1"""
2This module contains several smoothing methods for improving filters.
3"""
4from typing import Optional, Callable
5import numpy as np
6
7from .kalman_filters import KalmanFilter, TimeVaryingKalmanFilter
8
9
[docs]
10class RauchTungStriebelSmoother(KalmanFilter):
11 """
12 Implementation of the Rauch-Tung-Striebel Kalman filter smoother.
13
14 Parameters
15 ----------
16 time_window_length : `int`
17 Length of the time window which is considered for smoothing.
18 state_dim : `int`
19 Dimensionality of states.
20 obs_dim : `int`
21 Dimensionality of observations.
22 init_state : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
23 Initial state.
24 measurement_func : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
25 Measurement function -- i.e. matrix that is converting a state into an observation.
26 state_transition_func : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
27 State transition function -- i.e. matrix moving from a given state to the next state.
28 init_state_uncertainty_cov : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_, optional
29 Covariance matrix of the initial state uncertainty.
30 If None, the identity matrix will be used.
31
32 The default is None.
33 measurement_uncertainty_cov : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_, optional
34 Covariance matrix of the measurement/observation uncertainty.
35 If None, the identity matrix will be used.
36
37 The default is None.
38 system_uncertainty_cov : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_, optional
39 Covariance matrix of the system uncertainty.
40 If None, the identity matrix will be used.
41
42 The default is None.
43 """
44 def __init__(self, time_window_length: int, state_dim: int, obs_dim: int,
45 init_state: np.ndarray, measurement_func: np.ndarray,
46 state_transition_func: np.ndarray,
47 init_state_uncertainty_cov: Optional[np.ndarray],
48 measurement_uncertainty_cov: Optional[np.ndarray],
49 system_uncertainty_cov: Optional[np.ndarray]) -> None:
50 if not isinstance(time_window_length, int):
51 raise TypeError("'time_window_length' must be an instance of 'int' " +
52 f"but not of '{type(time_window_length)}'")
53 if time_window_length <= 0:
54 raise ValueError("'time_window_length' must be positive")
55
56 self._time_window_length = time_window_length
57
58 super().__init__(state_dim=state_dim, obs_dim=obs_dim, init_state=init_state,
59 measurement_func=measurement_func,
60 state_transition_func=state_transition_func,
61 init_state_uncertainty_cov=init_state_uncertainty_cov,
62 measurement_uncertainty_cov=measurement_uncertainty_cov,
63 system_uncertainty_cov=system_uncertainty_cov)
64
65 @property
66 def time_window_length(self) -> int:
67 """
68 Returns the length of the time window.
69
70 Returns
71 -------
72 `int`
73 Time window length.
74 """
75 return self._time_window_length
76
77 def __eq__(self, other):
78 return super().__eq__(other) and self._time_window_length == other.time_window_length
79
80 def __str__(self):
81 return super().__str__() + f" time_window_length: {self._time_window_length}"
82
[docs]
83 def step(self, observation: np.ndarray) -> tuple[list[np.ndarray], list[np.ndarray]]:
84 """
85 Predicts the current state (incl. it's uncertainty) based on a given
86 time window of observations.
87 Also, updates all other internal states of the Kalman filter.
88
89 Parameters
90 ----------
91 observation : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
92 Time window of observations.
93
94 Returns
95 -------
96 tuple[list[`numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_], list[`numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_]]
97 Lists of predicted system states and uncertainty covariance matrices.
98 """
99 if not isinstance(observation, np.ndarray):
100 raise TypeError("'observation' must be an instance of 'numpy.ndarray' " +
101 f"but not of '{type(observation)}'")
102 if observation.shape != (self._time_window_length, self._obs_dim):
103 raise ValueError("'observation' must be of shap (time_window_length, obs_dim) -- " +
104 f"i.e. {(self._time_window_length, self._obs_dim)}. " +
105 f"But found {observation.shape}")
106
107 # Forward pass
108 X = [], P = []
109 for i in range(self._time_window_length):
110 x, cov = super().step(observation[i, :].flatten())
111 X.append(x)
112 P.append(cov)
113
114 # Backward pass
115 for i in range(self._time_window_length-2, -1, -1):
116 C = self._F @ P[i] @ self._F.T + self._Q
117 K = P[i, :] @ self._F.T @ np.linalg.inv(C)
118 X[i] += K @ (X[i + 1] - (self._F @ X[i]))
119 P[i] += K @ (P[i + 1] - C) @ K.T
120
121 self._x = X[-1]
122 self._P = P[-1]
123
124 return X, P
125
126
[docs]
127class TimeVaryingRauchTungStriebelSmoother(TimeVaryingKalmanFilter):
128 """
129 Implementation of the time varying Rauch-Tung-Striebel Kalman filter smoother.
130
131 Parameters
132 ----------
133 time_window_length : `int`
134 Length of the time window which is considered for smoothing.
135 state_dim : `int`
136 Dimensionality of states.
137 obs_dim : `int`
138 Dimensionality of observations.
139 init_state : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
140 Initial state.
141 measurement_func : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
142 Measurement function -- i.e. matrix that is converting a state into an observation.
143 state_transition_func : Callable[[int], `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_]
144 Function mapping time (integer) to the time dependent state transition function --
145 i.e. matrix moving from a given state to the next state.
146 init_state_uncertainty_cov : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_, optional
147 Covariance matrix of the initial state uncertainty.
148 If None, the identity matrix will be used.
149
150 The default is None.
151 measurement_uncertainty_cov : Callable[[int], `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_], optional
152 Function mapping time (integer) to the time dependent covariance matrix of the
153 measurement/observation uncertainty.
154 If None, the identity matrix will be used in all time steps.
155
156 The default is None.
157 system_uncertainty_cov : Callable[[int], `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_], optional
158 Function mapping time (integer) to the time dependent covariance matrix of the
159 system uncertainty.
160 If None, the identity matrix will be used in all time steps.
161
162 The default is None.
163 """
164 def __init__(self, time_window_length: int, state_dim: int,
165 obs_dim: int, init_state: np.ndarray,
166 measurement_func: np.ndarray,
167 state_transition_func: Callable[[int], np.ndarray],
168 init_state_uncertainty_cov: Optional[np.ndarray],
169 measurement_uncertainty_cov: Optional[Callable[[int], np.ndarray]],
170 system_uncertainty_cov: Optional[Callable[[int], np.ndarray]]) -> None:
171 if not isinstance(time_window_length, int):
172 raise TypeError("'time_window_length' must be an instance of 'int' " +
173 f"but not of '{type(time_window_length)}'")
174 if time_window_length <= 0:
175 raise ValueError("'time_window_length' must be positive")
176
177 self._time_window_length = time_window_length
178
179 super().__init__(state_dim=state_dim, obs_dim=obs_dim, init_state=init_state,
180 measurement_func=measurement_func,
181 state_transition_func=state_transition_func,
182 init_state_uncertainty_cov=init_state_uncertainty_cov,
183 measurement_uncertainty_cov=measurement_uncertainty_cov,
184 system_uncertainty_cov=system_uncertainty_cov)
185
186 @property
187 def time_window_length(self) -> int:
188 """
189 Returns the length of the time window.
190
191 Returns
192 -------
193 `int`
194 Time window length.
195 """
196 return self._time_window_length
197
198 def __eq__(self, other):
199 return super().__eq__(other) and self._time_window_length == other.time_window_length
200
201 def __str__(self):
202 return super().__str__() + f" time_window_length: {self._time_window_length}"
203
[docs]
204 def step(self, observation: np.ndarray) -> tuple[list[np.ndarray], list[np.ndarray]]:
205 """
206 Predicts the current state (incl. it's uncertainty) based on a given
207 time window of observations.
208 Also, updates all other internal states of the Kalman filter.
209
210 Parameters
211 ----------
212 observation : `numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_
213 Time window of observations.
214
215 Returns
216 -------
217 tuple[list[`numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_], list[`numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_]]
218 List of predicted system states and uncertainty covariance matrices.
219 """
220 if not isinstance(observation, np.ndarray):
221 raise TypeError("'observation' must be an instance of 'numpy.ndarray' " +
222 f"but not of '{type(observation)}'")
223 if observation.shape != (self._time_window_length, self._obs_dim):
224 raise ValueError("'observation' must be of shap (time_window_length, obs_dim) -- " +
225 f"i.e. {(self._time_window_length, self._obs_dim)}. " +
226 f"But found {observation.shape}")
227
228 # Forward pass
229 X = [], P = []
230 for i in range(self._time_window_length):
231 x, cov = super().step(observation[i, :].flatten())
232 X.append(x)
233 P.append(cov)
234
235 # Backward pass
236 t = self._t - 1
237 for i in range(self._time_window_length-2, -1, -1):
238 F = self._get_state_transition_func(t)
239 Q = self._get_system_uncertainty_cov(t)
240 t -= 1
241
242 C = F @ P[i] @ F.T + Q
243 K = P[i, :] @ F.T @ np.linalg.inv(C)
244 X[i] += K @ (X[i + 1] - (F @ X[i]))
245 P[i] += K @ (P[i + 1] - C) @ K.T
246
247 self._x = X[-1]
248 self._P = P[-1]
249
250 return X, P