Source code for epyt_control.signal_processing.state_forecasting.surrogates

  1"""
  2This module contains different state transition surrogate models.
  3"""
  4from abc import abstractmethod
  5from typing import Callable
  6import numpy as np
  7from epyt_flow.topology import NetworkTopology
  8from epyt_flow.simulation import ScadaData
  9
 10from ...envs import RlEnv
 11
 12
[docs] 13class StateTransitionModel(): 14 """ 15 Abstract base class of state transition models used in a surrogte -- i.e. a deep neural network 16 approximating the state transition functions. 17 """
[docs] 18 @abstractmethod 19 def init(self, wdn_topology: NetworkTopology, input_size: int, state_size: int) -> None: 20 """ 21 Initializes the model. 22 23 Parameters 24 ---------- 25 wdn_topology : `epyt_flow.topology.NetworkTopology <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.html#epyt_flow.topology.NetworkTopology>`_ 26 Information about the topology of the WDN. 27 input_size : `int` 28 Dimensionality of the input -- i.e. current state + time varying inputs that are 29 relevant for the state transition (incl. control inputs). 30 state_size : `int` 31 Dimensionality of the state to be predicted. 32 """ 33 raise NotImplementedError()
34
[docs] 35 @abstractmethod 36 def fit(self, cur_state: np.ndarray, next_time_varying_quantity: np.ndarray, 37 next_state: np.ndarray) -> None: 38 """ 39 Fits the neural network to given state transition data. 40 41 Parameters 42 ---------- 43 cur_state : numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 44 Current state of the system. 45 next_time_varying_quantity : numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 46 Time varying events (incl. control signals) that are relevant for evolving the state. 47 next_state : numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 48 Next state -- is to be predicted based on the other two arguments. 49 """ 50 raise NotImplementedError()
51
[docs] 52 @abstractmethod 53 def partial_fit(self, cur_state: np.ndarray, next_time_varying_quantity: np.ndarray, 54 next_state: np.ndarray) -> None: 55 """ 56 Performs a partial fit of the state transition surrogate to given data. 57 58 Parameters 59 ---------- 60 cur_state : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 61 Current state of the system. 62 next_time_varying_quantity : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 63 Time varying events (incl. control signals) that are relevant for evolving the state. 64 next_state : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 65 Next state -- is to be predicted based on the other two arguments. 66 """ 67 raise NotImplementedError()
68
[docs] 69 @abstractmethod 70 def predict(self, cur_state: np.ndarray, 71 next_time_varying_quantity: np.ndarray) -> np.ndarray: 72 """ 73 Predicts the next state based on the current state and 74 time varying events such as control signals. 75 76 Parameters 77 ---------- 78 cur_state : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 79 Current state. 80 next_time_varying_quantity : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 81 Time varying events (incl. control signals) that are relevant for evolving the state. 82 83 Returns 84 ------- 85 `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 86 Next state. 87 """ 88 raise NotImplementedError()
89 90
[docs] 91class StateTransitionSurrogate(): 92 """ 93 Base class of state transition surrogates. 94 95 Parameters 96 ---------- 97 wdn_topology : `epyt_flow.topology.NetworkTopology <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.html#epyt_flow.topology.NetworkTopology>`_ 98 Information about the topology of the WDN. 99 n_actuators : `int` 100 Number of actuators -- i.e. control inputs. 101 """ 102 def __init__(self, wdn_topology: NetworkTopology, n_actuators: int): 103 self._wdn_topology = wdn_topology 104 self._n_actuators = n_actuators 105
[docs] 106 @abstractmethod 107 def fit_to_scada(self, scada_data: ScadaData, control_actions: np.ndarray) -> None: 108 """ 109 Fits the state transition surrogate to given `SCADA data <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.simulation.scada.html#epyt_flow.simulation.scada.scada_data.ScadaData>`_. 110 111 Parameters 112 ---------- 113 scada_data : `epyt_flow.simulation.scada_data.ScadaData <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.simulation.scada.html#epyt_flow.simulation.scada.scada_data.ScadaData>`_ 114 SCADA data. 115 control_actions : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 116 Control signals at every time step. 117 """ 118 raise NotImplementedError()
119
[docs] 120 def fit_to_env(self, env: RlEnv, n_max_iter: int = None, 121 policy: Callable[[np.ndarray], np.ndarray] = None) -> None: 122 """ 123 Fits the state transition surrogate to a given control environment. 124 125 Parameters 126 ---------- 127 env : :class:`~epyt_control.envs.rl_env.RlEnv` 128 Control environment. 129 n_max_iter : `int` 130 Maximum numbe of iterations used for data collection. 131 Note that data collection stops if the environment terminates. 132 policy : `Callable[[numpy.ndarray], numpy.ndarray]` 133 A policy for mapping observations to actions (i.e. control signals) -- will be applied at each time step. 134 If None, random actions are sampled from the action space. 135 136 The default is None. 137 """ 138 # Run the environment and collect SCADA data 139 scada_data = None 140 control_actions = [] 141 142 obs, _ = env.reset() 143 for _ in range(n_max_iter): 144 action = policy(obs) if policy is not None else env.action_space.sample() 145 control_actions.append(action) 146 obs, _, terminated, _, info = env.step(action) 147 if terminated is True: 148 break 149 150 current_scada_data = info["scada_data"] 151 if scada_data is None: 152 scada_data = current_scada_data 153 else: 154 scada_data.concatenate(current_scada_data) 155 156 env.close() 157 158 # Fit state transition surrogate model 159 self.fit_to_scada(scada_data, np.array(control_actions))
160 161
[docs] 162class HydraulicStateTransitionSurrogate(StateTransitionSurrogate): 163 """ 164 Surrogate for the hydraulic state transition function. 165 166 Paramaters 167 ---------- 168 wdn_topology : `epyt_flow.topology.NetworkTopology <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.html#epyt_flow.topology.NetworkTopology>`_ 169 Information about the topology of the WDN. 170 n_actuators : `int` 171 Dimensionality of the control signal. 172 state_transition_model : :class:`StateTransitionModel` 173 State transition model which is used as an approximation of the true state transition function. 174 Usually, a neural network is used. 175 """ 176 def __init__(self, wdn_topology: NetworkTopology, n_actuators: int, 177 state_transition_model: StateTransitionModel): 178 super().__init__(wdn_topology, n_actuators) 179 180 state_size = wdn_topology.get_number_of_nodes() + wdn_topology.get_number_of_links() 181 input_size = state_size + wdn_topology.get_number_of_nodes() + n_actuators 182 state_transition_model.init(self._wdn_topology, 183 input_size=input_size, 184 state_size=state_size) 185 self._state_transition_model = state_transition_model 186
[docs] 187 def fit_to_scada(self, scada_data: ScadaData, control_actions: np.ndarray = None) -> None: 188 if not isinstance(scada_data, ScadaData): 189 raise TypeError("'scada_data' must be an instance of " + 190 f"'epyt_flow.simulation.ScadaData' but not of '{type(scada_data)}'") 191 if self._n_actuators > 0 and control_actions is None: 192 raise ValueError("'control_actions' can not be None if 'n_actuators' > 0") 193 194 X_pressure = scada_data.get_data_pressures() 195 X_flows = scada_data.get_data_flows() 196 X_demands = scada_data.get_data_demands() 197 X_controls = control_actions 198 n_time_steps = X_pressure.shape[0] 199 200 cur_state = np.concatenate((X_pressure[:n_time_steps-1, :], 201 X_flows[:n_time_steps-1, :]), axis=1) 202 if X_controls is not None: 203 next_time_varying_quantity = np.concatenate((X_demands[1:, :], 204 X_controls[:n_time_steps-1, :]), axis=1) 205 else: 206 next_time_varying_quantity = X_demands[1:, :] 207 next_state = np.concatenate((X_pressure[1:, :], X_flows[1:, :]), axis=1) 208 209 self._state_transition_model.fit(cur_state, next_time_varying_quantity, next_state)
210 211 def __call__(self, cur_pressure: np.ndarray, cur_flow: np.ndarray, 212 next_demand: np.ndarray, control_actions: np.ndarray) -> np.ndarray: 213 return self.predict(cur_pressure, cur_flow, next_demand, control_actions) 214
[docs] 215 def predict(self, cur_pressure: np.ndarray, cur_flow: np.ndarray, 216 next_demand: np.ndarray, control_actions: np.ndarray) -> np.ndarray: 217 """ 218 Predcts how the current state evolves for the next time step. 219 220 cur_pressure : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 221 Current pressure at every node. 222 cur_flow : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 223 Current flow rate at every link. 224 next_demand : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 225 Demand at every node for the next time step. 226 control_actions : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 227 Control signal at the current time step. 228 229 Returns 230 ------- 231 `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 232 Next state. 233 """ 234 X_cur_state = np.concatenate((cur_pressure, cur_flow), axis=1) 235 X_control = np.concatenate((next_demand, control_actions), axis=1) 236 237 return self._state_transition_model.predict(X_cur_state, X_control)
238 239
[docs] 240class WaterQualityStateTransitionSurrogate(StateTransitionSurrogate): 241 """ 242 Surrogate for the quality (e.g. water age, chemical concentration) state transition function. 243 244 Paramaters 245 ---------- 246 wdn_topology : `epyt_flow.topology.NetworkTopology <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.html#epyt_flow.topology.NetworkTopology>`_ 247 Information about the topology of the WDN. 248 n_actuators : `int` 249 Dimensionality of the control signal. 250 state_transition_model : :class:`StateTransitionModel` 251 State transition model which is used as an approximation of the true state transition function. 252 Usually, a neural network is used. 253 """ 254 def __init__(self, wdn_topology: NetworkTopology, n_actuators: int, 255 state_transition_model: StateTransitionModel): 256 super().__init__(wdn_topology, n_actuators) 257 258 state_size = wdn_topology.get_number_of_nodes() + wdn_topology.get_number_of_links() 259 input_size = state_size + wdn_topology.get_number_of_links() + n_actuators 260 state_transition_model.init(self._wdn_topology, 261 input_size=input_size, 262 state_size=state_size) 263 self._state_transition_model = state_transition_model 264
[docs] 265 def fit_to_scada(self, scada_data: ScadaData, control_actions: np.ndarray = None) -> None: 266 if not isinstance(scada_data, ScadaData): 267 raise TypeError("'scada_data' must be an instance of " + 268 f"'epyt_flow.simulation.ScadaData' but not of '{type(scada_data)}'") 269 if self._n_actuators > 0 and control_actions is None: 270 raise ValueError("'control_actions' can not be None if 'n_actuators' > 0") 271 272 X_flows = scada_data.get_data_flows() 273 X_nodes_quality = scada_data.get_data_nodes_quality() 274 X_links_quality = scada_data.get_data_links_quality() 275 X_controls = control_actions 276 n_time_steps = X_flows.shape[0] 277 278 cur_state = np.concatenate((X_nodes_quality[:n_time_steps-1, :], 279 X_links_quality[:n_time_steps-1, :]), axis=1) 280 if X_controls is not None: 281 next_time_varying_quantity = np.concatenate((X_flows[1:, :], 282 X_controls[:n_time_steps-1, :]), axis=1) 283 else: 284 next_time_varying_quantity = X_flows[1:, :] 285 next_state = np.concatenate((X_nodes_quality[1:, :], X_links_quality[1:, :]), axis=1) 286 287 self._state_transition_model.fit(cur_state, next_time_varying_quantity, next_state)
288 289 def __call__(self, cur_node_quality: np.ndarray, cur_link_quality: np.ndarray, 290 next_flow: np.ndarray, control_actions: np.ndarray) -> np.ndarray: 291 return self.predict(cur_node_quality, cur_link_quality, 292 next_flow, control_actions) 293
[docs] 294 def predict(self, cur_node_quality: np.ndarray, cur_link_quality: np.ndarray, 295 next_flow: np.ndarray, control_actions: np.ndarray) -> np.ndarray: 296 """ 297 Predicts the next state (i.e. quality everywhere) based on the current state of the system, 298 the next flow, and control signals. 299 300 Parameters 301 ---------- 302 cur_flow : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 303 Current flow rate at each link. 304 cur_node_quality : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 305 Current quality at every node. 306 cur_link_quality : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 307 Current quality at every link. 308 next_demand : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 309 Demand at every node for the next time step. 310 control_actions : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 311 Control signal at the current time step. 312 313 Returns 314 ------- 315 `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 316 Next state. 317 """ 318 X_cur_state = np.concatenate((cur_node_quality, cur_link_quality), axis=1) 319 X_control = np.concatenate((next_flow, control_actions), axis=1) 320 321 return self._state_transition_model.predict(X_cur_state, X_control)