Source code for epyt_control.signal_processing.state_estimation.smoothers

  1"""
  2This module contains several smoothing methods for improving filters.
  3"""
  4from typing import Optional, Callable
  5import numpy as np
  6
  7from .kalman_filters import KalmanFilter, TimeVaryingKalmanFilter
  8
  9
[docs] 10class RauchTungStriebelSmoother(KalmanFilter): 11 """ 12 Implementation of the Rauch-Tung-Striebel Kalman filter smoother. 13 14 Parameters 15 ---------- 16 time_window_length : `int` 17 Length of the time window which is considered for smoothing. 18 state_dim : `int` 19 Dimensionality of states. 20 obs_dim : `int` 21 Dimensionality of observations. 22 init_state : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 23 Initial state. 24 measurement_func : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 25 Measurement function -- i.e. matrix that is converting a state into an observation. 26 state_transition_func : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 27 State transition function -- i.e. matrix moving from a given state to the next state. 28 init_state_uncertainty_cov : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_, optional 29 Covariance matrix of the initial state uncertainty. 30 If None, the identity matrix will be used. 31 32 The default is None. 33 measurement_uncertainty_cov : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_, optional 34 Covariance matrix of the measurement/observation uncertainty. 35 If None, the identity matrix will be used. 36 37 The default is None. 38 system_uncertainty_cov : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_, optional 39 Covariance matrix of the system uncertainty. 40 If None, the identity matrix will be used. 41 42 The default is None. 43 """ 44 def __init__(self, time_window_length: int, state_dim: int, obs_dim: int, 45 init_state: np.ndarray, measurement_func: np.ndarray, 46 state_transition_func: np.ndarray, 47 init_state_uncertainty_cov: Optional[np.ndarray], 48 measurement_uncertainty_cov: Optional[np.ndarray], 49 system_uncertainty_cov: Optional[np.ndarray]) -> None: 50 if not isinstance(time_window_length, int): 51 raise TypeError("'time_window_length' must be an instance of 'int' " + 52 f"but not of '{type(time_window_length)}'") 53 if time_window_length <= 0: 54 raise ValueError("'time_window_length' must be positive") 55 56 self._time_window_length = time_window_length 57 58 super().__init__(state_dim=state_dim, obs_dim=obs_dim, init_state=init_state, 59 measurement_func=measurement_func, 60 state_transition_func=state_transition_func, 61 init_state_uncertainty_cov=init_state_uncertainty_cov, 62 measurement_uncertainty_cov=measurement_uncertainty_cov, 63 system_uncertainty_cov=system_uncertainty_cov) 64 65 @property 66 def time_window_length(self) -> int: 67 """ 68 Returns the length of the time window. 69 70 Returns 71 ------- 72 `int` 73 Time window length. 74 """ 75 return self._time_window_length 76 77 def __eq__(self, other): 78 return super().__eq__(other) and self._time_window_length == other.time_window_length 79 80 def __str__(self): 81 return super().__str__() + f" time_window_length: {self._time_window_length}" 82
[docs] 83 def step(self, observation: np.ndarray) -> tuple[list[np.ndarray], list[np.ndarray]]: 84 """ 85 Predicts the current state (incl. it's uncertainty) based on a given 86 time window of observations. 87 Also, updates all other internal states of the Kalman filter. 88 89 Parameters 90 ---------- 91 observation : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 92 Time window of observations. 93 94 Returns 95 ------- 96 tuple[list[`numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_], list[`numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_]] 97 Lists of predicted system states and uncertainty covariance matrices. 98 """ 99 if not isinstance(observation, np.ndarray): 100 raise TypeError("'observation' must be an instance of 'numpy.ndarray' " + 101 f"but not of '{type(observation)}'") 102 if observation.shape != (self._time_window_length, self._obs_dim): 103 raise ValueError("'observation' must be of shap (time_window_length, obs_dim) -- " + 104 f"i.e. {(self._time_window_length, self._obs_dim)}. " + 105 f"But found {observation.shape}") 106 107 # Forward pass 108 X = [], P = [] 109 for i in range(self._time_window_length): 110 x, cov = super().step(observation[i, :].flatten()) 111 X.append(x) 112 P.append(cov) 113 114 # Backward pass 115 for i in range(self._time_window_length-2, -1, -1): 116 C = self._F @ P[i] @ self._F.T + self._Q 117 K = P[i, :] @ self._F.T @ np.linalg.inv(C) 118 X[i] += K @ (X[i + 1] - (self._F @ X[i])) 119 P[i] += K @ (P[i + 1] - C) @ K.T 120 121 self._x = X[-1] 122 self._P = P[-1] 123 124 return X, P
125 126
[docs] 127class TimeVaryingRauchTungStriebelSmoother(TimeVaryingKalmanFilter): 128 """ 129 Implementation of the time varying Rauch-Tung-Striebel Kalman filter smoother. 130 131 Parameters 132 ---------- 133 time_window_length : `int` 134 Length of the time window which is considered for smoothing. 135 state_dim : `int` 136 Dimensionality of states. 137 obs_dim : `int` 138 Dimensionality of observations. 139 init_state : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 140 Initial state. 141 measurement_func : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 142 Measurement function -- i.e. matrix that is converting a state into an observation. 143 state_transition_func : Callable[[int], `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_] 144 Function mapping time (integer) to the time dependent state transition function -- 145 i.e. matrix moving from a given state to the next state. 146 init_state_uncertainty_cov : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_, optional 147 Covariance matrix of the initial state uncertainty. 148 If None, the identity matrix will be used. 149 150 The default is None. 151 measurement_uncertainty_cov : Callable[[int], `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_], optional 152 Function mapping time (integer) to the time dependent covariance matrix of the 153 measurement/observation uncertainty. 154 If None, the identity matrix will be used in all time steps. 155 156 The default is None. 157 system_uncertainty_cov : Callable[[int], `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_], optional 158 Function mapping time (integer) to the time dependent covariance matrix of the 159 system uncertainty. 160 If None, the identity matrix will be used in all time steps. 161 162 The default is None. 163 """ 164 def __init__(self, time_window_length: int, state_dim: int, 165 obs_dim: int, init_state: np.ndarray, 166 measurement_func: np.ndarray, 167 state_transition_func: Callable[[int], np.ndarray], 168 init_state_uncertainty_cov: Optional[np.ndarray], 169 measurement_uncertainty_cov: Optional[Callable[[int], np.ndarray]], 170 system_uncertainty_cov: Optional[Callable[[int], np.ndarray]]) -> None: 171 if not isinstance(time_window_length, int): 172 raise TypeError("'time_window_length' must be an instance of 'int' " + 173 f"but not of '{type(time_window_length)}'") 174 if time_window_length <= 0: 175 raise ValueError("'time_window_length' must be positive") 176 177 self._time_window_length = time_window_length 178 179 super().__init__(state_dim=state_dim, obs_dim=obs_dim, init_state=init_state, 180 measurement_func=measurement_func, 181 state_transition_func=state_transition_func, 182 init_state_uncertainty_cov=init_state_uncertainty_cov, 183 measurement_uncertainty_cov=measurement_uncertainty_cov, 184 system_uncertainty_cov=system_uncertainty_cov) 185 186 @property 187 def time_window_length(self) -> int: 188 """ 189 Returns the length of the time window. 190 191 Returns 192 ------- 193 `int` 194 Time window length. 195 """ 196 return self._time_window_length 197 198 def __eq__(self, other): 199 return super().__eq__(other) and self._time_window_length == other.time_window_length 200 201 def __str__(self): 202 return super().__str__() + f" time_window_length: {self._time_window_length}" 203
[docs] 204 def step(self, observation: np.ndarray) -> tuple[list[np.ndarray], list[np.ndarray]]: 205 """ 206 Predicts the current state (incl. it's uncertainty) based on a given 207 time window of observations. 208 Also, updates all other internal states of the Kalman filter. 209 210 Parameters 211 ---------- 212 observation : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_ 213 Time window of observations. 214 215 Returns 216 ------- 217 tuple[list[`numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_], list[`numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_]] 218 List of predicted system states and uncertainty covariance matrices. 219 """ 220 if not isinstance(observation, np.ndarray): 221 raise TypeError("'observation' must be an instance of 'numpy.ndarray' " + 222 f"but not of '{type(observation)}'") 223 if observation.shape != (self._time_window_length, self._obs_dim): 224 raise ValueError("'observation' must be of shap (time_window_length, obs_dim) -- " + 225 f"i.e. {(self._time_window_length, self._obs_dim)}. " + 226 f"But found {observation.shape}") 227 228 # Forward pass 229 X = [], P = [] 230 for i in range(self._time_window_length): 231 x, cov = super().step(observation[i, :].flatten()) 232 X.append(x) 233 P.append(cov) 234 235 # Backward pass 236 t = self._t - 1 237 for i in range(self._time_window_length-2, -1, -1): 238 F = self._get_state_transition_func(t) 239 Q = self._get_system_uncertainty_cov(t) 240 t -= 1 241 242 C = F @ P[i] @ F.T + Q 243 K = P[i, :] @ F.T @ np.linalg.inv(C) 244 X[i] += K @ (X[i + 1] - (F @ X[i])) 245 P[i] += K @ (P[i + 1] - C) @ K.T 246 247 self._x = X[-1] 248 self._P = P[-1] 249 250 return X, P