Skip to content

Commit

Permalink
Merge pull request #60 from jeiloh/master
Browse files Browse the repository at this point in the history
Kalman filter upgrade
  • Loading branch information
mdbartos authored Nov 13, 2023
2 parents 0328cc5 + 0860e0b commit ee07af3
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 23 deletions.
19 changes: 10 additions & 9 deletions pipedream_solver/nutils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from pipedream_solver.utils import _square_root_kalman_semi_implicit
from numba import njit

@njit
Expand Down Expand Up @@ -184,13 +185,13 @@ def _kalman_semi_implicit(Z_next, P_x_k_k, A_1, A_2, b, H, C,
Measurement noise covariance
"""
I = np.eye(A_1.shape[0])
y_k1_k = b
A_1_inv = np.linalg.inv(A_1)
H_1 = H @ A_1_inv
P_y_k1_k = A_2 @ P_x_k_k @ A_2.T + C @ Qcov @ C.T
L_y_k1 = P_y_k1_k @ H_1.T @ np.linalg.inv((H_1 @ P_y_k1_k @ H_1.T) + Rcov)
P_y_k1_k1 = (I - L_y_k1 @ H_1) @ P_y_k1_k
b_hat = y_k1_k + L_y_k1 @ (Z_next - H_1 @ y_k1_k)
P_x_k1_k1 = A_1_inv @ P_y_k1_k1 @ A_1_inv.T
return b_hat, P_x_k1_k1


x_k1_k = A_1_inv @ b
P_x_k1_k = A_1_inv @ A_2 @ P_x_k_k @ A_2.T @ A_1_inv.T + C @ Qcov @ C.T
L_x_k1 = P_x_k1_k @ H.T @ np.linalg.inv((H @ P_x_k1_k @ H.T) + Rcov)
P_zz = (H @ P_x_k1_k @ H.T) + Rcov
P_x_k1_k1 = (I - L_x_k1 @ H) @ P_x_k1_k
x_hat = x_k1_k + L_x_k1 @ (Z_next - H @ x_k1_k)
b_hat = A_1 @ x_hat
return b_hat, P_x_k1_k1, P_zz
23 changes: 17 additions & 6 deletions pipedream_solver/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
except:
_HAS_NUMBA = False
if _HAS_NUMBA:
from pipedream_solver.nutils import interpolate_sample, _kalman_semi_implicit
from pipedream_solver.nutils import interpolate_sample, _kalman_semi_implicit, _square_root_kalman_semi_implicit
else:
from pipedream_solver.utils import interpolate_sample, _kalman_semi_implicit
from pipedream_solver.utils import interpolate_sample, _kalman_semi_implicit, _square_root_kalman_semi_implicit

eps = np.finfo(float).eps

Expand Down Expand Up @@ -86,7 +86,7 @@ class Simulation():
def __init__(self, model, Q_in=None, H_bc=None, Q_Ik=None, t_start=None,
t_end=None, dt=None, max_iter=None, min_dt=1, max_dt=200,
tol=0.01, min_rel_change=1e-10, max_rel_change=1e10, safety_factor=0.9,
Qcov=None, Rcov=None, C=None, H=None, interpolation_method='linear'):
Pxx = None, Qcov=None, Rcov=None, C=None, H=None, interpolation_method='linear'):
self.model = model
if Q_in is not None:
self.Q_in = Q_in.copy(deep=True)
Expand Down Expand Up @@ -204,7 +204,12 @@ def __init__(self, model, Q_in=None, H_bc=None, Q_Ik=None, t_start=None,
else:
assert isinstance(H, np.ndarray)
self.H = H
self.P_x_k_k = self.C @ self.Qcov @ self.C.T
if Pxx is None:
self.P_x_k_k = self.C @ self.Qcov @ self.C.T
else:
self.P_x_k_k = Pxx.copy()
self.A_1 = None
self.P_zz = None
# Progress bar checkpoints
if np.isfinite(self.t_end):
self._checkpoints = np.linspace(self.t_start, self.t_end)
Expand Down Expand Up @@ -447,7 +452,7 @@ def filter_step_size(self, tol=0.5, dts=None, errs=None, coeffs=[0.5, 0.5, 0, 0.
return dt_np1

def kalman_filter(self, Z, H=None, C=None, Qcov=None, Rcov=None, P_x_k_k=None,
dt=None, **kwargs):
dt=None, SR=False, **kwargs):
"""
Apply Kalman Filter to fuse observed data into model.
Expand Down Expand Up @@ -481,9 +486,15 @@ def kalman_filter(self, Z, H=None, C=None, Qcov=None, Rcov=None, P_x_k_k=None,
if Rcov is None:
Rcov = self.Rcov
A_1, A_2, b = self.model._semi_implicit_system(_dt=dt)
b_hat, P_x_k_k = _kalman_semi_implicit(Z, P_x_k_k, A_1, A_2, b, H, C,
if SR == False:
b_hat, P_x_k_k, P_zz = _kalman_semi_implicit(Z, P_x_k_k, A_1, A_2, b, H, C,
Qcov, Rcov)
else:
b_hat, P_x_k_k, P_zz = _square_root_kalman_semi_implicit(Z, P_x_k_k, A_1, A_2, b, H, C,
Qcov, Rcov)
self.P_x_k_k = P_x_k_k
self.P_zz = P_zz
self.A_1 = A_1
self.model.b = b_hat
self.model.iter_count -= 1
self.model.t -= dt
Expand Down
61 changes: 53 additions & 8 deletions pipedream_solver/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,58 @@ def _kalman_semi_implicit(Z_next, P_x_k_k, A_1, A_2, b, H, C,
Measurement noise covariance
"""
I = np.eye(A_1.shape[0])
y_k1_k = b
A_1_inv = np.linalg.inv(A_1)
H_1 = H @ A_1_inv
P_y_k1_k = A_2 @ P_x_k_k @ A_2.T + C @ Qcov @ C.T
L_y_k1 = P_y_k1_k @ H_1.T @ np.linalg.inv((H_1 @ P_y_k1_k @ H_1.T) + Rcov)
P_y_k1_k1 = (I - L_y_k1 @ H_1) @ P_y_k1_k
b_hat = y_k1_k + L_y_k1 @ (Z_next - H_1 @ y_k1_k)
P_x_k1_k1 = A_1_inv @ P_y_k1_k1 @ A_1_inv.T
return b_hat, P_x_k1_k1

x_k1_k = A_1_inv @ b
P_x_k1_k = A_1_inv @ A_2 @ P_x_k_k @ A_2.T @ A_1_inv.T + C @ Qcov @ C.T
L_x_k1 = P_x_k1_k @ H.T @ np.linalg.inv((H @ P_x_k1_k @ H.T) + Rcov)
P_zz = (H @ P_x_k1_k @ H.T) + Rcov
P_x_k1_k1 = (I - L_x_k1 @ H) @ P_x_k1_k
x_hat = x_k1_k + L_x_k1 @ (Z_next - H @ x_k1_k)
b_hat = A_1 @ x_hat
return b_hat, P_x_k1_k1, P_zz

def _square_root_kalman_semi_implicit(Z_next, P_x_k_k, A_1, A_2, b, H, C,
Qcov, Rcov):
"""
Perform Kalman filtering to estimate state and error covariance.
Inputs:
-------
Z_next : np.ndarray (b x 1)
Observed data
P_x_k_k : np.ndarray (M x M)
Posterior error covariance estimate at previous timestep
A_1 : np.ndarray (M x M)
Left state transition matrix
A_2 : np.ndarray (M x M)
Right state transition matrix
b : np.ndarray (M x 1)
Right-hand side solution vector
H : np.ndarray (M x b)
Observation matrix
C : np.ndarray (a x M)
Signal-input matrix
Qcov : np.ndarray (M x M)
Process noise covariance
Rcov : np.ndarray (M x M)
Measurement noise covariance
"""
I = np.eye(A_1.shape[0])
A_1_inv = np.linalg.inv(A_1)
Rq = np.linalg.cholesky(Qcov)
Rr = np.linalg.cholesky(Rcov)
F = np.linalg.cholesky(P_x_k_k)

x_k1_k = A_1_inv @ b
Fbar = np.linalg.qr(np.vstack((F@A_2.T@A_1_inv.T, Rq)), mode='r')

G = np.linalg.qr(np.block([[Fbar@H.T], [Rr]]), mode='r')

L_x_k1 = (np.linalg.inv(G)@(np.linalg.inv(G).T@H)@Fbar.T@Fbar).T

Fhat = np.linalg.qr(np.block([[Fbar@(I - L_x_k1@H).T], [Rr@L_x_k1.T]]), mode='r')
x_hat = x_k1_k + L_x_k1 @ (Z_next - H @ x_k1_k)

P_x_k1_k1 = Fhat.T@Fhat
b_hat = A_1 @ x_hat
return b_hat, P_x_k1_k1

0 comments on commit ee07af3

Please sign in to comment.