Source code for epyt_control.evaluation.evaluation

 1"""
 2This module provides functions for evaluating policies/agents/control strategies on environments.
 3"""
 4from typing import Callable, Union
 5import numpy as np
 6from epyt_flow.simulation import ScadaData
 7from gymnasium import Wrapper
 8
 9from ..envs import RlEnv
10
11
[docs] 12def evaluate_policy(env: Union[RlEnv, Wrapper], policy: Callable[[np.ndarray], np.ndarray], 13 n_max_iter: int = 10) -> tuple[list[float], ScadaData]: 14 """ 15 Evaluates a given policy/agent/control strategy for a given environment -- 16 i.e. the policy/agent is applied to the environment and the rewards and 17 `ScadaData <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.simulation.scada.html#epyt_flow.simulation.scada.scada_data.ScadaData>`_ 18 observations over time are recorded. 19 20 Parameters 21 ---------- 22 env : :class:`~epyt_control.envs.rl_env.RlEnv` or `gymnasium.Wrapper <https://gymnasium.farama.org/api/wrappers/#gymnasium.Wrapper>`_ 23 The environment. 24 25 Note that in the case of a 26 `gymnasium.Wrapper <https://gymnasium.farama.org/api/wrappers/#gymnasium.Wrapper>`_ 27 instance, the underlying environment must be an instance of 28 :class:`~epyt_control.envs.rl_env.RlEnv`. 29 policy : `Callable[[numpy.ndarray], numpy.ndarray]` 30 Policy/Agent/Control strategy to be evaluated. 31 n_max_iter : `int`, optional 32 Upper bound on the number of iterations that is used for evaluating the given policy/agent. 33 34 The default is 1. 35 36 Returns 37 ------- 38 tuple[list[float], `epyt_flow.simulation.ScadaData <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.simulation.scada.html#epyt_flow.simulation.scada.scada_data.ScadaData>`_] 39 Tuple of rewards over time and a 40 `epyt_flow.simulation.ScadaData <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.simulation.scada.html#epyt_flow.simulation.scada.scada_data.ScadaData>`_ 41 instance containing the WDN states over time. 42 """ 43 if not isinstance(env, RlEnv) and not isinstance(env, Wrapper): 44 raise TypeError("'env' must be an instance of 'epyt_control.envs.RlEnv' or " + 45 f"'gymnasium.Wrapper' but not of '{type(env)}'") 46 if isinstance(env, Wrapper): 47 if not isinstance(env.env, RlEnv): 48 raise TypeError("The wrapped environment must be an insance of " + 49 f"'epyt_control.envs.RlEnv' but not of '{type(env.env)}'") 50 if not callable(policy): 51 raise TypeError("'policy' must be callable -- " + 52 "i.e. mapping observations (numpy.ndarray) to actions (numpy.ndarray)") 53 if not isinstance(n_max_iter, int) or n_max_iter < 1: 54 raise ValueError("'n_max_iter' must be an integer >= 1") 55 56 rewards = [] 57 scada_data = None 58 59 obs, _ = env.reset() 60 for _ in range(n_max_iter): 61 action = policy(obs) 62 obs, reward, terminated, _, info = env.step(action) 63 if terminated is True: 64 break 65 66 rewards.append(reward) 67 current_scada_data = info["scada_data"] 68 if scada_data is None: 69 scada_data = current_scada_data 70 else: 71 scada_data.concatenate(current_scada_data) 72 73 env.close() 74 75 return rewards, scada_data