Source code for epyt_control.envs.rl_env

  1"""
  2This module contains a base class for reinforcement learning (RL) environments.
  3"""
  4from abc import abstractmethod
  5import os
  6import uuid
  7from copy import deepcopy
  8from typing import Optional, Any, Union
  9import numpy as np
 10from epyt_flow.simulation import ScadaData, ScenarioConfig, ScenarioSimulator
 11from epyt_flow.simulation.sensor_config import SENSOR_TYPE_NODE_PRESSURE, \
 12SENSOR_TYPE_NODE_QUALITY, SENSOR_TYPE_NODE_DEMAND, \
 13SENSOR_TYPE_LINK_FLOW, SENSOR_TYPE_LINK_QUALITY, \
 14SENSOR_TYPE_VALVE_STATE, SENSOR_TYPE_PUMP_STATE, \
 15SENSOR_TYPE_TANK_VOLUME, SENSOR_TYPE_NODE_BULK_SPECIES, \
 16SENSOR_TYPE_LINK_BULK_SPECIES, SENSOR_TYPE_SURFACE_SPECIES, \
 17SENSOR_TYPE_PUMP_EFFICIENCY, SENSOR_TYPE_PUMP_ENERGYCONSUMPTION, \
 18valid_sensor_types
 19
 20from epyt_flow.gym import ScenarioControlEnv
 21from epyt_flow.utils import get_temp_folder
 22from gymnasium import Env
 23from gymnasium.spaces import Space, Box, Discrete, Tuple
 24from gymnasium.spaces.utils import flatten_space
 25
 26from .actions import Action
 27
 28
[docs] 29class RlEnv(ScenarioControlEnv, Env): 30 """ 31 Base class for reinforcement learning environments. 32 33 Parameters 34 ---------- 35 scenario_config : `epyt_flow.simulation.ScenarioConfig <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.simulation.html#epyt_flow.simulation.scenario_config.ScenarioConfig>`_ 36 Config of the scenario. 37 gym_action_space : `gymnasium.spaces.Space <https://gymnasium.farama.org/api/spaces/#gymnasium.spaces.Space>`_ 38 Gymnasium action space. 39 action_space : list[:class:`~epyt_control.actions.actions.Action`] 40 List of all action spaces -- one space for each element that can be controlled by the agent. 41 reload_scenario_when_reset : `bool`, optional 42 If True, the scenario (incl. the .inp and .msx file) is reloaded from the hard disk. 43 If False, only the simulation is reset. 44 45 The default is True. 46 hyd_file_in : `str`, optional 47 Path to an EPANET .hyd file containing the simulated hydraulics. 48 Can only be used in conjunction with 'hyd_scada_in' in the case of an EPANET-MSX scenario. 49 If set, hydraulics will not be simulated but taken from the specified file. 50 51 The default is None. 52 hyd_scada_in : `epyt_flow.simulation.ScadaData <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.simulation.scada.html#epyt_flow.simulation.scada.scada_data.ScadaData>`_, optional 53 ScadaData instance containing the simulated hydraulics -- must match the hydraulics 54 from 'hyd_file_in'. Can only be used in conjunction with 'hyd_file_in'. 55 56 The default is None. 57 frozen_sensor_config : `bool`, optional 58 If True, only the sensor readings from the observation space will be stored when running the 59 simulation -- note that this implies that the reward function can only use the observations. 60 This can lead to a significant speed-up of the simulation. 61 62 The default is False. 63 """ 64 def __init__(self, scenario_config: ScenarioConfig, gym_action_space: Space, 65 action_space: list[Action], reload_scenario_when_reset: bool = True, 66 hyd_file_in: str = None, hyd_scada_in: ScadaData = None, 67 frozen_sensor_config: bool = False, **kwds): 68 if not isinstance(gym_action_space, Space): 69 raise TypeError("'gym_action_space' must be an instance of 'gymnasium.spaces.Space' " + 70 f"but not of '{type(gym_action_space)}'") 71 if not isinstance(action_space, list): 72 raise TypeError("'action_spaces' must be an instance of " + 73 "'list[epyt_control.actions.Action]' " + 74 f"but not of '{type(action_space)}'") 75 if any(not isinstance(a_s, Action) for a_s in action_space): 76 raise TypeError("Every item in 'action_spaces' must be an instance of " + 77 "'epyt_control.actions.Action'") 78 if not isinstance(reload_scenario_when_reset, bool): 79 raise TypeError("'reload_scenario_when_reset' must be an instance of 'bool' " + 80 f"but not of '{type(reload_scenario_when_reset)}'") 81 if not isinstance(frozen_sensor_config, bool): 82 raise TypeError("'frozen_sensor_config' must be an instance of 'bool' " + 83 f"but not of '{type(frozen_sensor_config)}'") 84 85 if (hyd_file_in is not None and hyd_scada_in is None) or \ 86 (hyd_file_in is None and hyd_scada_in is not None): 87 raise ValueError("") 88 if hyd_file_in is not None: 89 if not isinstance(hyd_file_in, str): 90 raise TypeError("'hyd_file_in' must be an instance of 'str' " + 91 f"but not of '{type(hyd_file_in)}'") 92 if hyd_scada_in is not None: 93 if not isinstance(hyd_scada_in, ScadaData): 94 raise TypeError("'hyd_scada_in' must be an instance of " + 95 "'epyt_flow.simulation.ScadaData' but not of " + 96 f"'{type(hyd_scada_in)}'") 97 98 self._hyd_file_in = hyd_file_in 99 self._hyd_scada_in = hyd_scada_in 100 101 super().__init__(scenario_config=scenario_config, **kwds) 102 103 self._observation_space = self._get_observation_space() 104 self._action_space = action_space 105 self._gym_action_space = gym_action_space 106 self._reload_scenario_when_reset = reload_scenario_when_reset 107 self._frozen_sensor_config = frozen_sensor_config 108 109 def _get_observation_space(self) -> Space: 110 obs_space = [] 111 sensor_config = self._scenario_config.sensor_config 112 113 for sensor_type in sensor_config.sensor_ordering: 114 if sensor_type==SENSOR_TYPE_NODE_PRESSURE: 115 obs_space += [Box(low=0, high=float("inf"))] * len( 116 sensor_config.pressure_sensors 117 ) 118 elif sensor_type==SENSOR_TYPE_LINK_FLOW: 119 obs_space += [Box(low=float("-inf"), high=float("inf"))] * len( 120 sensor_config.flow_sensors 121 ) 122 elif sensor_type==SENSOR_TYPE_NODE_DEMAND: 123 obs_space += [Box(low=0, high=float("inf"))] * len( 124 sensor_config.demand_sensors 125 ) 126 elif sensor_type==SENSOR_TYPE_NODE_QUALITY: 127 obs_space += [Box(low=0, high=float("inf"))] * len( 128 sensor_config.quality_node_sensors 129 ) 130 elif sensor_type==SENSOR_TYPE_LINK_QUALITY: 131 obs_space += [Box(low=0, high=float("inf"))] * len( 132 sensor_config.quality_link_sensors 133 ) 134 elif sensor_type==SENSOR_TYPE_VALVE_STATE: 135 obs_space += [Discrete(2, start=2)] * len( 136 sensor_config.valve_state_sensors 137 ) 138 elif sensor_type==SENSOR_TYPE_PUMP_STATE: 139 obs_space += [Discrete(2, start=2)] * len( 140 sensor_config.pump_state_sensors 141 ) 142 elif sensor_type==SENSOR_TYPE_PUMP_EFFICIENCY: 143 obs_space += [Box(low=0, high=float("inf"))] * len( 144 sensor_config.pump_efficiency_sensors 145 ) 146 elif sensor_type==SENSOR_TYPE_PUMP_ENERGYCONSUMPTION: 147 obs_space += [Box(low=0, high=float("inf"))] * len( 148 sensor_config.pump_energyconsumption_sensors 149 ) 150 elif sensor_type==SENSOR_TYPE_TANK_VOLUME: 151 obs_space += [Box(low=0, high=float("inf"))] * len( 152 sensor_config.tank_volume_sensors 153 ) 154 elif sensor_type==SENSOR_TYPE_SURFACE_SPECIES: 155 for species_id in sensor_config.surface_species_sensors: 156 obs_space += [Box(low=0, high=float("inf"))] * len( 157 sensor_config.surface_species_sensors[species_id] 158 ) 159 elif sensor_type==SENSOR_TYPE_NODE_BULK_SPECIES: 160 for species_id in sensor_config.bulk_species_node_sensors: 161 obs_space += [Box(low=0, high=float("inf"))] * len( 162 sensor_config.bulk_species_node_sensors[species_id] 163 ) 164 elif sensor_type==SENSOR_TYPE_LINK_BULK_SPECIES: 165 for species_id in sensor_config.bulk_species_link_sensors: 166 obs_space += [Box(low=0, high=float("inf"))] * len( 167 sensor_config.bulk_species_link_sensors[species_id] 168 ) 169 else: raise ValueError( 170 f"Invalid sensor type: {sensor_type} " 171 f"Valid sensor types are\n{valid_sensor_types()}" 172 ) 173 174 return flatten_space(Tuple(obs_space)) 175 176 @property 177 def observation_space(self) -> Space: 178 """ 179 Returns the observation space of this environment. 180 181 Returns 182 ------- 183 `gymnasium.spaces.Space <https://gymnasium.farama.org/api/spaces/#gymnasium.spaces.Space>`_ 184 Gymnasium (observation) space instance. 185 """ 186 return self._observation_space 187 188 @property 189 def action_space(self) -> Space: 190 """ 191 Returns the action space of this environment. 192 193 Returns 194 ------- 195 `gymnasium.spaces.Space <https://gymnasium.farama.org/api/spaces/#gymnasium.spaces.Space>`_ 196 Gymnasium (action) space instance. 197 """ 198 return self._gym_action_space 199 200 def _next_sim_itr(self) -> Union[tuple[ScadaData, bool], ScadaData]: 201 try: 202 next(self._sim_generator) 203 scada_data, terminated = self._sim_generator.send(False) 204 205 if self._scenario_sim.f_msx_in is not None: 206 cur_time = int(scada_data.sensor_readings_time[0]) 207 cur_hyd_scada_data = self._hydraulic_scada_data.\ 208 extract_time_window(cur_time, cur_time) 209 scada_data.join(cur_hyd_scada_data) 210 211 if self.autoreset is True: 212 return scada_data 213 else: 214 return scada_data, terminated 215 except StopIteration: 216 if self.autoreset is True: 217 _, info = self.reset() 218 return info["scada_data"] 219 else: 220 return None, True 221
[docs] 222 def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None 223 ) -> tuple[np.ndarray, dict]: 224 """ 225 Resets this environment to an initial internal state, returning an 226 initial observation and info. 227 228 Parameters 229 ---------- 230 seed : `int`, optional 231 The seed that is used to initialize the environment's PRNG. 232 233 The default is None. 234 options : `dict[str, Any]`, optional 235 Additional information to specify how the environment is reset 236 (optional, depending on the specific environment). 237 238 The default is None. 239 240 Returns 241 ------- 242 tuple[`numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_, dict] 243 Observation (`numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_), 244 {"scada_data": `ScadaData <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.simulation.scada.html#epyt_flow.simulation.scada.scada_data.ScadaData>`_} 245 (`epyt_flow.simulation.ScadaData <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.simulation.scada.html#epyt_flow.simulation.scada.scada_data.ScadaData>`_ as additional info). 246 """ 247 Env.reset(self, seed=seed) 248 249 if self._reload_scenario_when_reset is True: 250 scada_data = super().reset() 251 else: 252 if self._scenario_sim is None: 253 self._scenario_sim = ScenarioSimulator(scenario_config=self._scenario_config) 254 else: 255 # Abort current simulation if any is runing 256 try: 257 next(self._sim_generator) 258 self._sim_generator.send(True) 259 except StopIteration: 260 pass 261 262 if self._scenario_sim.f_msx_in is not None: 263 if self._hyd_file_in is not None: 264 hyd_export = self._hyd_file_in 265 self._hydraulic_scada_data = self._hyd_scada_in 266 else: 267 hyd_export = os.path.join(get_temp_folder(), 268 f"epytflow_env_MSX_{uuid.uuid4()}.hyd") 269 sim = self._scenario_sim.run_hydraulic_simulation 270 self._hydraulic_scada_data = sim(hyd_export=hyd_export, 271 frozen_sensor_config=self._frozen_sensor_config, 272 reapply_uncertainties=self.reapply_uncertainties_at_reset) 273 274 gen = self._scenario_sim.run_advanced_quality_simulation_as_generator 275 self._sim_generator = gen(hyd_export, support_abort=True, 276 frozen_sensor_config=self._frozen_sensor_config, 277 reapply_uncertainties=self.reapply_uncertainties_at_reset) 278 else: 279 gen = self._scenario_sim.run_hydraulic_simulation_as_generator 280 self._sim_generator = gen(support_abort=True, 281 frozen_sensor_config=self._frozen_sensor_config, 282 reapply_uncertainties=self.reapply_uncertainties_at_reset) 283 284 scada_data = self._next_sim_itr() 285 286 if isinstance(scada_data, tuple): 287 scada_data, _ = scada_data 288 r = self._get_observation(scada_data) 289 290 return r, {"scada_data": scada_data}
291 292 def _get_observation(self, scada_data: ScadaData) -> np.ndarray: 293 if scada_data is not None: 294 return scada_data.get_data().flatten().astype(np.float32) 295 else: 296 return None 297 298 @abstractmethod 299 def _compute_reward_function(self, scada_data: ScadaData) -> float: 300 """ 301 Computes the current reward based on the current sensors readings (i.e. SCADA data). 302 303 Parameters 304 ---------- 305 scada_data :`epyt_flow.simulation.ScadaData <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.simulation.scada.html#epyt_flow.simulation.scada.scada_data.ScadaData>`_ 306 Current sensor readings. 307 308 Returns 309 ------- 310 `float` 311 Current reward. 312 """ 313 raise NotImplementedError() 314
[docs] 315 def step(self, action: np.ndarray) -> tuple[np.ndarray, float, bool, bool, dict]: 316 """ 317 Performs the next step by applying an action and observing the next 318 state together with a reward. 319 320 Parameters 321 ---------- 322 action : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 323 Actions to be executed. 324 325 Returns 326 ------- 327 tuple[`numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_, float, bool, bool, dict] 328 Observation (`numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_), reward, terminated, False (truncated), {"scada_data": `ScadaData <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.simulation.scada.html#epyt_flow.simulation.scada.scada_data.ScadaData>`_} 329 (`epyt_flow.simulation.ScadaData <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.simulation.scada.html#epyt_flow.simulation.scada.scada_data.ScadaData>`_ as additional info). 330 """ 331 # Apply actions 332 for action_value, action in zip(action, self._action_space): 333 action.apply(self, action_value) 334 335 # Run one simulation step and observe the sensor readings (SCADA data) 336 if self.autoreset is False: 337 current_scada_data, terminated = self._next_sim_itr() 338 else: 339 terminated = False 340 current_scada_data = self._next_sim_itr() 341 342 if isinstance(current_scada_data, tuple): 343 current_scada_data, _ = current_scada_data 344 345 if current_scada_data is not None: 346 obs = self._get_observation(current_scada_data) 347 348 # Calculate reward 349 current_reward = self._compute_reward_function(deepcopy(current_scada_data)) 350 else: 351 obs = None 352 current_reward = None 353 354 # Return observation and reward 355 return obs, current_reward, terminated, False, {"scada_data": current_scada_data}