State-Space Models & Kalman¶
Layer 3 recipes for linear-Gaussian state-space models. Stationary 1-D GP kernels with rational spectral densities admit exact SDE representations
turning \(O(N^3)\) GP inference into \(O(N d^3)\) Kalman filtering. This page covers the SDE kernel zoo, the filters and smoothers (sequential, parallel associative-scan, square-root, and steady-state), and the natural-parameter / site machinery for non-conjugate likelihoods.
SDE kernels¶
Structured linear algebra and Gaussian primitives for JAX.
SDEKernel
¶
Bases: Module
Abstract base class for state-space kernel representations.
Subclasses implement sde_params to provide the continuous-time
SDE matrices (F, L, H, Q_c, P_inf). The default discretise
uses the matrix exponential for discretization; subclasses may override
with closed-form solutions.
Source code in src/gaussx/_ssm/_sde_kernel.py
state_dim: int
abstractmethod
property
¶
Dimension of the latent state vector.
sde_params() -> SDEParams
abstractmethod
¶
discretise(dt: Float[Array, '']) -> tuple[Float[Array, 'd d'], Float[Array, 'd d']]
¶
Discretise the SDE at time step dt.
Default implementation computes:
A = expm(F * dt)
Q = P_inf - A @ P_inf @ A^T
Subclasses may override with closed-form expressions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dt
|
Float[Array, '']
|
Time step (scalar, positive). |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'd d']
|
Tuple |
Float[Array, 'd d']
|
Q is the process noise covariance. |
Source code in src/gaussx/_ssm/_sde_kernel.py
discretise_sequence(dt: Float[Array, ' N']) -> tuple[Float[Array, 'N d d'], Float[Array, 'N d d']]
¶
Discretise the SDE at multiple time steps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dt
|
Float[Array, ' N']
|
Time steps, shape |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, 'N d d'], Float[Array, 'N d d']]
|
Tuple |
Source code in src/gaussx/_ssm/_sde_kernel.py
SDEParams
¶
Bases: NamedTuple
Continuous-time SDE parameters for a stationary kernel.
Defines the linear time-invariant SDE:
dx = F x dt + L dW, W ~ N(0, Q_c dt)
with observation model y = H x.
Attributes:
| Name | Type | Description |
|---|---|---|
F |
Float[Array, 'd d']
|
Drift matrix, shape |
L |
Float[Array, 'd s']
|
Diffusion matrix, shape |
H |
Float[Array, '1 d']
|
Observation matrix, shape |
Q_c |
Float[Array, 's s']
|
Spectral density, shape |
P_inf |
Float[Array, 'd d']
|
Stationary covariance, shape |
Source code in src/gaussx/_ssm/_sde_kernel.py
ConstantSDE
¶
Bases: SDEKernel
State-space representation of a constant kernel.
Models \(k(\tau) = \sigma^2\) — a degenerate kernel with zero dynamics and zero diffusion. State dimension is 1.
Attributes:
| Name | Type | Description |
|---|---|---|
variance |
Float[Array, '']
|
Signal variance \(\sigma^2\). |
Source code in src/gaussx/_ssm/_constant.py
MaternSDE
¶
Bases: SDEKernel
State-space representation of the Matern kernel.
Supports orders 0 (Matern-1/2), 1 (Matern-3/2), and 2 (Matern-5/2).
The state dimension is order + 1.
Attributes:
| Name | Type | Description |
|---|---|---|
variance |
Float[Array, '']
|
Signal variance \(\sigma^2\). |
lengthscale |
Float[Array, '']
|
Lengthscale \(\ell\). |
order |
int
|
Matern order (0, 1, or 2). |
Source code in src/gaussx/_ssm/_matern.py
sde_params() -> SDEParams
¶
Compute SDE parameters for the Matern kernel.
Source code in src/gaussx/_ssm/_matern.py
PeriodicSDE
¶
Bases: SDEKernel
State-space representation of the periodic (MacKay) kernel.
Approximates the periodic kernel via Fourier series truncation
to n_harmonics terms. State dimension is 2 * n_harmonics.
Attributes:
| Name | Type | Description |
|---|---|---|
variance |
Float[Array, '']
|
Signal variance \(\sigma^2\). |
lengthscale |
Float[Array, '']
|
Lengthscale \(\ell\). |
period |
Float[Array, '']
|
Period \(T\). |
n_harmonics |
int
|
Number of Fourier harmonics (truncation order). |
Source code in src/gaussx/_ssm/_periodic.py
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | |
sde_params() -> SDEParams
¶
Return SDE parameters for the periodic kernel.
Source code in src/gaussx/_ssm/_periodic.py
discretise(dt: Float[Array, '']) -> tuple[Float[Array, 'd d'], Float[Array, 'd d']]
¶
Closed-form: block-diagonal rotation matrices.
Source code in src/gaussx/_ssm/_periodic.py
QuasiPeriodicSDE
¶
Bases: ProductSDE
Quasi-periodic kernel: product of Matern and Periodic SDE.
Attributes:
| Name | Type | Description |
|---|---|---|
kernel1 |
SDEKernel
|
Modulating kernel (typically Matern). |
kernel2 |
SDEKernel
|
Periodic kernel. |
Source code in src/gaussx/_ssm/_composition.py
CosineSDE
¶
Bases: SDEKernel
State-space representation of the cosine kernel.
Models \(k(\tau) = \sigma^2 \cos(\omega_0 \tau)\) via a 2-D rotation SDE. State dimension is 2.
Attributes:
| Name | Type | Description |
|---|---|---|
variance |
Float[Array, '']
|
Signal variance \(\sigma^2\). |
frequency |
Float[Array, '']
|
Angular frequency \(\omega_0\). |
Source code in src/gaussx/_ssm/_periodic.py
sde_params() -> SDEParams
¶
Return SDE parameters for the cosine kernel.
Source code in src/gaussx/_ssm/_periodic.py
discretise(dt: Float[Array, '']) -> tuple[Float[Array, 'd d'], Float[Array, 'd d']]
¶
Closed-form rotation matrix discretization.
Source code in src/gaussx/_ssm/_periodic.py
ProductSDE
¶
Bases: SDEKernel
Product of two SDE kernels via Kronecker composition.
Attributes:
| Name | Type | Description |
|---|---|---|
kernel1 |
SDEKernel
|
First component kernel. |
kernel2 |
SDEKernel
|
Second component kernel. |
Source code in src/gaussx/_ssm/_composition.py
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | |
sde_params() -> SDEParams
¶
Return Kronecker-structured SDE parameters.
Note
SDEParams currently types its fields as dense
jaxtyping.Float[Array, ...]. The Kronecker products
below are dense materializations of size
(state_dim, state_dim), where state_dim is
kernel1.state_dim * kernel2.state_dim — for typical SSM
kernels (Matérn-3/2, periodic) this is ≤ 32, so the
materialization is bounded and cheap. A future refactor
could expose a parallel sde_operators() method that
returns gaussx.Kronecker operators for downstream
filters that can exploit the structure (issue #153).
Source code in src/gaussx/_ssm/_composition.py
discretise(dt: Float[Array, '']) -> tuple[Float[Array, 'd d'], Float[Array, 'd d']]
¶
Discretise via the Kronecker matrix-exponential identity.
For a product kernel F = F_1 \oplus F_2 = F_1 \otimes I + I \otimes F_2,
the factors F_1 \otimes I and I \otimes F_2 commute, so
This computes two expm calls of size d_1 and d_2
each, plus one Kronecker product, instead of one expm of
size d_1 \cdot d_2. Numerically equivalent to the dense
expm on F but cheaper for moderate factor sizes.
Q = P_\infty - A P_\infty A^T is computed densely from the
resulting A; with P_\infty = P_{\infty,1} \otimes
P_{\infty,2} this could itself be expressed as a Kronecker
difference, but is left dense to keep the consumer-facing
(A, Q) interface unchanged.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dt
|
Float[Array, '']
|
Time step (scalar, positive). |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, 'd d'], Float[Array, 'd d']]
|
Tuple |
Source code in src/gaussx/_ssm/_composition.py
SumSDE
¶
Bases: SDEKernel
Sum of SDE kernels via block-diagonal composition.
Attributes:
| Name | Type | Description |
|---|---|---|
kernels |
tuple[SDEKernel, ...]
|
Tuple of component SDE kernels. |
Source code in src/gaussx/_ssm/_composition.py
sde_params() -> SDEParams
¶
Return block-diagonal SDE parameters.
Source code in src/gaussx/_ssm/_composition.py
sde_autocovariance(kernel: SDEKernel, tau: Float[Array, ' *batch']) -> Float[Array, ' *batch']
¶
Compute the stationary autocovariance of an SDE kernel.
Evaluates:
K(\tau) = H \, \exp(F |\tau|) \, P_\infty \, H^T
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kernel
|
SDEKernel
|
An SDE kernel with |
required |
tau
|
Float[Array, ' *batch']
|
Lag values, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ' *batch']
|
Autocovariance values |
Source code in src/gaussx/_ssm/_autocovariance.py
Kalman filtering & smoothing¶
The forward filter and RTS smoother, their \(O(\log N)\) parallel (associative-scan) counterparts, and the steady-state (infinite-horizon) variants built on the discrete algebraic Riccati equation.
Structured linear algebra and Gaussian primitives for JAX.
EmissionModel
¶
Bases: Module
Observation (emission) model wrapping a linear observation matrix.
Provides named methods for common Kalman filter projection operations with observation matrix H ∈ ℝᴹˣᴺ.
Attributes:
| Name | Type | Description |
|---|---|---|
H |
Float[Array, 'M N']
|
Observation matrix, shape |
Source code in src/gaussx/_ssm/_emission.py
project_mean(mean: Float[Array, ' N']) -> Float[Array, ' M']
¶
Project state mean to observation space: ŷ = H x.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mean
|
Float[Array, ' N']
|
State mean, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ' M']
|
Projected mean, shape |
Source code in src/gaussx/_ssm/_emission.py
project_covariance(cov: Float[Array, 'N N'], noise: Float[Array, 'M M'] | None = None) -> Float[Array, 'M M']
¶
Project state covariance: S = H P Hᵀ [+ R].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cov
|
Float[Array, 'N N']
|
State covariance P, shape |
required |
noise
|
Float[Array, 'M M'] | None
|
Optional observation noise R, shape |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'M M']
|
Innovation covariance S, shape |
Source code in src/gaussx/_ssm/_emission.py
innovation(y: Float[Array, ' M'], x_pred: Float[Array, ' N']) -> Float[Array, ' M']
¶
Compute innovation (measurement residual): v = y − H x.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y
|
Float[Array, ' M']
|
Observation, shape |
required |
x_pred
|
Float[Array, ' N']
|
Predicted state mean, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ' M']
|
Innovation vector v, shape |
Source code in src/gaussx/_ssm/_emission.py
back_project_precision(noise_prec: Float[Array, 'M M']) -> Float[Array, 'N N']
¶
Back-project observation precision: Hᵀ R⁻¹ H.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
noise_prec
|
Float[Array, 'M M']
|
Observation noise precision R⁻¹, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N N']
|
Information matrix contribution, shape |
Source code in src/gaussx/_ssm/_emission.py
back_project_info(y: Float[Array, ' M'], noise_prec: Float[Array, 'M M']) -> Float[Array, ' N']
¶
Back-project observation to information vector: Hᵀ R⁻¹ y.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y
|
Float[Array, ' M']
|
Observation, shape |
required |
noise_prec
|
Float[Array, 'M M']
|
Observation noise precision R⁻¹, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ' N']
|
Information vector contribution, shape |
Source code in src/gaussx/_ssm/_emission.py
FilterState
¶
Bases: Module
Output of kalman_filter.
Attributes:
| Name | Type | Description |
|---|---|---|
filtered_means |
Float[Array, 'T N']
|
Shape |
filtered_covs |
Float[Array, 'T N N']
|
Shape |
predicted_means |
Float[Array, 'T N']
|
Shape |
predicted_covs |
Float[Array, 'T N N']
|
Shape |
log_likelihood |
Float[Array, '']
|
Scalar — total log-likelihood. |
Source code in src/gaussx/_ssm/_kalman.py
InfiniteHorizonState
¶
Bases: Module
Output of infinite_horizon_filter.
Attributes:
| Name | Type | Description |
|---|---|---|
filtered_means |
Float[Array, 'T N']
|
Filtered state estimates, shape |
filtered_covs |
Float[Array, 'T N N']
|
Filtered covariances (constant), shape |
predicted_means |
Float[Array, 'T N']
|
Predicted state estimates, shape |
predicted_covs |
Float[Array, 'T N N']
|
Predicted covariances (constant), shape |
log_likelihood |
Float[Array, '']
|
Total log-likelihood (scalar). |
Source code in src/gaussx/_ssm/_infinite_horizon_kalman.py
DAREResult
¶
Bases: Module
Result of DARE solver.
Attributes:
| Name | Type | Description |
|---|---|---|
P_inf |
Float[Array, 'D D']
|
Steady-state covariance, shape |
K_inf |
Float[Array, 'D M']
|
Steady-state Kalman gain, shape |
converged |
Bool[Array, '']
|
Scalar boolean indicating convergence. |
Source code in src/gaussx/_ssm/_dare.py
kalman_filter(transition: Float[Array, '*T N N'] | lx.AbstractLinearOperator, obs_model: Float[Array, '*T M N'] | lx.AbstractLinearOperator, process_noise: Float[Array, '*T N N'] | lx.AbstractLinearOperator, obs_noise: Float[Array, '*T M M'] | lx.AbstractLinearOperator, observations: Float[Array, 'T M'], init_mean: Float[Array, ' N'], init_cov: Float[Array, 'N N'], *, mask: Bool[Array, ' T'] | None = None, solver: AbstractSolverStrategy | None = None, woodbury_innovation: bool = False) -> FilterState
¶
Kalman filter forward pass via jax.lax.scan.
Implements the predict-update cycle for a (possibly time-varying) linear-Gaussian state-space model:
x_t = A_t @ x_{t-1} + q_t, q_t ~ N(0, Q_t)
y_t = H_t @ x_t + r_t, r_t ~ N(0, R_t)
Time-invariant inputs (single (N, N) / (M, N) etc.) are
automatically broadcast along the time axis. Time-varying inputs
are passed as (T, …) stacks (e.g. from
discretise_sequence).
Operator inputs (lineax BlockDiag / Kronecker /
DiagonalLinearOperator / MaskedOperator / etc.) are accepted
in the time-invariant signature only. The structural matvec
(A @ x, H @ x) runs through the operator's mv;
operator-typed Q / R are materialised to dense arrays once
outside the scan (the per-step sandwiches A P A^T / H P H^T
themselves run inside the scan because they depend on the evolving
P_filt).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
transition
|
Float[Array, '*T N N'] | AbstractLinearOperator
|
State transition matrix |
required |
obs_model
|
Float[Array, '*T M N'] | AbstractLinearOperator
|
Observation matrix |
required |
process_noise
|
Float[Array, '*T N N'] | AbstractLinearOperator
|
Process noise covariance |
required |
obs_noise
|
Float[Array, '*T M M'] | AbstractLinearOperator
|
Observation noise covariance |
required |
observations
|
Float[Array, 'T M']
|
Observed data, shape |
required |
init_mean
|
Float[Array, ' N']
|
Initial state mean, shape |
required |
init_cov
|
Float[Array, 'N N']
|
Initial state covariance, shape |
required |
mask
|
Bool[Array, ' T'] | None
|
Optional per-step boolean mask, shape |
None
|
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy. When |
None
|
woodbury_innovation
|
bool
|
When |
False
|
Raises:
| Type | Description |
|---|---|
TypeError
|
If operator-typed inputs are mixed with 3D |
Returns:
| Type | Description |
|---|---|
FilterState
|
A |
FilterState
|
and total log-likelihood. |
Source code in src/gaussx/_ssm/_kalman.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | |
rts_smoother(filter_state: FilterState, transition: Float[Array, '*T N N'] | lx.AbstractLinearOperator, process_noise: Float[Array, '*T N N'] | lx.AbstractLinearOperator, *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, 'T N'], Float[Array, 'T N N']]
¶
Rauch-Tung-Striebel backward smoother.
Accepts the same time-invariant / time-varying / operator forms for
transition and process_noise as kalman_filter. When
a step was masked off in the filter (mask[t] == 0), the
smoother formula degenerates harmlessly because filtered ==
predicted at that step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
filter_state
|
FilterState
|
Output of |
required |
transition
|
Float[Array, '*T N N'] | AbstractLinearOperator
|
State transition matrix or operator. |
required |
process_noise
|
Float[Array, '*T N N'] | AbstractLinearOperator
|
Process noise covariance or operator. (Not
currently used by the standard RTS recurrence — kept for
API symmetry with |
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, 'T N'], Float[Array, 'T N N']]
|
Tuple |
Source code in src/gaussx/_ssm/_kalman.py
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 | |
kalman_gain(P: lx.AbstractLinearOperator, H: lx.AbstractLinearOperator, R: lx.AbstractLinearOperator, *, solver: AbstractSolverStrategy | None = None, woodbury_innovation: bool = False) -> Float[Array, 'N M']
¶
Compute Kalman gain K = P @ H^T @ (H @ P @ H^T + R)^{-1}.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
P
|
AbstractLinearOperator
|
Prior covariance operator, shape |
required |
H
|
AbstractLinearOperator
|
Observation model operator, shape |
required |
R
|
AbstractLinearOperator
|
Observation noise operator, shape |
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy. When |
None
|
woodbury_innovation
|
bool
|
When |
False
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'N M']
|
Kalman gain matrix of shape |
Source code in src/gaussx/_ssm/_kalman.py
parallel_kalman_filter(transition: Float[Array, '*T N N'] | lx.AbstractLinearOperator, obs_model: Float[Array, '*T M N'] | lx.AbstractLinearOperator, process_noise: Float[Array, '*T N N'] | lx.AbstractLinearOperator, obs_noise: Float[Array, '*T M M'] | lx.AbstractLinearOperator, observations: Float[Array, 'T M'], init_mean: Float[Array, ' N'], init_cov: Float[Array, 'N N'], *, mask: Bool[Array, ' T'] | None = None, solver: AbstractSolverStrategy | None = None, woodbury_innovation: bool = False, form: str = 'covariance') -> FilterState
¶
Parallel Kalman filter via jax.lax.associative_scan.
Numerically equivalent to gaussx.kalman_filter but with
O(log T) parallel depth on accelerators. Same generalised
contract (TI / TV / operator-typed inputs, optional mask, scalar
log-likelihood). Empty observation windows (T == 0) return a
zero-length FilterState with log_likelihood == 0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
transition
|
Float[Array, '*T N N'] | AbstractLinearOperator
|
State transition matrix or operator. |
required |
obs_model
|
Float[Array, '*T M N'] | AbstractLinearOperator
|
Observation matrix or operator. |
required |
process_noise
|
Float[Array, '*T N N'] | AbstractLinearOperator
|
Process noise covariance or operator. |
required |
obs_noise
|
Float[Array, '*T M M'] | AbstractLinearOperator
|
Observation noise covariance or operator. |
required |
observations
|
Float[Array, 'T M']
|
Observed data, shape |
required |
init_mean
|
Float[Array, ' N']
|
Initial state mean, shape |
required |
init_cov
|
Float[Array, 'N N']
|
Initial state covariance, shape |
required |
mask
|
Bool[Array, ' T'] | None
|
Optional |
None
|
solver
|
AbstractSolverStrategy | None
|
Accepted for API symmetry with |
None
|
woodbury_innovation
|
bool
|
When |
False
|
form
|
str
|
Either |
'covariance'
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Returns:
| Type | Description |
|---|---|
FilterState
|
|
FilterState
|
and the total log-likelihood. |
Source code in src/gaussx/_ssm/_parallel_kalman.py
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 | |
parallel_rts_smoother(filter_state: FilterState, transition: Float[Array, '*T N N'] | lx.AbstractLinearOperator, process_noise: Float[Array, '*T N N'] | lx.AbstractLinearOperator, *, solver: AbstractSolverStrategy | None = None, form: str = 'covariance') -> tuple[Float[Array, 'T N'], Float[Array, 'T N N']]
¶
Parallel RTS smoother via reverse jax.lax.associative_scan.
Pairs with parallel_kalman_filter. Numerically equivalent to
gaussx.rts_smoother with O(log T) parallel depth.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
filter_state
|
FilterState
|
Output of |
required |
transition
|
Float[Array, '*T N N'] | AbstractLinearOperator
|
State transition matrix or operator. |
required |
process_noise
|
Float[Array, '*T N N'] | AbstractLinearOperator
|
Unused — kept for API symmetry with the sequential smoother. |
required |
solver
|
AbstractSolverStrategy | None
|
Accepted for API symmetry; not currently threaded through. |
None
|
form
|
str
|
Either |
'covariance'
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, 'T N'], Float[Array, 'T N N']]
|
Tuple |
Source code in src/gaussx/_ssm/_parallel_kalman.py
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 | |
infinite_horizon_filter(transition: Float[Array, 'N N'] | lx.AbstractLinearOperator, obs_model: Float[Array, 'M N'] | lx.AbstractLinearOperator, process_noise: Float[Array, 'N N'] | lx.AbstractLinearOperator, obs_noise: Float[Array, 'M M'] | lx.AbstractLinearOperator, observations: Float[Array, 'T M'], init_mean: Float[Array, ' N'] | None = None, *, dare_result: DAREResult | None = None, max_iter: int = 100, tol: float = 1e-08, solver: AbstractSolverStrategy | None = None, woodbury_innovation: bool = False) -> InfiniteHorizonState
¶
Infinite-horizon Kalman filter with fixed steady-state gain.
Uses the DARE solution for a constant Kalman gain K∞, avoiding per-step Riccati updates. For dense matrices, the per-step cost is O(N² + MN + M²) instead of O(N³) for the standard Kalman filter:
Predict: x⁻ₜ = A xₜ₋₁
Update: vₜ = yₜ − H x⁻ₜ
xₜ = x⁻ₜ + K∞ vₜ
All four operator/array arguments accept either a raw JAX array or
a lineax.AbstractLinearOperator. Operator inputs preserve
their structural matvec inside the per-step scan; the sandwiches
materialise once outside the scan.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
transition
|
Float[Array, 'N N'] | AbstractLinearOperator
|
State transition matrix or operator, shape |
required |
obs_model
|
Float[Array, 'M N'] | AbstractLinearOperator
|
Observation matrix or operator, shape |
required |
process_noise
|
Float[Array, 'N N'] | AbstractLinearOperator
|
Process noise covariance or operator, shape |
required |
obs_noise
|
Float[Array, 'M M'] | AbstractLinearOperator
|
Observation noise covariance or operator, shape |
required |
observations
|
Float[Array, 'T M']
|
Observed data y, shape |
required |
init_mean
|
Float[Array, ' N'] | None
|
Initial state mean, shape |
None
|
dare_result
|
DAREResult | None
|
Precomputed DARE result. If |
None
|
max_iter
|
int
|
Maximum DARE iterations (used only if |
100
|
tol
|
float
|
DARE convergence tolerance (used only if |
1e-08
|
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy for structured linear algebra.
When |
None
|
woodbury_innovation
|
bool
|
When |
False
|
Returns:
| Type | Description |
|---|---|
InfiniteHorizonState
|
An |
InfiniteHorizonState
|
covariances, and total log-likelihood. |
Source code in src/gaussx/_ssm/_infinite_horizon_kalman.py
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | |
infinite_horizon_smoother(filter_state: InfiniteHorizonState, transition: Float[Array, 'N N'] | lx.AbstractLinearOperator, dare_result: DAREResult, process_noise: Float[Array, 'N N'] | lx.AbstractLinearOperator, *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, 'T N'], Float[Array, 'T N N']]
¶
Infinite-horizon RTS smoother with fixed steady-state gain.
Precomputes the steady-state smoother gain G∞ = P∞ Aᵀ P⁻pred⁻¹, then runs a backward scan with fixed G∞. The steady-state smoothed covariance is the solution of the discrete Lyapunov equation:
P_smooth = P∞ + G∞ (P_smooth − P⁻pred) G∞ᵀ
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
filter_state
|
InfiniteHorizonState
|
Output of |
required |
transition
|
Float[Array, 'N N'] | AbstractLinearOperator
|
State transition matrix or operator, shape |
required |
dare_result
|
DAREResult
|
DARE result used in the filter. |
required |
process_noise
|
Float[Array, 'N N'] | AbstractLinearOperator
|
Process noise covariance or operator, shape |
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy for structured linear algebra.
When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'T N']
|
Tuple |
Float[Array, 'T N N']
|
|
Source code in src/gaussx/_ssm/_infinite_horizon_kalman.py
dare(A: Float[Array, 'D D'] | lx.AbstractLinearOperator, H: Float[Array, 'M D'] | lx.AbstractLinearOperator, Q: Float[Array, 'D D'] | lx.AbstractLinearOperator, R: Float[Array, 'M M'] | lx.AbstractLinearOperator, *, P_init: Float[Array, 'D D'] | None = None, max_iter: int = 100, tol: float = 1e-08, solver: AbstractSolverStrategy | None = None, woodbury_innovation: bool = False) -> DAREResult
¶
Discrete Algebraic Riccati Equation solver.
Iterates the Kalman predict-update equations until convergence:
Predict: P⁻ = A P Aᵀ + Q
Update: S = H P⁻ Hᵀ + R
K = P⁻ Hᵀ S⁻¹
P = (I - KH) P⁻
Convergence is declared when max|P_new - P_old| < tol.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
Float[Array, 'D D'] | AbstractLinearOperator
|
Transition matrix or operator, shape |
required |
H
|
Float[Array, 'M D'] | AbstractLinearOperator
|
Observation matrix or operator, shape |
required |
Q
|
Float[Array, 'D D'] | AbstractLinearOperator
|
Process noise covariance or operator, shape |
required |
R
|
Float[Array, 'M M'] | AbstractLinearOperator
|
Observation noise covariance or operator, shape |
required |
P_init
|
Float[Array, 'D D'] | None
|
Initial covariance guess, shape |
None
|
max_iter
|
int
|
Maximum number of iterations. |
100
|
tol
|
float
|
Convergence tolerance on the element-wise max absolute change. |
1e-08
|
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy for structured linear algebra.
When |
None
|
woodbury_innovation
|
bool
|
When |
False
|
Returns:
| Type | Description |
|---|---|
DAREResult
|
A |
DAREResult
|
Kalman gain, and convergence flag. |
Source code in src/gaussx/_ssm/_dare.py
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | |
pairwise_marginals(means: Float[Array, 'T d'], covariances: Float[Array, 'T d d'], cross_covariances: Float[Array, 'Tm1 d d']) -> tuple[Float[Array, 'Tm1 two_d'], Float[Array, 'Tm1 two_d two_d']]
¶
Joint p(x_k, x_{k+1}) for each consecutive pair.
For each pair (k, k+1), the joint distribution is:
p(x_k, x_{k+1}) = N([mu_k; mu_{k+1}],
[[P_k, C_k^T],
[C_k, P_{k+1}]])
where C_k = Cov[x_{k+1}, x_k] is the pairwise cross-covariance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
means
|
Float[Array, 'T d']
|
Smoothed means, shape |
required |
covariances
|
Float[Array, 'T d d']
|
Smoothed covariances, shape |
required |
cross_covariances
|
Float[Array, 'Tm1 d d']
|
Pairwise cross-covariances
|
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'Tm1 two_d']
|
Tuple |
Float[Array, 'Tm1 two_d two_d']
|
|
tuple[Float[Array, 'Tm1 two_d'], Float[Array, 'Tm1 two_d two_d']]
|
|
Source code in src/gaussx/_ssm/_pairwise_marginals.py
SpInGP¶
State-space (sparse-in-time) GP inference: marginal likelihood and posterior through the SSM representation.
Structured linear algebra and Gaussian primitives for JAX.
spingp_log_likelihood(prior_precision: BlockTriDiag, emission_model: Array, obs_noise: lx.AbstractLinearOperator, observations: Float[Array, 'N d_obs'], *, solver: AbstractSolverStrategy | None = None) -> Float[Array, '']
¶
Log marginal likelihood via sparse inverse GP formulation.
Computes the log marginal likelihood using the precision-form Kalman filter (SpInGP):
1. Likelihood precision sites: $\Lambda_{lik} = H^T R^{-1} H$
2. Posterior precision: $\Lambda_{post} = \Lambda_{prior} + \Lambda_{lik}$
3. log p(y) via banded Cholesky logdet and quadratic form
The full expression is:
log p(y) = -0.5 * (N_{obs} * log(2\pi) + log|R|_{total}
+ y^T R^{-1} y - \eta^T \Lambda_{post}^{-1} \eta
+ log|\Lambda_{post}| - log|\Lambda_{prior}|)
where \(\eta = H^T R^{-1} y\).
All operations exploit banded structure for O(Nd³) cost.
The solver parameter controls the algorithm used for the
large-scale posterior precision operations (solve, logdet).
Observation noise operations always use structural dispatch
since obs_noise is typically a small dense matrix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prior_precision
|
BlockTriDiag
|
Prior precision as |
required |
emission_model
|
Array
|
Emission matrix H. Shape |
required |
obs_noise
|
AbstractLinearOperator
|
Observation noise covariance R operator. |
required |
observations
|
Float[Array, 'N d_obs']
|
Observations y, shape |
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy for posterior precision
operations. When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar log marginal likelihood. |
Source code in src/gaussx/_ssm/_spingp.py
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | |
spingp_posterior(prior_precision: BlockTriDiag, emission_model: Array, obs_noise: lx.AbstractLinearOperator, observations: Float[Array, 'N d_obs'], *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, ' Nd'], BlockTriDiag]
¶
Posterior mean and precision via SpInGP.
Computes the posterior by adding likelihood precision sites to the prior precision and solving for the posterior mean:
\Lambda_{post} = \Lambda_{prior} + H^T R^{-1} H
\mu_{post} = \Lambda_{post}^{-1} H^T R^{-1} y
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prior_precision
|
BlockTriDiag
|
Prior precision as |
required |
emission_model
|
Array
|
Emission matrix H. Shape |
required |
obs_noise
|
AbstractLinearOperator
|
Observation noise covariance R operator. |
required |
observations
|
Float[Array, 'N d_obs']
|
Observations y, shape |
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy for posterior precision
operations. When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' Nd']
|
Tuple |
BlockTriDiag
|
|
tuple[Float[Array, ' Nd'], BlockTriDiag]
|
|
Source code in src/gaussx/_ssm/_spingp.py
Sites & natural parameters¶
Conjugate-computation VI (CVI) site updates and the conversions between SSM moment, expectation, and natural parameterizations used by non-conjugate temporal inference.
Structured linear algebra and Gaussian primitives for JAX.
GaussianSites
¶
Bases: Module
Time-varying Gaussian likelihood sites in natural parameterization.
Stores per-timestep natural parameters for N Gaussian sites,
following the \eta_2 = -\tfrac{1}{2}\Lambda convention
(consistent with gaussx.mean_cov_to_natural).
Attributes:
| Name | Type | Description |
|---|---|---|
nat1 |
Float[Array, 'N d']
|
Natural location parameters, shape |
nat2 |
Float[Array, 'N d d']
|
Natural precision parameters, shape |
Source code in src/gaussx/_ssm/_cvi.py
cvi_update_sites(sites: GaussianSites, grad_nat1: Float[Array, 'N d'], grad_nat2: Float[Array, 'N d d'], rho: float) -> GaussianSites
¶
Natural gradient update for CVI sites.
Performs a damped update in natural parameter space:
\theta \leftarrow (1 - \rho) \theta + \rho \nabla
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sites
|
GaussianSites
|
Current Gaussian sites. |
required |
grad_nat1
|
Float[Array, 'N d']
|
Natural gradient for location, shape |
required |
grad_nat2
|
Float[Array, 'N d d']
|
Natural gradient for precision, shape |
required |
rho
|
float
|
Step size / damping factor in |
required |
Returns:
| Type | Description |
|---|---|
GaussianSites
|
Updated |
Source code in src/gaussx/_ssm/_cvi.py
sites_to_precision(sites: GaussianSites) -> BlockTriDiag
¶
Convert Gaussian sites to a block-tridiagonal precision.
Returns a block-diagonal BlockTriDiag (zero
sub-diagonals) representing the precision contribution of the
sites. This can be added to a prior precision via .add()
or + to form the posterior precision:
\Lambda_{post} = \Lambda_{prior} + \Lambda_{sites}
Since nat2 stores -\tfrac{1}{2}\Lambda, the precision
blocks are -2 \cdot nat2.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sites
|
GaussianSites
|
Gaussian sites with |
required |
Returns:
| Type | Description |
|---|---|
BlockTriDiag
|
Block-diagonal |
Source code in src/gaussx/_ssm/_cvi.py
cavity_from_marginal(marg_mean: Float[Array, ' *batch'], marg_var: Float[Array, ' *batch'], site_nat1: Float[Array, ' *batch'], site_nat2: Float[Array, ' *batch']) -> tuple[Float[Array, ' *batch'], Float[Array, ' *batch']]
¶
Compute cavity distribution by removing a site from the marginal.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
marg_mean
|
Float[Array, ' *batch']
|
Marginal distribution means. |
required |
marg_var
|
Float[Array, ' *batch']
|
Marginal distribution variances (positive). |
required |
site_nat1
|
Float[Array, ' *batch']
|
Site precision-weighted means to remove. |
required |
site_nat2
|
Float[Array, ' *batch']
|
Site precisions to remove. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ' *batch'], Float[Array, ' *batch']]
|
Tuple |
Source code in src/gaussx/_ssm/_site_natural.py
site_natural_from_tilted(tilted_mean: Float[Array, ' *batch'], tilted_var: Float[Array, ' *batch'], cav_mean: Float[Array, ' *batch'], cav_var: Float[Array, ' *batch']) -> tuple[Float[Array, ' *batch'], Float[Array, ' *batch']]
¶
Compute site natural parameters from tilted and cavity moments.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tilted_mean
|
Float[Array, ' *batch']
|
Tilted distribution means. |
required |
tilted_var
|
Float[Array, ' *batch']
|
Tilted distribution variances (positive). |
required |
cav_mean
|
Float[Array, ' *batch']
|
Cavity distribution means. |
required |
cav_var
|
Float[Array, ' *batch']
|
Cavity distribution variances (positive). |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ' *batch'], Float[Array, ' *batch']]
|
Tuple |
Source code in src/gaussx/_ssm/_site_natural.py
site_mean_var_from_natural(site_nat1: Float[Array, ' *batch'], site_nat2: Float[Array, ' *batch']) -> tuple[Float[Array, ' *batch'], Float[Array, ' *batch']]
¶
Convert per-site natural parameters to mean/variance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
site_nat1
|
Float[Array, ' *batch']
|
Site precision-weighted means. |
required |
site_nat2
|
Float[Array, ' *batch']
|
Site precisions (positive for valid Gaussians). |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ' *batch'], Float[Array, ' *batch']]
|
Tuple |
Source code in src/gaussx/_ssm/_site_natural.py
expectations_to_ssm(eta1: Float[Array, ' Nd'], eta2: BlockTriDiag) -> tuple[Float[Array, 'N d'], Float[Array, 'N d d'], Float[Array, 'Nm1 d d']]
¶
Convert expectation parameters back to SSM marginals.
Recovers (means, covs, cross_covs) from the expectation
parameters of the joint Gaussian.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eta1
|
Float[Array, ' Nd']
|
Concatenated means, shape |
required |
eta2
|
BlockTriDiag
|
Second-moment |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N d']
|
Tuple |
Float[Array, 'N d d']
|
|
Float[Array, 'Nm1 d d']
|
|
tuple[Float[Array, 'N d'], Float[Array, 'N d d'], Float[Array, 'Nm1 d d']]
|
|
Source code in src/gaussx/_ssm/_ssm_natural.py
naturals_to_ssm(theta_linear: Float[Array, ' Nd'], theta_precision: BlockTriDiag, *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, 'Nm1 d d'], Float[Array, 'N d d'], Float[Array, ' d'], Float[Array, 'd d']]
¶
Convert natural parameters back to SSM parameters.
Recovers (A, Q, \mu_0, P_0) from the block-tridiagonal natural
parameters via a backward recurrence on the precision blocks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
theta_linear
|
Float[Array, ' Nd']
|
Natural location parameter, shape |
required |
theta_precision
|
BlockTriDiag
|
Natural precision parameter as
|
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy for structured linear algebra.
When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'Nm1 d d']
|
Tuple |
Float[Array, 'N d d']
|
|
Float[Array, ' d']
|
|
Float[Array, 'd d']
|
|
tuple[Float[Array, 'Nm1 d d'], Float[Array, 'N d d'], Float[Array, ' d'], Float[Array, 'd d']]
|
|
Source code in src/gaussx/_ssm/_ssm_natural.py
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | |
ssm_to_expectations(means: Float[Array, 'N d'], covs: Float[Array, 'N d d'], cross_covs: Float[Array, 'Nm1 d d']) -> tuple[Float[Array, ' Nd'], BlockTriDiag]
¶
Convert SSM marginals to expectation parameters.
Given filtered or smoothed marginals, computes the expectation
parameters (eta1, eta2) of the joint Gaussian where:
eta1 = E[x](concatenated means)eta2is aBlockTriDiagstoring the block-tridiagonal subset ofE[xx^T](second moments matching the Gauss-Markov sparsity pattern, not the full dense matrix)
The diagonal blocks of eta2 are E[x_k x_k^T] = P_k + m_k m_k^T
and the sub-diagonal blocks are
E[x_{k+1} x_k^T] = C_k + m_{k+1} m_k^T where C_k is the
cross-covariance Cov(x_{k+1}, x_k).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
means
|
Float[Array, 'N d']
|
Marginal means, shape |
required |
covs
|
Float[Array, 'N d d']
|
Marginal covariances, shape |
required |
cross_covs
|
Float[Array, 'Nm1 d d']
|
Cross-covariances |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ' Nd']
|
Tuple |
BlockTriDiag
|
and |
Source code in src/gaussx/_ssm/_ssm_natural.py
ssm_to_naturals(A: Float[Array, 'Nm1 d d'], Q: Float[Array, 'N d d'], mu_0: Float[Array, ' d'], P_0: Float[Array, 'd d'], *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, ' Nd'], BlockTriDiag]
¶
Convert SSM parameters to natural parameters.
For a linear-Gaussian state-space model:
x_0 \sim N(\mu_0, P_0)
x_{k+1} = A_k x_k + \epsilon_k,\quad \epsilon_k \sim N(0, Q_{k+1})
the joint prior p(x_0, \ldots, x_{N-1}) has a block-tridiagonal
precision matrix. This function returns its natural parameters
(\theta_1, \theta_2) where \theta_2 = -\tfrac{1}{2}\Lambda
(matching the convention in gaussx.mean_cov_to_natural).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
Float[Array, 'Nm1 d d']
|
Transition matrices, shape |
required |
Q
|
Float[Array, 'N d d']
|
Process noise covariances, shape |
required |
mu_0
|
Float[Array, ' d']
|
Initial mean, shape |
required |
P_0
|
Float[Array, 'd d']
|
Initial covariance, shape |
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy for structured linear algebra.
When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' Nd']
|
Tuple |
BlockTriDiag
|
|
tuple[Float[Array, ' Nd'], BlockTriDiag]
|
|
tuple[Float[Array, ' Nd'], BlockTriDiag]
|
in the |
Source code in src/gaussx/_ssm/_ssm_natural.py
Process noise¶
Structured linear algebra and Gaussian primitives for JAX.
process_noise_covariance(A: Float[Array, 'N N'], Pinf: Float[Array, 'N N']) -> Float[Array, 'N N']
¶
Compute process noise from stationary covariance.
Computes:
Q = Pinf - A @ Pinf @ A^T
For a discrete-time state-space model with stationary covariance
Pinf and transition matrix A.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
Float[Array, 'N N']
|
State transition matrix, shape |
required |
Pinf
|
Float[Array, 'N N']
|
Stationary covariance, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N N']
|
Process noise covariance Q, shape |