1"""
2This module provides functions for evaluating policies/agents/control strategies on environments.
3"""
4from typing import Callable, Union
5import numpy as np
6from epyt_flow.simulation import ScadaData
7from gymnasium import Wrapper
8
9from ..envs import RlEnv
10
11
[docs]
12def evaluate_policy(env: Union[RlEnv, Wrapper], policy: Callable[[np.ndarray], np.ndarray],
13 n_max_iter: int = 10) -> tuple[list[float], ScadaData]:
14 """
15 Evaluates a given policy/agent/control strategy for a given environment --
16 i.e. the policy/agent is applied to the environment and the rewards and
17 `ScadaData <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.simulation.scada.html#epyt_flow.simulation.scada.scada_data.ScadaData>`_
18 observations over time are recorded.
19
20 Parameters
21 ----------
22 env : :class:`~epyt_control.envs.rl_env.RlEnv` or `gymnasium.Wrapper <https://gymnasium.farama.org/api/wrappers/#gymnasium.Wrapper>`_
23 The environment.
24
25 Note that in the case of a
26 `gymnasium.Wrapper <https://gymnasium.farama.org/api/wrappers/#gymnasium.Wrapper>`_
27 instance, the underlying environment must be an instance of
28 :class:`~epyt_control.envs.rl_env.RlEnv`.
29 policy : `Callable[[numpy.ndarray], numpy.ndarray]`
30 Policy/Agent/Control strategy to be evaluated.
31 n_max_iter : `int`, optional
32 Upper bound on the number of iterations that is used for evaluating the given policy/agent.
33
34 The default is 1.
35
36 Returns
37 -------
38 tuple[list[float], `epyt_flow.simulation.ScadaData <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.simulation.scada.html#epyt_flow.simulation.scada.scada_data.ScadaData>`_]
39 Tuple of rewards over time and a
40 `epyt_flow.simulation.ScadaData <https://epyt-flow.readthedocs.io/en/stable/epyt_flow.simulation.scada.html#epyt_flow.simulation.scada.scada_data.ScadaData>`_
41 instance containing the WDN states over time.
42 """
43 if not isinstance(env, RlEnv) and not isinstance(env, Wrapper):
44 raise TypeError("'env' must be an instance of 'epyt_control.envs.RlEnv' or " +
45 f"'gymnasium.Wrapper' but not of '{type(env)}'")
46 if isinstance(env, Wrapper):
47 if not isinstance(env.env, RlEnv):
48 raise TypeError("The wrapped environment must be an insance of " +
49 f"'epyt_control.envs.RlEnv' but not of '{type(env.env)}'")
50 if not callable(policy):
51 raise TypeError("'policy' must be callable -- " +
52 "i.e. mapping observations (numpy.ndarray) to actions (numpy.ndarray)")
53 if not isinstance(n_max_iter, int) or n_max_iter < 1:
54 raise ValueError("'n_max_iter' must be an integer >= 1")
55
56 rewards = []
57 scada_data = None
58
59 obs, _ = env.reset()
60 for _ in range(n_max_iter):
61 action = policy(obs)
62 obs, reward, terminated, _, info = env.step(action)
63 if terminated is True:
64 break
65
66 rewards.append(reward)
67 current_scada_data = info["scada_data"]
68 if scada_data is None:
69 scada_data = current_scada_data
70 else:
71 scada_data.concatenate(current_scada_data)
72
73 env.close()
74
75 return rewards, scada_data