Implementation of the Kalman Filter
Methods¶
Filter¶
This is a two step process:
- Prediction Step for the Transition Model
- Measurement Step for the Emission Model
Predict Step¶
# predictive mean (state), t|t-1
mu_z_t_cond = F @ mu_t
# predictive covariance (state), t|t-1
Sigma_z_t_cond = F @ Sigma_t @ F.T + Q
Update Step¶
where:
- is the estimation of the state mean given the observations
- is the estimation of the state cov given the observations.
# Pred Mean (obs)
mu_x_t_cond = H @ mu_z_t_cond
# Pred Cov (obs)
Sigma_x_t_cond = H @ Sigma_z_t_cond @ H.T + R
We then need to do a correction. This is done via a 2-step process
Innovation & Kalman Gain¶
# innovation
r_t = x_t - mu_x_t_cond
# kalman gain
K_t = Sigma_z_t_cond @ H.T @ inv(Sigma_x_t_cond)
Correction¶
# estimated state mean
mu_z_t = mu_z_t_cond + K_t @ r_t
# estimated state covariance
Sigma_z_t = (I - K_t @ H) @ Sigma_z_t_cond
Filtering¶
Psuedocode¶
Inputs
- - transition matrix
- - transition noise
- - emission matrix
- - prior for state, , mean
- - prior for state, , covariance
Parameters
- - transition matrix
- - transition noise
- - emission matrix
Function
def _sequential_kf(F, Rs, H, ys, Qs, m0, P0, masks, return_predict=False):
"""
Parameters
----------
F : np.ndarray, shape=(state_dim, state_dim)
transition matrix for the transition function
Rs : np.ndarray, shape=(n_time, state_dim, state_dim)
noise matrices for the transition function
H : np.ndarray, shape=(obs_dim, state_dim)
the emission matrix for the emission function
Qs: np.ndarray, shape=(n_time, obs_dim, obs_dim)
the noises for the emission
ys : np.ndarray, shape=(batch, n_time, obs_dim)
the observations
m0 : np.ndarray, shape=(obs_dim)
"""
def body(carry, inputs):
# ==================
# Unroll Inputs
# ==================
# extract constants
y, R, Q, mask = inputs
# extract next steps (mu, sigma, ll)
m, P, ell = carry
# ==================
# Predict Step
# ==================
m_ = F @ m
P_ = F @ P @ F.T + Q
# ==================
# Update Step
# ==================
# residuals
obs_mean = H @ m_
HP = H @ P_
S = HP @ H.T + R
# log likelihood
ell_n = mvn_logpdf(y, obs_mean, S, mask)
ell = ell + ell_n
K = solve(S, HP).T
# correction step
m = m_ + K @ (y - obs_mean)
P = P_ - K @ HP
if return_predict:
return (m, P, ell), (m_, P_)
else:
return (m, P, ell), (m, P)
(_, _, loglik), (fms, fPs) = scan(
f=body,
init=(m0, P0, 0.),
xs=(ys, Qs, Rs, masks)
)
return loglik, fms, fPs
def kalman_filter(dt, kernel, y, noise_cov, mask=None, return_predict=False):
"""
Run the Kalman filter to get p(fₙ|y₁,...,yₙ).
Assumes a heteroscedastic Gaussian observation model, i.e. var is vector valued
:param dt: step sizes [N, 1]
:param kernel: an instantiation of the kernel class, used to determine the state space model
:param y: observations [N, D, 1]
:param noise_cov: observation noise covariances [N, D, D]
:param mask: boolean mask for the observations (to indicate missing data locations) [N, D, 1]
:param return_predict: flag whether to return predicted state, rather than updated state
:return:
ell: the log-marginal likelihood log p(y), for hyperparameter optimisation (learning) [scalar]
means: intermediate filtering means [N, state_dim, 1]
covs: intermediate filtering covariances [N, state_dim, state_dim]
"""
if mask is None:
mask = np.zeros_like(y, dtype=bool)
Pinf = kernel.stationary_covariance()
minf = np.zeros([Pinf.shape[0], 1])
# get constant params
F = ... # transition matrix
H = ... # emission matrix
# generate noise matrices
Rs = ... # generate noise for transitions
Qs = ... # generate noise for emissions
ell, means, covs = _sequential_kf(As, Qs, H, y, noise_cov, minf, Pinf, mask, return_predict=return_predict)
return ell, (means, covs)
Smoothing¶
Likelihoods¶
Missing Data¶
Resources¶
Courses¶
Sensor Fusion and Non-Linear Filtering - Youtube Playlist
Kalman Filtering and Applications in Finance - Youtube
Papers¶
Code¶
- Kalman Filter with Dask/XArray - Repo
- Ouala, S., Fablet, R., Herzet, C., Chapron, B., Pascual, A., Collard, F., & Gaultier, L. (2018). Neural Network Based Kalman Filters for the Spatio-Temporal Interpolation of Satellite-Derived Sea Surface Temperature. Remote Sensing, 10(12). 10.3390/rs10121864