1"""
2This module contains different state transition surrogate models.
3"""
4from abc import abstractmethod
5from typing import Callable
6import numpy as np
7from epyt_flow.topology import NetworkTopology
8from epyt_flow.simulation import ScadaData
9
10from ...envs import RlEnv
11
12
[docs]
13class StateTransitionModel():
14 """
15 Abstract base class of state transition models used in a surrogte -- i.e. a deep neural network
16 approximating the state transition functions.
17 """
[docs]
18 @abstractmethod
19 def init(self, wdn_topology: NetworkTopology, input_size: int, state_size: int) -> None:
20 """
21 Initializes the model.
22
23 Parameters
24 ----------
25 wdn_topology : `epyt_flow.topology.NetworkTopology <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.html#epyt_flow.topology.NetworkTopology>`_
26 Information about the topology of the WDN.
27 input_size : `int`
28 Dimensionality of the input -- i.e. current state + time varying inputs that are
29 relevant for the state transition (incl. control inputs).
30 state_size : `int`
31 Dimensionality of the state to be predicted.
32 """
33 raise NotImplementedError()
34
[docs]
35 @abstractmethod
36 def fit(self, cur_state: np.ndarray, next_time_varying_quantity: np.ndarray,
37 next_state: np.ndarray) -> None:
38 """
39 Fits the neural network to given state transition data.
40
41 Parameters
42 ----------
43 cur_state : numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
44 Current state of the system.
45 next_time_varying_quantity : numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
46 Time varying events (incl. control signals) that are relevant for evolving the state.
47 next_state : numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
48 Next state -- is to be predicted based on the other two arguments.
49 """
50 raise NotImplementedError()
51
[docs]
52 @abstractmethod
53 def partial_fit(self, cur_state: np.ndarray, next_time_varying_quantity: np.ndarray,
54 next_state: np.ndarray) -> None:
55 """
56 Performs a partial fit of the state transition surrogate to given data.
57
58 Parameters
59 ----------
60 cur_state : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
61 Current state of the system.
62 next_time_varying_quantity : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
63 Time varying events (incl. control signals) that are relevant for evolving the state.
64 next_state : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
65 Next state -- is to be predicted based on the other two arguments.
66 """
67 raise NotImplementedError()
68
[docs]
69 @abstractmethod
70 def predict(self, cur_state: np.ndarray,
71 next_time_varying_quantity: np.ndarray) -> np.ndarray:
72 """
73 Predicts the next state based on the current state and
74 time varying events such as control signals.
75
76 Parameters
77 ----------
78 cur_state : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
79 Current state.
80 next_time_varying_quantity : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
81 Time varying events (incl. control signals) that are relevant for evolving the state.
82
83 Returns
84 -------
85 `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
86 Next state.
87 """
88 raise NotImplementedError()
89
90
[docs]
91class StateTransitionSurrogate():
92 """
93 Base class of state transition surrogates.
94
95 Parameters
96 ----------
97 wdn_topology : `epyt_flow.topology.NetworkTopology <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.html#epyt_flow.topology.NetworkTopology>`_
98 Information about the topology of the WDN.
99 n_actuators : `int`
100 Number of actuators -- i.e. control inputs.
101 """
102 def __init__(self, wdn_topology: NetworkTopology, n_actuators: int):
103 self._wdn_topology = wdn_topology
104 self._n_actuators = n_actuators
105
[docs]
106 @abstractmethod
107 def fit_to_scada(self, scada_data: ScadaData, control_actions: np.ndarray) -> None:
108 """
109 Fits the state transition surrogate to given `SCADA data <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.simulation.scada.html#epyt_flow.simulation.scada.scada_data.ScadaData>`_.
110
111 Parameters
112 ----------
113 scada_data : `epyt_flow.simulation.scada_data.ScadaData <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.simulation.scada.html#epyt_flow.simulation.scada.scada_data.ScadaData>`_
114 SCADA data.
115 control_actions : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
116 Control signals at every time step.
117 """
118 raise NotImplementedError()
119
[docs]
120 def fit_to_env(self, env: RlEnv, n_max_iter: int = None,
121 policy: Callable[[np.ndarray], np.ndarray] = None) -> None:
122 """
123 Fits the state transition surrogate to a given control environment.
124
125 Parameters
126 ----------
127 env : :class:`~epyt_control.envs.rl_env.RlEnv`
128 Control environment.
129 n_max_iter : `int`
130 Maximum numbe of iterations used for data collection.
131 Note that data collection stops if the environment terminates.
132 policy : `Callable[[numpy.ndarray], numpy.ndarray]`
133 A policy for mapping observations to actions (i.e. control signals) -- will be applied at each time step.
134 If None, random actions are sampled from the action space.
135
136 The default is None.
137 """
138 # Run the environment and collect SCADA data
139 scada_data = None
140 control_actions = []
141
142 obs, _ = env.reset()
143 for _ in range(n_max_iter):
144 action = policy(obs) if policy is not None else env.action_space.sample()
145 control_actions.append(action)
146 obs, _, terminated, _, info = env.step(action)
147 if terminated is True:
148 break
149
150 current_scada_data = info["scada_data"]
151 if scada_data is None:
152 scada_data = current_scada_data
153 else:
154 scada_data.concatenate(current_scada_data)
155
156 env.close()
157
158 # Fit state transition surrogate model
159 self.fit_to_scada(scada_data, np.array(control_actions))
160
161
[docs]
162class HydraulicStateTransitionSurrogate(StateTransitionSurrogate):
163 """
164 Surrogate for the hydraulic state transition function.
165
166 Paramaters
167 ----------
168 wdn_topology : `epyt_flow.topology.NetworkTopology <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.html#epyt_flow.topology.NetworkTopology>`_
169 Information about the topology of the WDN.
170 n_actuators : `int`
171 Dimensionality of the control signal.
172 state_transition_model : :class:`StateTransitionModel`
173 State transition model which is used as an approximation of the true state transition function.
174 Usually, a neural network is used.
175 """
176 def __init__(self, wdn_topology: NetworkTopology, n_actuators: int,
177 state_transition_model: StateTransitionModel):
178 super().__init__(wdn_topology, n_actuators)
179
180 state_size = wdn_topology.get_number_of_nodes() + wdn_topology.get_number_of_links()
181 input_size = state_size + wdn_topology.get_number_of_nodes() + n_actuators
182 state_transition_model.init(self._wdn_topology,
183 input_size=input_size,
184 state_size=state_size)
185 self._state_transition_model = state_transition_model
186
[docs]
187 def fit_to_scada(self, scada_data: ScadaData, control_actions: np.ndarray = None) -> None:
188 if not isinstance(scada_data, ScadaData):
189 raise TypeError("'scada_data' must be an instance of " +
190 f"'epyt_flow.simulation.ScadaData' but not of '{type(scada_data)}'")
191 if self._n_actuators > 0 and control_actions is None:
192 raise ValueError("'control_actions' can not be None if 'n_actuators' > 0")
193
194 X_pressure = scada_data.get_data_pressures()
195 X_flows = scada_data.get_data_flows()
196 X_demands = scada_data.get_data_demands()
197 X_controls = control_actions
198 n_time_steps = X_pressure.shape[0]
199
200 cur_state = np.concatenate((X_pressure[:n_time_steps-1, :],
201 X_flows[:n_time_steps-1, :]), axis=1)
202 if X_controls is not None:
203 next_time_varying_quantity = np.concatenate((X_demands[1:, :],
204 X_controls[:n_time_steps-1, :]), axis=1)
205 else:
206 next_time_varying_quantity = X_demands[1:, :]
207 next_state = np.concatenate((X_pressure[1:, :], X_flows[1:, :]), axis=1)
208
209 self._state_transition_model.fit(cur_state, next_time_varying_quantity, next_state)
210
211 def __call__(self, cur_pressure: np.ndarray, cur_flow: np.ndarray,
212 next_demand: np.ndarray, control_actions: np.ndarray) -> np.ndarray:
213 return self.predict(cur_pressure, cur_flow, next_demand, control_actions)
214
[docs]
215 def predict(self, cur_pressure: np.ndarray, cur_flow: np.ndarray,
216 next_demand: np.ndarray, control_actions: np.ndarray) -> np.ndarray:
217 """
218 Predcts how the current state evolves for the next time step.
219
220 cur_pressure : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
221 Current pressure at every node.
222 cur_flow : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
223 Current flow rate at every link.
224 next_demand : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
225 Demand at every node for the next time step.
226 control_actions : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
227 Control signal at the current time step.
228
229 Returns
230 -------
231 `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
232 Next state.
233 """
234 X_cur_state = np.concatenate((cur_pressure, cur_flow), axis=1)
235 X_control = np.concatenate((next_demand, control_actions), axis=1)
236
237 return self._state_transition_model.predict(X_cur_state, X_control)
238
239
[docs]
240class WaterQualityStateTransitionSurrogate(StateTransitionSurrogate):
241 """
242 Surrogate for the quality (e.g. water age, chemical concentration) state transition function.
243
244 Paramaters
245 ----------
246 wdn_topology : `epyt_flow.topology.NetworkTopology <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.html#epyt_flow.topology.NetworkTopology>`_
247 Information about the topology of the WDN.
248 n_actuators : `int`
249 Dimensionality of the control signal.
250 state_transition_model : :class:`StateTransitionModel`
251 State transition model which is used as an approximation of the true state transition function.
252 Usually, a neural network is used.
253 """
254 def __init__(self, wdn_topology: NetworkTopology, n_actuators: int,
255 state_transition_model: StateTransitionModel):
256 super().__init__(wdn_topology, n_actuators)
257
258 state_size = wdn_topology.get_number_of_nodes() + wdn_topology.get_number_of_links()
259 input_size = state_size + wdn_topology.get_number_of_links() + n_actuators
260 state_transition_model.init(self._wdn_topology,
261 input_size=input_size,
262 state_size=state_size)
263 self._state_transition_model = state_transition_model
264
[docs]
265 def fit_to_scada(self, scada_data: ScadaData, control_actions: np.ndarray = None) -> None:
266 if not isinstance(scada_data, ScadaData):
267 raise TypeError("'scada_data' must be an instance of " +
268 f"'epyt_flow.simulation.ScadaData' but not of '{type(scada_data)}'")
269 if self._n_actuators > 0 and control_actions is None:
270 raise ValueError("'control_actions' can not be None if 'n_actuators' > 0")
271
272 X_flows = scada_data.get_data_flows()
273 X_nodes_quality = scada_data.get_data_nodes_quality()
274 X_links_quality = scada_data.get_data_links_quality()
275 X_controls = control_actions
276 n_time_steps = X_flows.shape[0]
277
278 cur_state = np.concatenate((X_nodes_quality[:n_time_steps-1, :],
279 X_links_quality[:n_time_steps-1, :]), axis=1)
280 if X_controls is not None:
281 next_time_varying_quantity = np.concatenate((X_flows[1:, :],
282 X_controls[:n_time_steps-1, :]), axis=1)
283 else:
284 next_time_varying_quantity = X_flows[1:, :]
285 next_state = np.concatenate((X_nodes_quality[1:, :], X_links_quality[1:, :]), axis=1)
286
287 self._state_transition_model.fit(cur_state, next_time_varying_quantity, next_state)
288
289 def __call__(self, cur_node_quality: np.ndarray, cur_link_quality: np.ndarray,
290 next_flow: np.ndarray, control_actions: np.ndarray) -> np.ndarray:
291 return self.predict(cur_node_quality, cur_link_quality,
292 next_flow, control_actions)
293
[docs]
294 def predict(self, cur_node_quality: np.ndarray, cur_link_quality: np.ndarray,
295 next_flow: np.ndarray, control_actions: np.ndarray) -> np.ndarray:
296 """
297 Predicts the next state (i.e. quality everywhere) based on the current state of the system,
298 the next flow, and control signals.
299
300 Parameters
301 ----------
302 cur_flow : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
303 Current flow rate at each link.
304 cur_node_quality : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
305 Current quality at every node.
306 cur_link_quality : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
307 Current quality at every link.
308 next_demand : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
309 Demand at every node for the next time step.
310 control_actions : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
311 Control signal at the current time step.
312
313 Returns
314 -------
315 `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
316 Next state.
317 """
318 X_cur_state = np.concatenate((cur_node_quality, cur_link_quality), axis=1)
319 X_control = np.concatenate((next_flow, control_actions), axis=1)
320
321 return self._state_transition_model.predict(X_cur_state, X_control)