Skip to content

State-Space Models & Kalman

Layer 3 recipes for linear-Gaussian state-space models. Stationary 1-D GP kernels with rational spectral densities admit exact SDE representations

\[ \dot{x}(t) = F\,x(t) + L\,w(t), \qquad f(t) = H\,x(t), \]

turning \(O(N^3)\) GP inference into \(O(N d^3)\) Kalman filtering. This page covers the SDE kernel zoo, the filters and smoothers (sequential, parallel associative-scan, square-root, and steady-state), and the natural-parameter / site machinery for non-conjugate likelihoods.

SDE kernels

Structured linear algebra and Gaussian primitives for JAX.

SDEKernel

Bases: Module

Abstract base class for state-space kernel representations.

Subclasses implement sde_params to provide the continuous-time SDE matrices (F, L, H, Q_c, P_inf). The default discretise uses the matrix exponential for discretization; subclasses may override with closed-form solutions.

Source code in src/gaussx/_ssm/_sde_kernel.py
class SDEKernel(eqx.Module):
    """Abstract base class for state-space kernel representations.

    Subclasses implement `sde_params` to provide the continuous-time
    SDE matrices ``(F, L, H, Q_c, P_inf)``. The default `discretise`
    uses the matrix exponential for discretization; subclasses may override
    with closed-form solutions.
    """

    @property
    @abc.abstractmethod
    def state_dim(self) -> int:
        """Dimension of the latent state vector."""
        ...

    @abc.abstractmethod
    def sde_params(self) -> SDEParams:
        """Return continuous-time SDE parameters."""
        ...

    def discretise(
        self,
        dt: Float[Array, ""],
    ) -> tuple[Float[Array, "d d"], Float[Array, "d d"]]:
        """Discretise the SDE at time step ``dt``.

        Default implementation computes:

            A = expm(F * dt)
            Q = P_inf - A @ P_inf @ A^T

        Subclasses may override with closed-form expressions.

        Args:
            dt: Time step (scalar, positive).

        Returns:
            Tuple ``(A, Q)`` where A is the transition matrix and
            Q is the process noise covariance.
        """
        params = self.sde_params()
        A = jsl.expm(params.F * dt)
        Q = params.P_inf - A @ params.P_inf @ A.T
        Q = symmetrize(Q)
        return A, Q

    def discretise_sequence(
        self,
        dt: Float[Array, " N"],
    ) -> tuple[Float[Array, "N d d"], Float[Array, "N d d"]]:
        """Discretise the SDE at multiple time steps.

        Args:
            dt: Time steps, shape ``(N,)``.

        Returns:
            Tuple ``(A_seq, Q_seq)`` with shapes ``(N, d, d)``.
        """
        return jax.vmap(self.discretise)(dt)

state_dim: int abstractmethod property

Dimension of the latent state vector.

sde_params() -> SDEParams abstractmethod

Return continuous-time SDE parameters.

Source code in src/gaussx/_ssm/_sde_kernel.py
@abc.abstractmethod
def sde_params(self) -> SDEParams:
    """Return continuous-time SDE parameters."""
    ...

discretise(dt: Float[Array, '']) -> tuple[Float[Array, 'd d'], Float[Array, 'd d']]

Discretise the SDE at time step dt.

Default implementation computes:

A = expm(F * dt)
Q = P_inf - A @ P_inf @ A^T

Subclasses may override with closed-form expressions.

Parameters:

Name Type Description Default
dt Float[Array, '']

Time step (scalar, positive).

required

Returns:

Type Description
Float[Array, 'd d']

Tuple (A, Q) where A is the transition matrix and

Float[Array, 'd d']

Q is the process noise covariance.

Source code in src/gaussx/_ssm/_sde_kernel.py
def discretise(
    self,
    dt: Float[Array, ""],
) -> tuple[Float[Array, "d d"], Float[Array, "d d"]]:
    """Discretise the SDE at time step ``dt``.

    Default implementation computes:

        A = expm(F * dt)
        Q = P_inf - A @ P_inf @ A^T

    Subclasses may override with closed-form expressions.

    Args:
        dt: Time step (scalar, positive).

    Returns:
        Tuple ``(A, Q)`` where A is the transition matrix and
        Q is the process noise covariance.
    """
    params = self.sde_params()
    A = jsl.expm(params.F * dt)
    Q = params.P_inf - A @ params.P_inf @ A.T
    Q = symmetrize(Q)
    return A, Q

discretise_sequence(dt: Float[Array, ' N']) -> tuple[Float[Array, 'N d d'], Float[Array, 'N d d']]

Discretise the SDE at multiple time steps.

Parameters:

Name Type Description Default
dt Float[Array, ' N']

Time steps, shape (N,).

required

Returns:

Type Description
tuple[Float[Array, 'N d d'], Float[Array, 'N d d']]

Tuple (A_seq, Q_seq) with shapes (N, d, d).

Source code in src/gaussx/_ssm/_sde_kernel.py
def discretise_sequence(
    self,
    dt: Float[Array, " N"],
) -> tuple[Float[Array, "N d d"], Float[Array, "N d d"]]:
    """Discretise the SDE at multiple time steps.

    Args:
        dt: Time steps, shape ``(N,)``.

    Returns:
        Tuple ``(A_seq, Q_seq)`` with shapes ``(N, d, d)``.
    """
    return jax.vmap(self.discretise)(dt)

SDEParams

Bases: NamedTuple

Continuous-time SDE parameters for a stationary kernel.

Defines the linear time-invariant SDE:

dx = F x dt + L dW,   W ~ N(0, Q_c dt)

with observation model y = H x.

Attributes:

Name Type Description
F Float[Array, 'd d']

Drift matrix, shape (d, d).

L Float[Array, 'd s']

Diffusion matrix, shape (d, s).

H Float[Array, '1 d']

Observation matrix, shape (1, d).

Q_c Float[Array, 's s']

Spectral density, shape (s, s).

P_inf Float[Array, 'd d']

Stationary covariance, shape (d, d).

Source code in src/gaussx/_ssm/_sde_kernel.py
class SDEParams(NamedTuple):
    """Continuous-time SDE parameters for a stationary kernel.

    Defines the linear time-invariant SDE:

        dx = F x dt + L dW,   W ~ N(0, Q_c dt)

    with observation model ``y = H x``.

    Attributes:
        F: Drift matrix, shape ``(d, d)``.
        L: Diffusion matrix, shape ``(d, s)``.
        H: Observation matrix, shape ``(1, d)``.
        Q_c: Spectral density, shape ``(s, s)``.
        P_inf: Stationary covariance, shape ``(d, d)``.
    """

    F: Float[Array, "d d"]
    L: Float[Array, "d s"]
    H: Float[Array, "1 d"]
    Q_c: Float[Array, "s s"]
    P_inf: Float[Array, "d d"]

ConstantSDE

Bases: SDEKernel

State-space representation of a constant kernel.

Models \(k(\tau) = \sigma^2\) — a degenerate kernel with zero dynamics and zero diffusion. State dimension is 1.

Attributes:

Name Type Description
variance Float[Array, '']

Signal variance \(\sigma^2\).

Source code in src/gaussx/_ssm/_constant.py
class ConstantSDE(SDEKernel):
    r"""State-space representation of a constant kernel.

    Models $k(\tau) = \sigma^2$ — a degenerate kernel with zero
    dynamics and zero diffusion. State dimension is 1.

    Attributes:
        variance: Signal variance $\sigma^2$.
    """

    variance: Float[Array, ""]

    @property
    def state_dim(self) -> int:
        return 1

    def sde_params(self) -> SDEParams:
        """Return SDE parameters for the constant kernel."""
        F = jnp.zeros((1, 1))
        L = jnp.zeros((1, 1))
        H = jnp.array([[1.0]])
        Q_c = jnp.zeros((1, 1))
        P_inf = jnp.array([[self.variance]])
        return SDEParams(F=F, L=L, H=H, Q_c=Q_c, P_inf=P_inf)

    def discretise(
        self,
        dt: Float[Array, ""],
    ) -> tuple[Float[Array, "d d"], Float[Array, "d d"]]:
        """Closed-form: A = I, Q = 0 (no dynamics)."""
        A = jnp.eye(1)
        Q = jnp.zeros((1, 1))
        return A, Q

sde_params() -> SDEParams

Return SDE parameters for the constant kernel.

Source code in src/gaussx/_ssm/_constant.py
def sde_params(self) -> SDEParams:
    """Return SDE parameters for the constant kernel."""
    F = jnp.zeros((1, 1))
    L = jnp.zeros((1, 1))
    H = jnp.array([[1.0]])
    Q_c = jnp.zeros((1, 1))
    P_inf = jnp.array([[self.variance]])
    return SDEParams(F=F, L=L, H=H, Q_c=Q_c, P_inf=P_inf)

discretise(dt: Float[Array, '']) -> tuple[Float[Array, 'd d'], Float[Array, 'd d']]

Closed-form: A = I, Q = 0 (no dynamics).

Source code in src/gaussx/_ssm/_constant.py
def discretise(
    self,
    dt: Float[Array, ""],
) -> tuple[Float[Array, "d d"], Float[Array, "d d"]]:
    """Closed-form: A = I, Q = 0 (no dynamics)."""
    A = jnp.eye(1)
    Q = jnp.zeros((1, 1))
    return A, Q

MaternSDE

Bases: SDEKernel

State-space representation of the Matern kernel.

Supports orders 0 (Matern-1/2), 1 (Matern-3/2), and 2 (Matern-5/2). The state dimension is order + 1.

Attributes:

Name Type Description
variance Float[Array, '']

Signal variance \(\sigma^2\).

lengthscale Float[Array, '']

Lengthscale \(\ell\).

order int

Matern order (0, 1, or 2).

Source code in src/gaussx/_ssm/_matern.py
class MaternSDE(SDEKernel):
    r"""State-space representation of the Matern kernel.

    Supports orders 0 (Matern-1/2), 1 (Matern-3/2), and 2 (Matern-5/2).
    The state dimension is ``order + 1``.

    Attributes:
        variance: Signal variance $\sigma^2$.
        lengthscale: Lengthscale $\ell$.
        order: Matern order (0, 1, or 2).
    """

    variance: Float[Array, ""]
    lengthscale: Float[Array, ""]
    order: int = eqx.field(static=True)

    @property
    def state_dim(self) -> int:
        return self.order + 1

    def sde_params(self) -> SDEParams:
        """Compute SDE parameters for the Matern kernel."""
        if self.order == 0:
            return self._matern12()
        elif self.order == 1:
            return self._matern32()
        elif self.order == 2:
            return self._matern52()
        else:
            msg = f"Unsupported Matern order {self.order}; must be 0, 1, or 2"
            raise ValueError(msg)

    def _matern12(self) -> SDEParams:
        lam = 1.0 / self.lengthscale
        F = jnp.array([[-lam]])
        L = jnp.array([[1.0]])
        H = jnp.array([[1.0]])
        Q_c = jnp.array([[2.0 * lam * self.variance]])
        P_inf = jnp.array([[self.variance]])
        return SDEParams(F=F, L=L, H=H, Q_c=Q_c, P_inf=P_inf)

    def _matern32(self) -> SDEParams:
        lam = jnp.sqrt(3.0) / self.lengthscale
        F = jnp.array([[0.0, 1.0], [-(lam**2), -2.0 * lam]])
        L = jnp.array([[0.0], [1.0]])
        H = jnp.array([[1.0, 0.0]])
        q = 4.0 * lam**3 * self.variance
        Q_c = jnp.array([[q]])
        P_inf = jnp.array([[self.variance, 0.0], [0.0, lam**2 * self.variance]])
        return SDEParams(F=F, L=L, H=H, Q_c=Q_c, P_inf=P_inf)

    def _matern52(self) -> SDEParams:
        lam = jnp.sqrt(5.0) / self.lengthscale
        F = jnp.array(
            [
                [0.0, 1.0, 0.0],
                [0.0, 0.0, 1.0],
                [-(lam**3), -3.0 * lam**2, -3.0 * lam],
            ]
        )
        L = jnp.array([[0.0], [0.0], [1.0]])
        H = jnp.array([[1.0, 0.0, 0.0]])
        kappa = 5.0 / 3.0 * self.variance / self.lengthscale**2
        q = 16.0 / 3.0 * lam**5 * self.variance
        Q_c = jnp.array([[q]])
        P_inf = jnp.array(
            [
                [self.variance, 0.0, -kappa],
                [0.0, kappa, 0.0],
                [-kappa, 0.0, lam**4 * self.variance],
            ]
        )
        return SDEParams(F=F, L=L, H=H, Q_c=Q_c, P_inf=P_inf)

sde_params() -> SDEParams

Compute SDE parameters for the Matern kernel.

Source code in src/gaussx/_ssm/_matern.py
def sde_params(self) -> SDEParams:
    """Compute SDE parameters for the Matern kernel."""
    if self.order == 0:
        return self._matern12()
    elif self.order == 1:
        return self._matern32()
    elif self.order == 2:
        return self._matern52()
    else:
        msg = f"Unsupported Matern order {self.order}; must be 0, 1, or 2"
        raise ValueError(msg)

PeriodicSDE

Bases: SDEKernel

State-space representation of the periodic (MacKay) kernel.

Approximates the periodic kernel via Fourier series truncation to n_harmonics terms. State dimension is 2 * n_harmonics.

Attributes:

Name Type Description
variance Float[Array, '']

Signal variance \(\sigma^2\).

lengthscale Float[Array, '']

Lengthscale \(\ell\).

period Float[Array, '']

Period \(T\).

n_harmonics int

Number of Fourier harmonics (truncation order).

Source code in src/gaussx/_ssm/_periodic.py
class PeriodicSDE(SDEKernel):
    r"""State-space representation of the periodic (MacKay) kernel.

    Approximates the periodic kernel via Fourier series truncation
    to ``n_harmonics`` terms. State dimension is ``2 * n_harmonics``.

    Attributes:
        variance: Signal variance $\sigma^2$.
        lengthscale: Lengthscale $\ell$.
        period: Period $T$.
        n_harmonics: Number of Fourier harmonics (truncation order).
    """

    variance: Float[Array, ""]
    lengthscale: Float[Array, ""]
    period: Float[Array, ""]
    n_harmonics: int = eqx.field(static=True, default=6)

    @property
    def state_dim(self) -> int:
        return 2 * self.n_harmonics

    def sde_params(self) -> SDEParams:
        """Return SDE parameters for the periodic kernel."""
        J = self.n_harmonics
        d = 2 * J
        w0 = 2.0 * jnp.pi / self.period

        inv_ell_sq = 1.0 / self.lengthscale**2
        js = jnp.arange(1, J + 1)
        log_ij = self._log_bessel_i(js, inv_ell_sq)
        log_q = jnp.log(2.0) + log_ij - inv_ell_sq
        q_j = self.variance * jnp.exp(log_q)

        F = jnp.zeros((d, d))
        P_inf = jnp.zeros((d, d))
        for j_idx in range(J):
            freq = (j_idx + 1) * w0
            block_start = 2 * j_idx
            F = F.at[block_start, block_start + 1].set(-freq)
            F = F.at[block_start + 1, block_start].set(freq)
            P_inf = P_inf.at[block_start, block_start].set(q_j[j_idx])
            P_inf = P_inf.at[block_start + 1, block_start + 1].set(q_j[j_idx])

        L = jnp.zeros((d, 1))
        H = jnp.zeros((1, d))
        for j_idx in range(J):
            H = H.at[0, 2 * j_idx].set(1.0)

        Q_c = jnp.zeros((1, 1))
        return SDEParams(F=F, L=L, H=H, Q_c=Q_c, P_inf=P_inf)

    def discretise(
        self,
        dt: Float[Array, ""],
    ) -> tuple[Float[Array, "d d"], Float[Array, "d d"]]:
        """Closed-form: block-diagonal rotation matrices."""
        J = self.n_harmonics
        d = 2 * J
        w0 = 2.0 * jnp.pi / self.period

        A = jnp.zeros((d, d))
        for j_idx in range(J):
            freq = (j_idx + 1) * w0
            cos_val = jnp.cos(freq * dt)
            sin_val = jnp.sin(freq * dt)
            block_start = 2 * j_idx
            A = A.at[block_start, block_start].set(cos_val)
            A = A.at[block_start, block_start + 1].set(-sin_val)
            A = A.at[block_start + 1, block_start].set(sin_val)
            A = A.at[block_start + 1, block_start + 1].set(cos_val)

        Q = jnp.zeros((d, d))
        return A, Q

    @staticmethod
    def _log_bessel_i(
        order: Float[Array, " J"],
        x: Float[Array, ""],
    ) -> Float[Array, " J"]:
        """Log of modified Bessel function I_n(x) via series."""
        half_x = x / 2.0
        log_half_x = jnp.log(half_x)

        log_leading = order * log_half_x - jss.gammaln(order + 1.0)

        x2_over_4 = x**2 / 4.0
        K = 20
        log_sum = jnp.zeros_like(order)
        log_term = jnp.zeros_like(order)
        for k in range(1, K + 1):
            log_term = (
                log_term
                + jnp.log(x2_over_4)
                - jnp.log(jnp.array(k, dtype=order.dtype))
                - jnp.log(order + k)
            )
            log_sum = jnp.logaddexp(log_sum, log_term)

        return log_leading + log_sum

sde_params() -> SDEParams

Return SDE parameters for the periodic kernel.

Source code in src/gaussx/_ssm/_periodic.py
def sde_params(self) -> SDEParams:
    """Return SDE parameters for the periodic kernel."""
    J = self.n_harmonics
    d = 2 * J
    w0 = 2.0 * jnp.pi / self.period

    inv_ell_sq = 1.0 / self.lengthscale**2
    js = jnp.arange(1, J + 1)
    log_ij = self._log_bessel_i(js, inv_ell_sq)
    log_q = jnp.log(2.0) + log_ij - inv_ell_sq
    q_j = self.variance * jnp.exp(log_q)

    F = jnp.zeros((d, d))
    P_inf = jnp.zeros((d, d))
    for j_idx in range(J):
        freq = (j_idx + 1) * w0
        block_start = 2 * j_idx
        F = F.at[block_start, block_start + 1].set(-freq)
        F = F.at[block_start + 1, block_start].set(freq)
        P_inf = P_inf.at[block_start, block_start].set(q_j[j_idx])
        P_inf = P_inf.at[block_start + 1, block_start + 1].set(q_j[j_idx])

    L = jnp.zeros((d, 1))
    H = jnp.zeros((1, d))
    for j_idx in range(J):
        H = H.at[0, 2 * j_idx].set(1.0)

    Q_c = jnp.zeros((1, 1))
    return SDEParams(F=F, L=L, H=H, Q_c=Q_c, P_inf=P_inf)

discretise(dt: Float[Array, '']) -> tuple[Float[Array, 'd d'], Float[Array, 'd d']]

Closed-form: block-diagonal rotation matrices.

Source code in src/gaussx/_ssm/_periodic.py
def discretise(
    self,
    dt: Float[Array, ""],
) -> tuple[Float[Array, "d d"], Float[Array, "d d"]]:
    """Closed-form: block-diagonal rotation matrices."""
    J = self.n_harmonics
    d = 2 * J
    w0 = 2.0 * jnp.pi / self.period

    A = jnp.zeros((d, d))
    for j_idx in range(J):
        freq = (j_idx + 1) * w0
        cos_val = jnp.cos(freq * dt)
        sin_val = jnp.sin(freq * dt)
        block_start = 2 * j_idx
        A = A.at[block_start, block_start].set(cos_val)
        A = A.at[block_start, block_start + 1].set(-sin_val)
        A = A.at[block_start + 1, block_start].set(sin_val)
        A = A.at[block_start + 1, block_start + 1].set(cos_val)

    Q = jnp.zeros((d, d))
    return A, Q

QuasiPeriodicSDE

Bases: ProductSDE

Quasi-periodic kernel: product of Matern and Periodic SDE.

Attributes:

Name Type Description
kernel1 SDEKernel

Modulating kernel (typically Matern).

kernel2 SDEKernel

Periodic kernel.

Source code in src/gaussx/_ssm/_composition.py
class QuasiPeriodicSDE(ProductSDE):
    """Quasi-periodic kernel: product of Matern and Periodic SDE.

    Attributes:
        kernel1: Modulating kernel (typically Matern).
        kernel2: Periodic kernel.
    """

    pass

CosineSDE

Bases: SDEKernel

State-space representation of the cosine kernel.

Models \(k(\tau) = \sigma^2 \cos(\omega_0 \tau)\) via a 2-D rotation SDE. State dimension is 2.

Attributes:

Name Type Description
variance Float[Array, '']

Signal variance \(\sigma^2\).

frequency Float[Array, '']

Angular frequency \(\omega_0\).

Source code in src/gaussx/_ssm/_periodic.py
class CosineSDE(SDEKernel):
    r"""State-space representation of the cosine kernel.

    Models $k(\tau) = \sigma^2 \cos(\omega_0 \tau)$ via a 2-D
    rotation SDE. State dimension is 2.

    Attributes:
        variance: Signal variance $\sigma^2$.
        frequency: Angular frequency $\omega_0$.
    """

    variance: Float[Array, ""]
    frequency: Float[Array, ""]

    @property
    def state_dim(self) -> int:
        return 2

    def sde_params(self) -> SDEParams:
        """Return SDE parameters for the cosine kernel."""
        w = self.frequency
        F = jnp.array([[0.0, -w], [w, 0.0]])
        L = jnp.zeros((2, 1))
        H = jnp.array([[1.0, 0.0]])
        Q_c = jnp.zeros((1, 1))
        P_inf = self.variance * jnp.eye(2)
        return SDEParams(F=F, L=L, H=H, Q_c=Q_c, P_inf=P_inf)

    def discretise(
        self,
        dt: Float[Array, ""],
    ) -> tuple[Float[Array, "d d"], Float[Array, "d d"]]:
        """Closed-form rotation matrix discretization."""
        w = self.frequency
        cos_wdt = jnp.cos(w * dt)
        sin_wdt = jnp.sin(w * dt)
        A = jnp.array([[cos_wdt, -sin_wdt], [sin_wdt, cos_wdt]])
        Q = jnp.zeros((2, 2))
        return A, Q

sde_params() -> SDEParams

Return SDE parameters for the cosine kernel.

Source code in src/gaussx/_ssm/_periodic.py
def sde_params(self) -> SDEParams:
    """Return SDE parameters for the cosine kernel."""
    w = self.frequency
    F = jnp.array([[0.0, -w], [w, 0.0]])
    L = jnp.zeros((2, 1))
    H = jnp.array([[1.0, 0.0]])
    Q_c = jnp.zeros((1, 1))
    P_inf = self.variance * jnp.eye(2)
    return SDEParams(F=F, L=L, H=H, Q_c=Q_c, P_inf=P_inf)

discretise(dt: Float[Array, '']) -> tuple[Float[Array, 'd d'], Float[Array, 'd d']]

Closed-form rotation matrix discretization.

Source code in src/gaussx/_ssm/_periodic.py
def discretise(
    self,
    dt: Float[Array, ""],
) -> tuple[Float[Array, "d d"], Float[Array, "d d"]]:
    """Closed-form rotation matrix discretization."""
    w = self.frequency
    cos_wdt = jnp.cos(w * dt)
    sin_wdt = jnp.sin(w * dt)
    A = jnp.array([[cos_wdt, -sin_wdt], [sin_wdt, cos_wdt]])
    Q = jnp.zeros((2, 2))
    return A, Q

ProductSDE

Bases: SDEKernel

Product of two SDE kernels via Kronecker composition.

Attributes:

Name Type Description
kernel1 SDEKernel

First component kernel.

kernel2 SDEKernel

Second component kernel.

Source code in src/gaussx/_ssm/_composition.py
class ProductSDE(SDEKernel):
    """Product of two SDE kernels via Kronecker composition.

    Attributes:
        kernel1: First component kernel.
        kernel2: Second component kernel.
    """

    kernel1: SDEKernel
    kernel2: SDEKernel

    @property
    def state_dim(self) -> int:
        return self.kernel1.state_dim * self.kernel2.state_dim

    def sde_params(self) -> SDEParams:
        """Return Kronecker-structured SDE parameters.

        Note:
            ``SDEParams`` currently types its fields as dense
            ``jaxtyping.Float[Array, ...]``. The Kronecker products
            below are dense materializations of size
            ``(state_dim, state_dim)``, where ``state_dim`` is
            ``kernel1.state_dim * kernel2.state_dim`` — for typical SSM
            kernels (Matérn-3/2, periodic) this is ≤ 32, so the
            materialization is bounded and cheap. A future refactor
            could expose a parallel ``sde_operators()`` method that
            returns `gaussx.Kronecker` operators for downstream
            filters that can exploit the structure (issue #153).
        """
        p1 = self.kernel1.sde_params()
        p2 = self.kernel2.sde_params()

        d1 = self.kernel1.state_dim
        d2 = self.kernel2.state_dim

        F = jnp.kron(p1.F, jnp.eye(d2)) + jnp.kron(jnp.eye(d1), p2.F)
        L = jnp.kron(p1.L, p2.L)
        H = jnp.kron(p1.H, p2.H)
        Q_c = jnp.kron(p1.Q_c, p2.Q_c)
        P_inf = jnp.kron(p1.P_inf, p2.P_inf)

        return SDEParams(F=F, L=L, H=H, Q_c=Q_c, P_inf=P_inf)

    def discretise(
        self,
        dt: Float[Array, ""],
    ) -> tuple[Float[Array, "d d"], Float[Array, "d d"]]:
        r"""Discretise via the Kronecker matrix-exponential identity.

        For a product kernel ``F = F_1 \oplus F_2 = F_1 \otimes I + I \otimes F_2``,
        the factors ``F_1 \otimes I`` and ``I \otimes F_2`` commute, so

        $$
        \exp(F \, dt) = \exp(F_1 \, dt) \otimes \exp(F_2 \, dt).
        $$

        This computes two ``expm`` calls of size ``d_1`` and ``d_2``
        each, plus one Kronecker product, instead of one ``expm`` of
        size ``d_1 \cdot d_2``. Numerically equivalent to the dense
        ``expm`` on ``F`` but cheaper for moderate factor sizes.

        ``Q = P_\infty - A P_\infty A^T`` is computed densely from the
        resulting ``A``; with ``P_\infty = P_{\infty,1} \otimes
        P_{\infty,2}`` this could itself be expressed as a Kronecker
        difference, but is left dense to keep the consumer-facing
        ``(A, Q)`` interface unchanged.

        Args:
            dt: Time step (scalar, positive).

        Returns:
            Tuple ``(A, Q)`` matching `SDEKernel.discretise`.
        """
        p1 = self.kernel1.sde_params()
        p2 = self.kernel2.sde_params()
        A1 = jsl.expm(p1.F * dt)
        A2 = jsl.expm(p2.F * dt)
        A = jnp.kron(A1, A2)

        # Use the per-factor stationary covariances directly; building
        # the full ``F`` via ``self.sde_params()`` would defeat the
        # whole point of this override.
        P_inf = jnp.kron(p1.P_inf, p2.P_inf)
        Q = P_inf - A @ P_inf @ A.T
        Q = symmetrize(Q)
        return A, Q

sde_params() -> SDEParams

Return Kronecker-structured SDE parameters.

Note

SDEParams currently types its fields as dense jaxtyping.Float[Array, ...]. The Kronecker products below are dense materializations of size (state_dim, state_dim), where state_dim is kernel1.state_dim * kernel2.state_dim — for typical SSM kernels (Matérn-3/2, periodic) this is ≤ 32, so the materialization is bounded and cheap. A future refactor could expose a parallel sde_operators() method that returns gaussx.Kronecker operators for downstream filters that can exploit the structure (issue #153).

Source code in src/gaussx/_ssm/_composition.py
def sde_params(self) -> SDEParams:
    """Return Kronecker-structured SDE parameters.

    Note:
        ``SDEParams`` currently types its fields as dense
        ``jaxtyping.Float[Array, ...]``. The Kronecker products
        below are dense materializations of size
        ``(state_dim, state_dim)``, where ``state_dim`` is
        ``kernel1.state_dim * kernel2.state_dim`` — for typical SSM
        kernels (Matérn-3/2, periodic) this is ≤ 32, so the
        materialization is bounded and cheap. A future refactor
        could expose a parallel ``sde_operators()`` method that
        returns `gaussx.Kronecker` operators for downstream
        filters that can exploit the structure (issue #153).
    """
    p1 = self.kernel1.sde_params()
    p2 = self.kernel2.sde_params()

    d1 = self.kernel1.state_dim
    d2 = self.kernel2.state_dim

    F = jnp.kron(p1.F, jnp.eye(d2)) + jnp.kron(jnp.eye(d1), p2.F)
    L = jnp.kron(p1.L, p2.L)
    H = jnp.kron(p1.H, p2.H)
    Q_c = jnp.kron(p1.Q_c, p2.Q_c)
    P_inf = jnp.kron(p1.P_inf, p2.P_inf)

    return SDEParams(F=F, L=L, H=H, Q_c=Q_c, P_inf=P_inf)

discretise(dt: Float[Array, '']) -> tuple[Float[Array, 'd d'], Float[Array, 'd d']]

Discretise via the Kronecker matrix-exponential identity.

For a product kernel F = F_1 \oplus F_2 = F_1 \otimes I + I \otimes F_2, the factors F_1 \otimes I and I \otimes F_2 commute, so

\[ \exp(F \, dt) = \exp(F_1 \, dt) \otimes \exp(F_2 \, dt). \]

This computes two expm calls of size d_1 and d_2 each, plus one Kronecker product, instead of one expm of size d_1 \cdot d_2. Numerically equivalent to the dense expm on F but cheaper for moderate factor sizes.

Q = P_\infty - A P_\infty A^T is computed densely from the resulting A; with P_\infty = P_{\infty,1} \otimes P_{\infty,2} this could itself be expressed as a Kronecker difference, but is left dense to keep the consumer-facing (A, Q) interface unchanged.

Parameters:

Name Type Description Default
dt Float[Array, '']

Time step (scalar, positive).

required

Returns:

Type Description
tuple[Float[Array, 'd d'], Float[Array, 'd d']]

Tuple (A, Q) matching SDEKernel.discretise.

Source code in src/gaussx/_ssm/_composition.py
def discretise(
    self,
    dt: Float[Array, ""],
) -> tuple[Float[Array, "d d"], Float[Array, "d d"]]:
    r"""Discretise via the Kronecker matrix-exponential identity.

    For a product kernel ``F = F_1 \oplus F_2 = F_1 \otimes I + I \otimes F_2``,
    the factors ``F_1 \otimes I`` and ``I \otimes F_2`` commute, so

    $$
    \exp(F \, dt) = \exp(F_1 \, dt) \otimes \exp(F_2 \, dt).
    $$

    This computes two ``expm`` calls of size ``d_1`` and ``d_2``
    each, plus one Kronecker product, instead of one ``expm`` of
    size ``d_1 \cdot d_2``. Numerically equivalent to the dense
    ``expm`` on ``F`` but cheaper for moderate factor sizes.

    ``Q = P_\infty - A P_\infty A^T`` is computed densely from the
    resulting ``A``; with ``P_\infty = P_{\infty,1} \otimes
    P_{\infty,2}`` this could itself be expressed as a Kronecker
    difference, but is left dense to keep the consumer-facing
    ``(A, Q)`` interface unchanged.

    Args:
        dt: Time step (scalar, positive).

    Returns:
        Tuple ``(A, Q)`` matching `SDEKernel.discretise`.
    """
    p1 = self.kernel1.sde_params()
    p2 = self.kernel2.sde_params()
    A1 = jsl.expm(p1.F * dt)
    A2 = jsl.expm(p2.F * dt)
    A = jnp.kron(A1, A2)

    # Use the per-factor stationary covariances directly; building
    # the full ``F`` via ``self.sde_params()`` would defeat the
    # whole point of this override.
    P_inf = jnp.kron(p1.P_inf, p2.P_inf)
    Q = P_inf - A @ P_inf @ A.T
    Q = symmetrize(Q)
    return A, Q

SumSDE

Bases: SDEKernel

Sum of SDE kernels via block-diagonal composition.

Attributes:

Name Type Description
kernels tuple[SDEKernel, ...]

Tuple of component SDE kernels.

Source code in src/gaussx/_ssm/_composition.py
class SumSDE(SDEKernel):
    """Sum of SDE kernels via block-diagonal composition.

    Attributes:
        kernels: Tuple of component SDE kernels.
    """

    kernels: tuple[SDEKernel, ...] = eqx.field()

    @property
    def state_dim(self) -> int:
        return sum(k.state_dim for k in self.kernels)

    def sde_params(self) -> SDEParams:
        """Return block-diagonal SDE parameters."""
        params_list = [k.sde_params() for k in self.kernels]

        F = jsl.block_diag(*[p.F for p in params_list])
        P_inf = jsl.block_diag(*[p.P_inf for p in params_list])

        L_blocks = [p.L for p in params_list]
        total_rows = sum(b.shape[0] for b in L_blocks)
        total_cols = sum(b.shape[1] for b in L_blocks)
        L = jnp.zeros((total_rows, total_cols))
        row_offset = 0
        col_offset = 0
        for block in L_blocks:
            r, c = block.shape
            L = L.at[row_offset : row_offset + r, col_offset : col_offset + c].set(
                block
            )
            row_offset += r
            col_offset += c

        Q_c = jsl.block_diag(*[p.Q_c for p in params_list])
        H = jnp.concatenate([p.H for p in params_list], axis=1)

        return SDEParams(F=F, L=L, H=H, Q_c=Q_c, P_inf=P_inf)

sde_params() -> SDEParams

Return block-diagonal SDE parameters.

Source code in src/gaussx/_ssm/_composition.py
def sde_params(self) -> SDEParams:
    """Return block-diagonal SDE parameters."""
    params_list = [k.sde_params() for k in self.kernels]

    F = jsl.block_diag(*[p.F for p in params_list])
    P_inf = jsl.block_diag(*[p.P_inf for p in params_list])

    L_blocks = [p.L for p in params_list]
    total_rows = sum(b.shape[0] for b in L_blocks)
    total_cols = sum(b.shape[1] for b in L_blocks)
    L = jnp.zeros((total_rows, total_cols))
    row_offset = 0
    col_offset = 0
    for block in L_blocks:
        r, c = block.shape
        L = L.at[row_offset : row_offset + r, col_offset : col_offset + c].set(
            block
        )
        row_offset += r
        col_offset += c

    Q_c = jsl.block_diag(*[p.Q_c for p in params_list])
    H = jnp.concatenate([p.H for p in params_list], axis=1)

    return SDEParams(F=F, L=L, H=H, Q_c=Q_c, P_inf=P_inf)

sde_autocovariance(kernel: SDEKernel, tau: Float[Array, ' *batch']) -> Float[Array, ' *batch']

Compute the stationary autocovariance of an SDE kernel.

Evaluates:

K(\tau) = H \, \exp(F |\tau|) \, P_\infty \, H^T

Parameters:

Name Type Description Default
kernel SDEKernel

An SDE kernel with sde_params() method.

required
tau Float[Array, ' *batch']

Lag values, shape (*batch,).

required

Returns:

Type Description
Float[Array, ' *batch']

Autocovariance values K(tau), shape (*batch,).

Source code in src/gaussx/_ssm/_autocovariance.py
def sde_autocovariance(
    kernel: SDEKernel,
    tau: Float[Array, " *batch"],
) -> Float[Array, " *batch"]:
    r"""Compute the stationary autocovariance of an SDE kernel.

    Evaluates:

        K(\tau) = H \, \exp(F |\tau|) \, P_\infty \, H^T

    Args:
        kernel: An SDE kernel with ``sde_params()`` method.
        tau: Lag values, shape ``(*batch,)``.

    Returns:
        Autocovariance values ``K(tau)``, shape ``(*batch,)``.
    """
    params = kernel.sde_params()

    def _single_autocov(t: Float[Array, ""]) -> Float[Array, ""]:
        abs_t = jnp.abs(t)
        eF = jsl.expm(params.F * abs_t)
        cov_matrix = params.H @ eF @ params.P_inf @ params.H.T
        return cov_matrix.squeeze()

    orig_shape = tau.shape
    flat_tau = tau.ravel()
    flat_result = jax.vmap(_single_autocov)(flat_tau)
    return flat_result.reshape(orig_shape)

Kalman filtering & smoothing

The forward filter and RTS smoother, their \(O(\log N)\) parallel (associative-scan) counterparts, and the steady-state (infinite-horizon) variants built on the discrete algebraic Riccati equation.

Structured linear algebra and Gaussian primitives for JAX.

EmissionModel

Bases: Module

Observation (emission) model wrapping a linear observation matrix.

Provides named methods for common Kalman filter projection operations with observation matrix H ∈ ℝᴹˣᴺ.

Attributes:

Name Type Description
H Float[Array, 'M N']

Observation matrix, shape (M, N).

Source code in src/gaussx/_ssm/_emission.py
class EmissionModel(eqx.Module):
    """Observation (emission) model wrapping a linear observation matrix.

    Provides named methods for common Kalman filter projection
    operations with observation matrix H ∈ ℝᴹˣᴺ.

    Attributes:
        H: Observation matrix, shape ``(M, N)``.
    """

    H: Float[Array, "M N"]

    def project_mean(
        self,
        mean: Float[Array, " N"],
    ) -> Float[Array, " M"]:
        """Project state mean to observation space: ŷ = H x.

        Args:
            mean: State mean, shape ``(N,)``.

        Returns:
            Projected mean, shape ``(M,)``.
        """
        return self.H @ mean

    def project_covariance(
        self,
        cov: Float[Array, "N N"],
        noise: Float[Array, "M M"] | None = None,
    ) -> Float[Array, "M M"]:
        """Project state covariance: S = H P Hᵀ [+ R].

        Args:
            cov: State covariance P, shape ``(N, N)``.
            noise: Optional observation noise R, shape ``(M, M)``.

        Returns:
            Innovation covariance S, shape ``(M, M)``.
        """
        S = self.H @ cov @ self.H.T  # (M, M)
        if noise is not None:
            S = S + noise
        return S

    def innovation(
        self,
        y: Float[Array, " M"],
        x_pred: Float[Array, " N"],
    ) -> Float[Array, " M"]:
        """Compute innovation (measurement residual): v = y − H x.

        Args:
            y: Observation, shape ``(M,)``.
            x_pred: Predicted state mean, shape ``(N,)``.

        Returns:
            Innovation vector v, shape ``(M,)``.
        """
        return y - self.H @ x_pred

    def back_project_precision(
        self,
        noise_prec: Float[Array, "M M"],
    ) -> Float[Array, "N N"]:
        """Back-project observation precision: Hᵀ R⁻¹ H.

        Args:
            noise_prec: Observation noise precision R⁻¹, shape ``(M, M)``.

        Returns:
            Information matrix contribution, shape ``(N, N)``.
        """
        return self.H.T @ noise_prec @ self.H

    def back_project_info(
        self,
        y: Float[Array, " M"],
        noise_prec: Float[Array, "M M"],
    ) -> Float[Array, " N"]:
        """Back-project observation to information vector: Hᵀ R⁻¹ y.

        Args:
            y: Observation, shape ``(M,)``.
            noise_prec: Observation noise precision R⁻¹, shape ``(M, M)``.

        Returns:
            Information vector contribution, shape ``(N,)``.
        """
        return self.H.T @ noise_prec @ y

project_mean(mean: Float[Array, ' N']) -> Float[Array, ' M']

Project state mean to observation space: ŷ = H x.

Parameters:

Name Type Description Default
mean Float[Array, ' N']

State mean, shape (N,).

required

Returns:

Type Description
Float[Array, ' M']

Projected mean, shape (M,).

Source code in src/gaussx/_ssm/_emission.py
def project_mean(
    self,
    mean: Float[Array, " N"],
) -> Float[Array, " M"]:
    """Project state mean to observation space: ŷ = H x.

    Args:
        mean: State mean, shape ``(N,)``.

    Returns:
        Projected mean, shape ``(M,)``.
    """
    return self.H @ mean

project_covariance(cov: Float[Array, 'N N'], noise: Float[Array, 'M M'] | None = None) -> Float[Array, 'M M']

Project state covariance: S = H P Hᵀ [+ R].

Parameters:

Name Type Description Default
cov Float[Array, 'N N']

State covariance P, shape (N, N).

required
noise Float[Array, 'M M'] | None

Optional observation noise R, shape (M, M).

None

Returns:

Type Description
Float[Array, 'M M']

Innovation covariance S, shape (M, M).

Source code in src/gaussx/_ssm/_emission.py
def project_covariance(
    self,
    cov: Float[Array, "N N"],
    noise: Float[Array, "M M"] | None = None,
) -> Float[Array, "M M"]:
    """Project state covariance: S = H P Hᵀ [+ R].

    Args:
        cov: State covariance P, shape ``(N, N)``.
        noise: Optional observation noise R, shape ``(M, M)``.

    Returns:
        Innovation covariance S, shape ``(M, M)``.
    """
    S = self.H @ cov @ self.H.T  # (M, M)
    if noise is not None:
        S = S + noise
    return S

innovation(y: Float[Array, ' M'], x_pred: Float[Array, ' N']) -> Float[Array, ' M']

Compute innovation (measurement residual): v = y − H x.

Parameters:

Name Type Description Default
y Float[Array, ' M']

Observation, shape (M,).

required
x_pred Float[Array, ' N']

Predicted state mean, shape (N,).

required

Returns:

Type Description
Float[Array, ' M']

Innovation vector v, shape (M,).

Source code in src/gaussx/_ssm/_emission.py
def innovation(
    self,
    y: Float[Array, " M"],
    x_pred: Float[Array, " N"],
) -> Float[Array, " M"]:
    """Compute innovation (measurement residual): v = y − H x.

    Args:
        y: Observation, shape ``(M,)``.
        x_pred: Predicted state mean, shape ``(N,)``.

    Returns:
        Innovation vector v, shape ``(M,)``.
    """
    return y - self.H @ x_pred

back_project_precision(noise_prec: Float[Array, 'M M']) -> Float[Array, 'N N']

Back-project observation precision: Hᵀ R⁻¹ H.

Parameters:

Name Type Description Default
noise_prec Float[Array, 'M M']

Observation noise precision R⁻¹, shape (M, M).

required

Returns:

Type Description
Float[Array, 'N N']

Information matrix contribution, shape (N, N).

Source code in src/gaussx/_ssm/_emission.py
def back_project_precision(
    self,
    noise_prec: Float[Array, "M M"],
) -> Float[Array, "N N"]:
    """Back-project observation precision: Hᵀ R⁻¹ H.

    Args:
        noise_prec: Observation noise precision R⁻¹, shape ``(M, M)``.

    Returns:
        Information matrix contribution, shape ``(N, N)``.
    """
    return self.H.T @ noise_prec @ self.H

back_project_info(y: Float[Array, ' M'], noise_prec: Float[Array, 'M M']) -> Float[Array, ' N']

Back-project observation to information vector: Hᵀ R⁻¹ y.

Parameters:

Name Type Description Default
y Float[Array, ' M']

Observation, shape (M,).

required
noise_prec Float[Array, 'M M']

Observation noise precision R⁻¹, shape (M, M).

required

Returns:

Type Description
Float[Array, ' N']

Information vector contribution, shape (N,).

Source code in src/gaussx/_ssm/_emission.py
def back_project_info(
    self,
    y: Float[Array, " M"],
    noise_prec: Float[Array, "M M"],
) -> Float[Array, " N"]:
    """Back-project observation to information vector: Hᵀ R⁻¹ y.

    Args:
        y: Observation, shape ``(M,)``.
        noise_prec: Observation noise precision R⁻¹, shape ``(M, M)``.

    Returns:
        Information vector contribution, shape ``(N,)``.
    """
    return self.H.T @ noise_prec @ y

FilterState

Bases: Module

Output of kalman_filter.

Attributes:

Name Type Description
filtered_means Float[Array, 'T N']

Shape (T, N) — filtered state estimates.

filtered_covs Float[Array, 'T N N']

Shape (T, N, N) — filtered covariances.

predicted_means Float[Array, 'T N']

Shape (T, N) — predicted state estimates.

predicted_covs Float[Array, 'T N N']

Shape (T, N, N) — predicted covariances.

log_likelihood Float[Array, '']

Scalar — total log-likelihood.

Source code in src/gaussx/_ssm/_kalman.py
class FilterState(eqx.Module):
    """Output of ``kalman_filter``.

    Attributes:
        filtered_means: Shape ``(T, N)`` — filtered state estimates.
        filtered_covs: Shape ``(T, N, N)`` — filtered covariances.
        predicted_means: Shape ``(T, N)`` — predicted state estimates.
        predicted_covs: Shape ``(T, N, N)`` — predicted covariances.
        log_likelihood: Scalar — total log-likelihood.
    """

    filtered_means: Float[Array, "T N"]
    filtered_covs: Float[Array, "T N N"]
    predicted_means: Float[Array, "T N"]
    predicted_covs: Float[Array, "T N N"]
    log_likelihood: Float[Array, ""]

InfiniteHorizonState

Bases: Module

Output of infinite_horizon_filter.

Attributes:

Name Type Description
filtered_means Float[Array, 'T N']

Filtered state estimates, shape (T, N).

filtered_covs Float[Array, 'T N N']

Filtered covariances (constant), shape (T, N, N).

predicted_means Float[Array, 'T N']

Predicted state estimates, shape (T, N).

predicted_covs Float[Array, 'T N N']

Predicted covariances (constant), shape (T, N, N).

log_likelihood Float[Array, '']

Total log-likelihood (scalar).

Source code in src/gaussx/_ssm/_infinite_horizon_kalman.py
class InfiniteHorizonState(eqx.Module):
    """Output of ``infinite_horizon_filter``.

    Attributes:
        filtered_means: Filtered state estimates, shape ``(T, N)``.
        filtered_covs: Filtered covariances (constant), shape ``(T, N, N)``.
        predicted_means: Predicted state estimates, shape ``(T, N)``.
        predicted_covs: Predicted covariances (constant), shape ``(T, N, N)``.
        log_likelihood: Total log-likelihood (scalar).
    """

    filtered_means: Float[Array, "T N"]
    filtered_covs: Float[Array, "T N N"]
    predicted_means: Float[Array, "T N"]
    predicted_covs: Float[Array, "T N N"]
    log_likelihood: Float[Array, ""]

DAREResult

Bases: Module

Result of DARE solver.

Attributes:

Name Type Description
P_inf Float[Array, 'D D']

Steady-state covariance, shape (D, D).

K_inf Float[Array, 'D M']

Steady-state Kalman gain, shape (D, M).

converged Bool[Array, '']

Scalar boolean indicating convergence.

Source code in src/gaussx/_ssm/_dare.py
class DAREResult(eqx.Module):
    """Result of DARE solver.

    Attributes:
        P_inf: Steady-state covariance, shape ``(D, D)``.
        K_inf: Steady-state Kalman gain, shape ``(D, M)``.
        converged: Scalar boolean indicating convergence.
    """

    P_inf: Float[Array, "D D"]
    K_inf: Float[Array, "D M"]
    converged: Bool[Array, ""]

kalman_filter(transition: Float[Array, '*T N N'] | lx.AbstractLinearOperator, obs_model: Float[Array, '*T M N'] | lx.AbstractLinearOperator, process_noise: Float[Array, '*T N N'] | lx.AbstractLinearOperator, obs_noise: Float[Array, '*T M M'] | lx.AbstractLinearOperator, observations: Float[Array, 'T M'], init_mean: Float[Array, ' N'], init_cov: Float[Array, 'N N'], *, mask: Bool[Array, ' T'] | None = None, solver: AbstractSolverStrategy | None = None, woodbury_innovation: bool = False) -> FilterState

Kalman filter forward pass via jax.lax.scan.

Implements the predict-update cycle for a (possibly time-varying) linear-Gaussian state-space model:

x_t = A_t @ x_{t-1} + q_t,   q_t ~ N(0, Q_t)
y_t = H_t @ x_t + r_t,        r_t ~ N(0, R_t)

Time-invariant inputs (single (N, N) / (M, N) etc.) are automatically broadcast along the time axis. Time-varying inputs are passed as (T, …) stacks (e.g. from discretise_sequence).

Operator inputs (lineax BlockDiag / Kronecker / DiagonalLinearOperator / MaskedOperator / etc.) are accepted in the time-invariant signature only. The structural matvec (A @ x, H @ x) runs through the operator's mv; operator-typed Q / R are materialised to dense arrays once outside the scan (the per-step sandwiches A P A^T / H P H^T themselves run inside the scan because they depend on the evolving P_filt).

Parameters:

Name Type Description Default
transition Float[Array, '*T N N'] | AbstractLinearOperator

State transition matrix A. Shape (N, N), (T, N, N), or lineax.AbstractLinearOperator.

required
obs_model Float[Array, '*T M N'] | AbstractLinearOperator

Observation matrix H. Shape (M, N), (T, M, N), or operator.

required
process_noise Float[Array, '*T N N'] | AbstractLinearOperator

Process noise covariance Q. Shape (N, N), (T, N, N), or operator.

required
obs_noise Float[Array, '*T M M'] | AbstractLinearOperator

Observation noise covariance R. Shape (M, M), (T, M, M), or operator.

required
observations Float[Array, 'T M']

Observed data, shape (T, M).

required
init_mean Float[Array, ' N']

Initial state mean, shape (N,).

required
init_cov Float[Array, 'N N']

Initial state covariance, shape (N, N).

required
mask Bool[Array, ' T'] | None

Optional per-step boolean mask, shape (T,). True (or 1) runs the full predict + update step; False (or 0) runs the predict step only and skips the log-likelihood contribution. Defaults to all-True. Useful for prediction on merged train/test grids.

None
solver AbstractSolverStrategy | None

Optional solver strategy. When None, uses structural dispatch.

None
woodbury_innovation bool

When True, build the innovation covariance S = H P Hᵀ + R as a gaussx.LowRankUpdate so structured R can use Woodbury solves/log-determinants. Defaults to False to preserve the dense innovation path.

False

Raises:

Type Description
TypeError

If operator-typed inputs are mixed with 3D (T, …) arrays. Operator inputs must come from the time-invariant signature (per-step structured stacks are not supported; pass dense (T, …) arrays for the time-varying path).

Returns:

Type Description
FilterState

A FilterState with filtered/predicted means, covariances,

FilterState

and total log-likelihood.

Source code in src/gaussx/_ssm/_kalman.py
def kalman_filter(
    transition: Float[Array, "*T N N"] | lx.AbstractLinearOperator,
    obs_model: Float[Array, "*T M N"] | lx.AbstractLinearOperator,
    process_noise: Float[Array, "*T N N"] | lx.AbstractLinearOperator,
    obs_noise: Float[Array, "*T M M"] | lx.AbstractLinearOperator,
    observations: Float[Array, "T M"],
    init_mean: Float[Array, " N"],
    init_cov: Float[Array, "N N"],
    *,
    mask: Bool[Array, " T"] | None = None,
    solver: AbstractSolverStrategy | None = None,
    woodbury_innovation: bool = False,
) -> FilterState:
    r"""Kalman filter forward pass via ``jax.lax.scan``.

    Implements the predict-update cycle for a (possibly time-varying)
    linear-Gaussian state-space model:

        x_t = A_t @ x_{t-1} + q_t,   q_t ~ N(0, Q_t)
        y_t = H_t @ x_t + r_t,        r_t ~ N(0, R_t)

    **Time-invariant inputs** (single ``(N, N)`` / ``(M, N)`` etc.) are
    automatically broadcast along the time axis. **Time-varying inputs**
    are passed as ``(T, …)`` stacks (e.g. from
    `discretise_sequence`).

    **Operator inputs** (lineax ``BlockDiag`` / ``Kronecker`` /
    ``DiagonalLinearOperator`` / ``MaskedOperator`` / etc.) are accepted
    in the **time-invariant** signature only. The structural matvec
    (``A @ x``, ``H @ x``) runs through the operator's ``mv``;
    operator-typed ``Q`` / ``R`` are materialised to dense arrays once
    outside the scan (the per-step sandwiches ``A P A^T`` / ``H P H^T``
    themselves run inside the scan because they depend on the evolving
    ``P_filt``).

    Args:
        transition: State transition matrix ``A``. Shape ``(N, N)``,
            ``(T, N, N)``, or `lineax.AbstractLinearOperator`.
        obs_model: Observation matrix ``H``. Shape ``(M, N)``,
            ``(T, M, N)``, or operator.
        process_noise: Process noise covariance ``Q``. Shape ``(N, N)``,
            ``(T, N, N)``, or operator.
        obs_noise: Observation noise covariance ``R``. Shape ``(M, M)``,
            ``(T, M, M)``, or operator.
        observations: Observed data, shape ``(T, M)``.
        init_mean: Initial state mean, shape ``(N,)``.
        init_cov: Initial state covariance, shape ``(N, N)``.
        mask: Optional per-step boolean mask, shape ``(T,)``. ``True``
            (or ``1``) runs the full predict + update step; ``False``
            (or ``0``) runs the predict step only and skips the
            log-likelihood contribution. Defaults to all-True. Useful
            for prediction on merged train/test grids.
        solver: Optional solver strategy. When ``None``, uses
            structural dispatch.
        woodbury_innovation: When ``True``, build the innovation
            covariance ``S = H P Hᵀ + R`` as a
            `gaussx.LowRankUpdate` so structured ``R`` can use
            Woodbury solves/log-determinants. Defaults to ``False`` to
            preserve the dense innovation path.

    Raises:
        TypeError: If operator-typed inputs are mixed with 3D ``(T, …)``
            arrays. Operator inputs must come from the time-invariant
            signature (per-step structured stacks are not supported;
            pass dense ``(T, …)`` arrays for the time-varying path).

    Returns:
        A ``FilterState`` with filtered/predicted means, covariances,
        and total log-likelihood.
    """
    M = observations.shape[-1]
    T = observations.shape[0]

    # Closure-friendly matvec: when an operator is supplied, prefer its
    # structural ``mv`` over the dense ``A @ x``. Otherwise the
    # broadcast 3D array contains ``A_seq[t]`` for each step.
    A_op = transition if isinstance(transition, lx.AbstractLinearOperator) else None
    H_op = obs_model if isinstance(obs_model, lx.AbstractLinearOperator) else None
    R_op = obs_noise if isinstance(obs_noise, lx.AbstractLinearOperator) else None
    A_seq, H_seq, Q_seq, R_seq, mask_seq, _ = _normalise_tv_inputs(
        transition,
        obs_model,
        process_noise,
        obs_noise,
        T=T,
        mask=mask,
        materialise_transition=A_op is None,
        materialise_obs=H_op is None,
        # Skip the O(T M²) dense broadcast of structured R when the
        # Woodbury path consumes the operator directly.
        materialise_obs_noise=not (woodbury_innovation and R_op is not None),
    )

    def step(carry, inputs):
        x_filt, P_filt, ll = carry
        A_t, H_t, Q_t, R_t, y_t, mask_t = inputs

        # --- Predict ---
        # Structural matvec when an operator was supplied; dense matmul otherwise.
        x_pred = A_op.mv(x_filt) if A_op is not None else A_t @ x_filt
        if A_op is not None:
            P_filt_op = lx.MatrixLinearOperator(P_filt, lx.positive_semidefinite_tag)
            P_pred = sandwich(A_op, P_filt_op).as_matrix() + Q_t
        else:
            P_pred = A_t @ P_filt @ A_t.T + Q_t

        # --- Update (gated by mask via lax.cond so the predict-only
        #             branch evaluates neither the update arithmetic
        #             nor produces gradients for the dropped path). ---
        def _do_update(_):
            v = y_t - (H_op.mv(x_pred) if H_op is not None else H_t @ x_pred)
            # Resolve ``R`` for innovation: operator path uses the
            # closed-over ``R_op`` (kept structural for Woodbury); array
            # path falls back to the per-step ``R_t``.
            R_innov = R_op if R_op is not None else R_t
            # Resolve ``H`` similarly so the operator preserves structure
            # in both the Woodbury and the structural-sandwich paths.
            H_innov = H_op if H_op is not None else H_t
            S_op = _innovation_covariance(
                H_innov, P_pred, R_innov, woodbury=woodbury_innovation
            )

            PHt = (
                _right_matmul_transpose(P_pred, H_op)
                if H_op is not None
                else P_pred @ H_t.T
            )  # (N, M)
            K = solve_rows(S_op, PHt, solver=solver)  # (N, M)

            x_upd = x_pred + K @ v
            if woodbury_innovation:
                # Avoid materialising S for the covariance update. Use
                # the operator path when available (H_t is a (0, 0)
                # placeholder under operator mode).
                HP_pred = (
                    _left_matmul(H_op, P_pred) if H_op is not None else H_t @ P_pred
                )
                P_upd = P_pred - K @ HP_pred
            else:
                P_upd = P_pred - K @ S_op.as_matrix() @ K.T

            Sinv_v = dispatch_solve(S_op, v, solver)
            ld = dispatch_logdet(S_op, solver)
            ll_inc = -0.5 * (v @ Sinv_v + ld + M * _LOG_2PI)
            return x_upd, P_upd, ll_inc

        def _skip_update(_):
            return x_pred, P_pred, jnp.array(0.0)

        x_filt_new, P_filt_new, ll_inc = jax.lax.cond(
            mask_t, _do_update, _skip_update, operand=None
        )
        ll_new = ll + ll_inc

        carry_new = (x_filt_new, P_filt_new, ll_new)
        outputs = (x_filt_new, P_filt_new, x_pred, P_pred)
        return carry_new, outputs

    init_carry = (init_mean, init_cov, jnp.array(0.0))
    final_carry, (f_means, f_covs, p_means, p_covs) = jax.lax.scan(
        step, init_carry, (A_seq, H_seq, Q_seq, R_seq, observations, mask_seq)
    )

    return FilterState(
        filtered_means=f_means,
        filtered_covs=f_covs,
        predicted_means=p_means,
        predicted_covs=p_covs,
        log_likelihood=final_carry[2],
    )

rts_smoother(filter_state: FilterState, transition: Float[Array, '*T N N'] | lx.AbstractLinearOperator, process_noise: Float[Array, '*T N N'] | lx.AbstractLinearOperator, *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, 'T N'], Float[Array, 'T N N']]

Rauch-Tung-Striebel backward smoother.

Accepts the same time-invariant / time-varying / operator forms for transition and process_noise as kalman_filter. When a step was masked off in the filter (mask[t] == 0), the smoother formula degenerates harmlessly because filtered == predicted at that step.

Parameters:

Name Type Description Default
filter_state FilterState

Output of kalman_filter.

required
transition Float[Array, '*T N N'] | AbstractLinearOperator

State transition matrix or operator.

required
process_noise Float[Array, '*T N N'] | AbstractLinearOperator

Process noise covariance or operator. (Not currently used by the standard RTS recurrence — kept for API symmetry with kalman_filter.)

required
solver AbstractSolverStrategy | None

Optional solver strategy.

None

Returns:

Type Description
tuple[Float[Array, 'T N'], Float[Array, 'T N N']]

Tuple (smoothed_means, smoothed_covs).

Source code in src/gaussx/_ssm/_kalman.py
def rts_smoother(
    filter_state: FilterState,
    transition: Float[Array, "*T N N"] | lx.AbstractLinearOperator,
    process_noise: Float[Array, "*T N N"] | lx.AbstractLinearOperator,
    *,
    solver: AbstractSolverStrategy | None = None,
) -> tuple[Float[Array, "T N"], Float[Array, "T N N"]]:
    """Rauch-Tung-Striebel backward smoother.

    Accepts the same time-invariant / time-varying / operator forms for
    ``transition`` and ``process_noise`` as `kalman_filter`. When
    a step was masked off in the filter (``mask[t] == 0``), the
    smoother formula degenerates harmlessly because filtered ==
    predicted at that step.

    Args:
        filter_state: Output of `kalman_filter`.
        transition: State transition matrix or operator.
        process_noise: Process noise covariance or operator. (Not
            currently used by the standard RTS recurrence — kept for
            API symmetry with `kalman_filter`.)
        solver: Optional solver strategy.

    Returns:
        Tuple ``(smoothed_means, smoothed_covs)``.
    """
    del process_noise  # not used in the standard RTS recurrence

    T = filter_state.filtered_means.shape[0]

    # Materialise once outside the scan for the sandwich; matvec stays
    # structural via the operator's mv.
    A_dense = _materialise(transition)
    A_op = transition if isinstance(transition, lx.AbstractLinearOperator) else None
    if A_dense.ndim == 2:
        A_seq = jnp.broadcast_to(A_dense, (T, *A_dense.shape))
    elif A_dense.ndim == 3:
        if A_op is not None:
            raise TypeError(
                "Operator-typed transition cannot have a leading time axis."
            )
        A_seq = A_dense
    else:
        raise ValueError(f"transition must have ndim 2 or 3, got {A_dense.ndim}.")

    def step(carry, inputs):
        x_smooth, P_smooth = carry
        x_filt, P_filt, x_pred, P_pred, A_next = inputs

        # Smoother gain: G = P_filt A_{t+1}^T P_pred_{t+1}^{-1}
        P_pred_op = lx.MatrixLinearOperator(P_pred, lx.positive_semidefinite_tag)
        G = P_filt @ A_next.T  # (N, N)
        G = solve_rows(P_pred_op, G, solver=solver)  # (N, N)

        x_smooth_new = x_filt + G @ (x_smooth - x_pred)
        P_smooth_new = P_filt + G @ (P_smooth - P_pred) @ G.T

        return (x_smooth_new, P_smooth_new), (x_smooth_new, P_smooth_new)

    init_carry = (
        filter_state.filtered_means[T - 1],
        filter_state.filtered_covs[T - 1],
    )

    # Reverse the sequences for backward pass (exclude last time step).
    # ``A_next[t]`` is the transition that maps step ``t`` to step ``t+1``,
    # i.e. ``A_seq[t+1]``.
    inputs = (
        filter_state.filtered_means[:-1][::-1],
        filter_state.filtered_covs[:-1][::-1],
        filter_state.predicted_means[1:][::-1],
        filter_state.predicted_covs[1:][::-1],
        A_seq[1:][::-1],
    )

    _, (s_means_rev, s_covs_rev) = jax.lax.scan(step, init_carry, inputs)

    # Reverse back and prepend last filtered state.
    s_means = jnp.concatenate(
        [s_means_rev[::-1], filter_state.filtered_means[T - 1 :]], axis=0
    )
    s_covs = jnp.concatenate(
        [s_covs_rev[::-1], filter_state.filtered_covs[T - 1 :]], axis=0
    )

    return s_means, s_covs

kalman_gain(P: lx.AbstractLinearOperator, H: lx.AbstractLinearOperator, R: lx.AbstractLinearOperator, *, solver: AbstractSolverStrategy | None = None, woodbury_innovation: bool = False) -> Float[Array, 'N M']

Compute Kalman gain K = P @ H^T @ (H @ P @ H^T + R)^{-1}.

Parameters:

Name Type Description Default
P AbstractLinearOperator

Prior covariance operator, shape (N, N).

required
H AbstractLinearOperator

Observation model operator, shape (M, N).

required
R AbstractLinearOperator

Observation noise operator, shape (M, M).

required
solver AbstractSolverStrategy | None

Optional solver strategy. When None, uses structural dispatch.

None
woodbury_innovation bool

When True, route the innovation covariance through gaussx.LowRankUpdate.

False

Returns:

Type Description
Float[Array, 'N M']

Kalman gain matrix of shape (N, M).

Source code in src/gaussx/_ssm/_kalman.py
def kalman_gain(
    P: lx.AbstractLinearOperator,
    H: lx.AbstractLinearOperator,
    R: lx.AbstractLinearOperator,
    *,
    solver: AbstractSolverStrategy | None = None,
    woodbury_innovation: bool = False,
) -> Float[Array, "N M"]:
    """Compute Kalman gain ``K = P @ H^T @ (H @ P @ H^T + R)^{-1}``.

    Args:
        P: Prior covariance operator, shape ``(N, N)``.
        H: Observation model operator, shape ``(M, N)``.
        R: Observation noise operator, shape ``(M, M)``.
        solver: Optional solver strategy. When ``None``, uses
            structural dispatch.
        woodbury_innovation: When ``True``, route the innovation
            covariance through `gaussx.LowRankUpdate`.

    Returns:
        Kalman gain matrix of shape ``(N, M)``.
    """
    P_mat = _materialise(P)
    H_mat = _materialise(H)

    S_op = _innovation_covariance(H, P, R, woodbury=woodbury_innovation)

    # K = P Hᵀ S⁻¹
    PHt = P_mat @ H_mat.T  # (N, M)
    return solve_rows(S_op, PHt, solver=solver)  # (N, M)

parallel_kalman_filter(transition: Float[Array, '*T N N'] | lx.AbstractLinearOperator, obs_model: Float[Array, '*T M N'] | lx.AbstractLinearOperator, process_noise: Float[Array, '*T N N'] | lx.AbstractLinearOperator, obs_noise: Float[Array, '*T M M'] | lx.AbstractLinearOperator, observations: Float[Array, 'T M'], init_mean: Float[Array, ' N'], init_cov: Float[Array, 'N N'], *, mask: Bool[Array, ' T'] | None = None, solver: AbstractSolverStrategy | None = None, woodbury_innovation: bool = False, form: str = 'covariance') -> FilterState

Parallel Kalman filter via jax.lax.associative_scan.

Numerically equivalent to gaussx.kalman_filter but with O(log T) parallel depth on accelerators. Same generalised contract (TI / TV / operator-typed inputs, optional mask, scalar log-likelihood). Empty observation windows (T == 0) return a zero-length FilterState with log_likelihood == 0.

Parameters:

Name Type Description Default
transition Float[Array, '*T N N'] | AbstractLinearOperator

State transition matrix or operator.

required
obs_model Float[Array, '*T M N'] | AbstractLinearOperator

Observation matrix or operator.

required
process_noise Float[Array, '*T N N'] | AbstractLinearOperator

Process noise covariance or operator.

required
obs_noise Float[Array, '*T M M'] | AbstractLinearOperator

Observation noise covariance or operator.

required
observations Float[Array, 'T M']

Observed data, shape (T, M).

required
init_mean Float[Array, ' N']

Initial state mean, shape (N,).

required
init_cov Float[Array, 'N N']

Initial state covariance, shape (N, N).

required
mask Bool[Array, ' T'] | None

Optional (T,) boolean mask; False runs predict-only and contributes 0 to the log-likelihood. Defaults to all-True.

None
solver AbstractSolverStrategy | None

Accepted for API symmetry with kalman_filter but not currently threaded through the per-element solves; the covariance-form combinator uses unstructured dense solves. The square-root form also uses dense solves for the affine terms.

None
woodbury_innovation bool

When True, delegates to gaussx.kalman_filter with the same flag so structured R uses the Woodbury innovation path.

False
form str

Either "covariance" (default) or "sqrt". The square-root form maintains lower-triangular covariance factors alongside the covariance updates and reconstructs PSD covariance matrices in the returned FilterState. Note: the associative-scan equations themselves still use the covariance form internally; the factor path is a PSD-safety net for ill-conditioned float32 chains rather than a fully factor-propagating combinator (see #165).

'covariance'

Raises:

Type Description
ValueError

If form is not "covariance" or "sqrt".

Returns:

Type Description
FilterState

FilterState with filtered / predicted means and covs

FilterState

and the total log-likelihood.

Source code in src/gaussx/_ssm/_parallel_kalman.py
def parallel_kalman_filter(
    transition: Float[Array, "*T N N"] | lx.AbstractLinearOperator,
    obs_model: Float[Array, "*T M N"] | lx.AbstractLinearOperator,
    process_noise: Float[Array, "*T N N"] | lx.AbstractLinearOperator,
    obs_noise: Float[Array, "*T M M"] | lx.AbstractLinearOperator,
    observations: Float[Array, "T M"],
    init_mean: Float[Array, " N"],
    init_cov: Float[Array, "N N"],
    *,
    mask: Bool[Array, " T"] | None = None,
    solver: AbstractSolverStrategy | None = None,
    woodbury_innovation: bool = False,
    form: str = "covariance",
) -> FilterState:
    """Parallel Kalman filter via `jax.lax.associative_scan`.

    Numerically equivalent to `gaussx.kalman_filter` but with
    ``O(log T)`` parallel depth on accelerators. Same generalised
    contract (TI / TV / operator-typed inputs, optional mask, scalar
    log-likelihood). Empty observation windows (``T == 0``) return a
    zero-length `FilterState` with ``log_likelihood == 0``.

    Args:
        transition: State transition matrix or operator.
        obs_model: Observation matrix or operator.
        process_noise: Process noise covariance or operator.
        obs_noise: Observation noise covariance or operator.
        observations: Observed data, shape ``(T, M)``.
        init_mean: Initial state mean, shape ``(N,)``.
        init_cov: Initial state covariance, shape ``(N, N)``.
        mask: Optional ``(T,)`` boolean mask; ``False`` runs predict-only
            and contributes 0 to the log-likelihood. Defaults to all-True.
        solver: Accepted for API symmetry with `kalman_filter` but
            not currently threaded through the per-element solves; the
            covariance-form combinator uses unstructured dense solves.
            The square-root form also uses dense solves for the affine
            terms.
        woodbury_innovation: When ``True``, delegates to
            `gaussx.kalman_filter` with the same flag so structured
            ``R`` uses the Woodbury innovation path.
        form: Either ``"covariance"`` (default) or ``"sqrt"``. The
            square-root form maintains lower-triangular covariance
            factors alongside the covariance updates and reconstructs
            PSD covariance matrices in the returned `FilterState`.
            Note: the associative-scan equations themselves still use
            the covariance form internally; the factor path is a
            PSD-safety net for ill-conditioned float32 chains rather
            than a fully factor-propagating combinator (see #165).

    Raises:
        ValueError: If ``form`` is not ``"covariance"`` or ``"sqrt"``.

    Returns:
        `FilterState` with filtered / predicted means and covs
        and the total log-likelihood.
    """
    if form == "sqrt":
        from gaussx._ssm._parallel_kalman_sqrt import parallel_kalman_filter_sqrt

        return parallel_kalman_filter_sqrt(
            transition,
            obs_model,
            process_noise,
            obs_noise,
            observations,
            init_mean,
            init_cov,
            mask=mask,
            solver=solver,
        )
    if form != "covariance":
        raise ValueError("form must be 'covariance' or 'sqrt'.")

    if woodbury_innovation:
        return kalman_filter(
            transition,
            obs_model,
            process_noise,
            obs_noise,
            observations,
            init_mean,
            init_cov,
            mask=mask,
            solver=solver,
            woodbury_innovation=True,
        )

    del solver  # not currently threaded through; see docstring + #165

    M_obs = observations.shape[-1]
    T = observations.shape[0]
    N = init_mean.shape[0]

    # Empty observation window: match kalman_filter's empty-scan output.
    if T == 0:
        return FilterState(
            filtered_means=jnp.zeros((0, N), dtype=init_mean.dtype),
            filtered_covs=jnp.zeros((0, N, N), dtype=init_cov.dtype),
            predicted_means=jnp.zeros((0, N), dtype=init_mean.dtype),
            predicted_covs=jnp.zeros((0, N, N), dtype=init_cov.dtype),
            log_likelihood=jnp.zeros((), dtype=init_mean.dtype),
        )

    A_seq, H_seq, Q_seq, R_seq, mask_seq, _ = _normalise_tv_inputs(
        transition, obs_model, process_noise, obs_noise, T=T, mask=mask
    )

    # Build per-step elements. ``vmap`` of ``lax.cond`` evaluates both
    # branches and selects, so we instead substitute mask-aware safe
    # inputs (H=0, R=I, y=0 for masked steps) into a single active path.
    # With those substitutions the active builder collapses to
    # (F, 0, Q, 0, 0) — exactly the predict-only element — and the
    # Cholesky operates on the well-conditioned identity, so even
    # garbage in masked H / R / y can't NaN the gradient.
    def _build_step(F, H, Q, R, y, m):
        H_eff = jnp.where(m, H, jnp.zeros_like(H))
        R_eff = jnp.where(m, R, jnp.eye(M_obs, dtype=R.dtype))
        y_eff = jnp.where(m, y, jnp.zeros_like(y))
        return _generic_filter_element_active(F, H_eff, Q, R_eff, y_eff)

    elems = jax.vmap(_build_step)(A_seq, H_seq, Q_seq, R_seq, observations, mask_seq)

    # Patch element 0 to absorb the initial prior. Outer ``lax.cond``
    # genuinely skips the inactive branch (no ``vmap`` wrapping here).
    first = jax.lax.cond(
        mask_seq[0],
        lambda: _first_filter_element_active(
            A_seq[0],
            H_seq[0],
            Q_seq[0],
            R_seq[0],
            observations[0],
            init_mean,
            init_cov,
        ),
        lambda: _first_filter_element_masked(
            A_seq[0],
            Q_seq[0],
            init_mean,
            init_cov,
        ),
    )
    elems = tuple(arr.at[0].set(val) for arr, val in zip(elems, first, strict=True))

    # ----- Associative scan -----
    _A_out, b_out, C_out, _eta_out, _J_out = jax.lax.associative_scan(
        _filter_combine, elems
    )
    filtered_means = b_out
    filtered_covs = jax.vmap(_sym)(C_out)

    # Reconstruct predicted means / covs from filtered + transition.
    prev_means = jnp.concatenate([init_mean[None], filtered_means[:-1]], axis=0)
    prev_covs = jnp.concatenate([init_cov[None], filtered_covs[:-1]], axis=0)

    def _predict_step(F, m, P, Q):
        return F @ m, _sym(F @ P @ F.T + Q)

    predicted_means, predicted_covs = jax.vmap(_predict_step)(
        A_seq, prev_means, prev_covs, Q_seq
    )

    # Log-likelihood from innovations. Same safe substitution as the
    # element builder so masked steps don't drive the Cholesky through
    # ill-conditioned user-supplied R / NaN gradients.
    def _ll_contrib(y, m_pred, P_pred, H, R, m):
        H_eff = jnp.where(m, H, jnp.zeros_like(H))
        R_eff = jnp.where(m, R, jnp.eye(M_obs, dtype=R.dtype))
        y_eff = jnp.where(m, y, jnp.zeros_like(y))
        v = y_eff - H_eff @ m_pred
        S = _sym(H_eff @ P_pred @ H_eff.T + R_eff)
        L = jnp.linalg.cholesky(S)
        Sinv_v = jax.scipy.linalg.cho_solve((L, True), v)
        quad = v @ Sinv_v
        logdet = cholesky_logdet(L)
        contrib = -0.5 * (quad + logdet + M_obs * _LOG_2PI)
        return jnp.where(m, contrib, jnp.zeros_like(contrib))

    ll_contribs = jax.vmap(_ll_contrib)(
        observations, predicted_means, predicted_covs, H_seq, R_seq, mask_seq
    )
    log_likelihood = jnp.sum(ll_contribs)

    return FilterState(
        filtered_means=filtered_means,
        filtered_covs=filtered_covs,
        predicted_means=predicted_means,
        predicted_covs=predicted_covs,
        log_likelihood=log_likelihood,
    )

parallel_rts_smoother(filter_state: FilterState, transition: Float[Array, '*T N N'] | lx.AbstractLinearOperator, process_noise: Float[Array, '*T N N'] | lx.AbstractLinearOperator, *, solver: AbstractSolverStrategy | None = None, form: str = 'covariance') -> tuple[Float[Array, 'T N'], Float[Array, 'T N N']]

Parallel RTS smoother via reverse jax.lax.associative_scan.

Pairs with parallel_kalman_filter. Numerically equivalent to gaussx.rts_smoother with O(log T) parallel depth.

Parameters:

Name Type Description Default
filter_state FilterState

Output of parallel_kalman_filter or gaussx.kalman_filter.

required
transition Float[Array, '*T N N'] | AbstractLinearOperator

State transition matrix or operator.

required
process_noise Float[Array, '*T N N'] | AbstractLinearOperator

Unused — kept for API symmetry with the sequential smoother.

required
solver AbstractSolverStrategy | None

Accepted for API symmetry; not currently threaded through.

None
form str

Either "covariance" (default) or "sqrt". The square-root form maintains lower-triangular factors alongside the smoother associative scan and returns PSD-reconstructed covariances (see parallel_kalman_filter for the same caveat about the internal combinator).

'covariance'

Raises:

Type Description
ValueError

If form is not "covariance" or "sqrt".

Returns:

Type Description
tuple[Float[Array, 'T N'], Float[Array, 'T N N']]

Tuple (smoothed_means, smoothed_covs).

Source code in src/gaussx/_ssm/_parallel_kalman.py
def parallel_rts_smoother(
    filter_state: FilterState,
    transition: Float[Array, "*T N N"] | lx.AbstractLinearOperator,
    process_noise: Float[Array, "*T N N"] | lx.AbstractLinearOperator,
    *,
    solver: AbstractSolverStrategy | None = None,
    form: str = "covariance",
) -> tuple[Float[Array, "T N"], Float[Array, "T N N"]]:
    """Parallel RTS smoother via reverse `jax.lax.associative_scan`.

    Pairs with `parallel_kalman_filter`. Numerically equivalent to
    `gaussx.rts_smoother` with ``O(log T)`` parallel depth.

    Args:
        filter_state: Output of `parallel_kalman_filter` or
            `gaussx.kalman_filter`.
        transition: State transition matrix or operator.
        process_noise: Unused — kept for API symmetry with the sequential
            smoother.
        solver: Accepted for API symmetry; not currently threaded
            through.
        form: Either ``"covariance"`` (default) or ``"sqrt"``. The
            square-root form maintains lower-triangular factors
            alongside the smoother associative scan and returns
            PSD-reconstructed covariances (see `parallel_kalman_filter`
            for the same caveat about the internal combinator).

    Raises:
        ValueError: If ``form`` is not ``"covariance"`` or ``"sqrt"``.

    Returns:
        Tuple ``(smoothed_means, smoothed_covs)``.
    """
    if form == "sqrt":
        from gaussx._ssm._parallel_kalman_sqrt import parallel_rts_smoother_sqrt

        return parallel_rts_smoother_sqrt(
            filter_state,
            transition,
            process_noise,
            solver=solver,
        )
    if form != "covariance":
        raise ValueError("form must be 'covariance' or 'sqrt'.")

    del process_noise, solver

    f_means = filter_state.filtered_means
    f_covs = filter_state.filtered_covs
    p_means = filter_state.predicted_means
    p_covs = filter_state.predicted_covs
    T = f_means.shape[0]
    N = f_means.shape[-1]

    if T == 0:
        return (
            jnp.zeros((0, N), dtype=f_means.dtype),
            jnp.zeros((0, N, N), dtype=f_covs.dtype),
        )

    A_dense = _materialise(transition)
    A_op = transition if isinstance(transition, lx.AbstractLinearOperator) else None
    if A_dense.ndim == 2:
        A_seq = jnp.broadcast_to(A_dense, (T, *A_dense.shape))
    elif A_dense.ndim == 3:
        if A_op is not None:
            raise TypeError(
                "Operator-typed transition cannot have a leading time axis."
            )
        A_seq = A_dense
    else:
        raise ValueError(f"transition must have ndim 2 or 3, got {A_dense.ndim}.")

    def _build_inner(f_mean, f_cov, p_mean_next, p_cov_next, A_next):
        # G = f_cov @ A_next.T @ inv(p_cov_next); p_cov_next is symmetric.
        rhs = f_cov @ A_next.T  # (N, N)
        G = jnp.linalg.solve(p_cov_next, rhs.T).T
        E = G
        g = f_mean - G @ p_mean_next
        L = _sym(f_cov - G @ p_cov_next @ G.T)
        return E, g, L

    inner_E, inner_g, inner_L = jax.vmap(_build_inner)(
        f_means[:-1], f_covs[:-1], p_means[1:], p_covs[1:], A_seq[1:]
    )
    last_E = jnp.zeros((1, N, N), dtype=f_means.dtype)
    last_g = f_means[-1:]
    last_L = f_covs[-1:]

    E = jnp.concatenate([inner_E, last_E], axis=0)
    g = jnp.concatenate([inner_g, last_g], axis=0)
    L = jnp.concatenate([inner_L, last_L], axis=0)

    _E_out, smoothed_means, smoothed_covs = jax.lax.associative_scan(
        _smoother_combine, (E, g, L), reverse=True
    )
    smoothed_covs = jax.vmap(_sym)(smoothed_covs)
    return smoothed_means, smoothed_covs

infinite_horizon_filter(transition: Float[Array, 'N N'] | lx.AbstractLinearOperator, obs_model: Float[Array, 'M N'] | lx.AbstractLinearOperator, process_noise: Float[Array, 'N N'] | lx.AbstractLinearOperator, obs_noise: Float[Array, 'M M'] | lx.AbstractLinearOperator, observations: Float[Array, 'T M'], init_mean: Float[Array, ' N'] | None = None, *, dare_result: DAREResult | None = None, max_iter: int = 100, tol: float = 1e-08, solver: AbstractSolverStrategy | None = None, woodbury_innovation: bool = False) -> InfiniteHorizonState

Infinite-horizon Kalman filter with fixed steady-state gain.

Uses the DARE solution for a constant Kalman gain K∞, avoiding per-step Riccati updates. For dense matrices, the per-step cost is O(N² + MN + M²) instead of O(N³) for the standard Kalman filter:

Predict:  x⁻ₜ = A xₜ₋₁
Update:   vₜ  = yₜ − H x⁻ₜ
          xₜ  = x⁻ₜ + K∞ vₜ

All four operator/array arguments accept either a raw JAX array or a lineax.AbstractLinearOperator. Operator inputs preserve their structural matvec inside the per-step scan; the sandwiches materialise once outside the scan.

Parameters:

Name Type Description Default
transition Float[Array, 'N N'] | AbstractLinearOperator

State transition matrix or operator, shape (N, N).

required
obs_model Float[Array, 'M N'] | AbstractLinearOperator

Observation matrix or operator, shape (M, N).

required
process_noise Float[Array, 'N N'] | AbstractLinearOperator

Process noise covariance or operator, shape (N, N).

required
obs_noise Float[Array, 'M M'] | AbstractLinearOperator

Observation noise covariance or operator, shape (M, M).

required
observations Float[Array, 'T M']

Observed data y, shape (T, M).

required
init_mean Float[Array, ' N'] | None

Initial state mean, shape (N,). Defaults to zeros.

None
dare_result DAREResult | None

Precomputed DARE result. If None, calls dare() internally.

None
max_iter int

Maximum DARE iterations (used only if dare_result is None).

100
tol float

DARE convergence tolerance (used only if dare_result is None).

1e-08
solver AbstractSolverStrategy | None

Optional solver strategy for structured linear algebra. When None, falls back to structural dispatch.

None
woodbury_innovation bool

When True, build the steady-state innovation covariance as gaussx.LowRankUpdate so structured R can use Woodbury solves/log-determinants.

False

Returns:

Type Description
InfiniteHorizonState

An InfiniteHorizonState with filtered/predicted means,

InfiniteHorizonState

covariances, and total log-likelihood.

Source code in src/gaussx/_ssm/_infinite_horizon_kalman.py
def infinite_horizon_filter(
    transition: Float[Array, "N N"] | lx.AbstractLinearOperator,
    obs_model: Float[Array, "M N"] | lx.AbstractLinearOperator,
    process_noise: Float[Array, "N N"] | lx.AbstractLinearOperator,
    obs_noise: Float[Array, "M M"] | lx.AbstractLinearOperator,
    observations: Float[Array, "T M"],
    init_mean: Float[Array, " N"] | None = None,
    *,
    dare_result: DAREResult | None = None,
    max_iter: int = 100,
    tol: float = 1e-8,
    solver: AbstractSolverStrategy | None = None,
    woodbury_innovation: bool = False,
) -> InfiniteHorizonState:
    """Infinite-horizon Kalman filter with fixed steady-state gain.

    Uses the DARE solution for a constant Kalman gain K∞, avoiding
    per-step Riccati updates.  For dense matrices, the per-step cost is
    O(N² + MN + M²) instead of O(N³) for the standard Kalman filter:

        Predict:  x⁻ₜ = A xₜ₋₁
        Update:   vₜ  = yₜ − H x⁻ₜ
                  xₜ  = x⁻ₜ + K∞ vₜ

    All four operator/array arguments accept either a raw JAX array or
    a `lineax.AbstractLinearOperator`. Operator inputs preserve
    their structural matvec inside the per-step scan; the sandwiches
    materialise once outside the scan.

    Args:
        transition: State transition matrix or operator, shape ``(N, N)``.
        obs_model: Observation matrix or operator, shape ``(M, N)``.
        process_noise: Process noise covariance or operator, shape ``(N, N)``.
        obs_noise: Observation noise covariance or operator, shape ``(M, M)``.
        observations: Observed data y, shape ``(T, M)``.
        init_mean: Initial state mean, shape ``(N,)``. Defaults to zeros.
        dare_result: Precomputed DARE result. If ``None``, calls
            ``dare()`` internally.
        max_iter: Maximum DARE iterations (used only if ``dare_result``
            is ``None``).
        tol: DARE convergence tolerance (used only if ``dare_result``
            is ``None``).
        solver: Optional solver strategy for structured linear algebra.
            When ``None``, falls back to structural dispatch.
        woodbury_innovation: When ``True``, build the steady-state
            innovation covariance as `gaussx.LowRankUpdate` so
            structured ``R`` can use Woodbury solves/log-determinants.

    Returns:
        An ``InfiniteHorizonState`` with filtered/predicted means,
        covariances, and total log-likelihood.
    """
    if dare_result is None:
        dare_result = dare(
            transition,
            obs_model,
            process_noise,
            obs_noise,
            max_iter=max_iter,
            tol=tol,
            solver=solver,
            woodbury_innovation=woodbury_innovation,
        )

    A_op = _as_operator(transition)
    H_op = _as_operator(obs_model)
    Q_dense = _materialise(process_noise)
    # Keep ``R`` lazy when the Woodbury innovation path will consume the
    # operator directly — avoids an O(M²) allocation for large structured
    # noise (e.g. ``DiagonalLinearOperator`` with large ``M``).
    R_for_innovation = (
        obs_noise
        if woodbury_innovation and isinstance(obs_noise, lx.AbstractLinearOperator)
        else _materialise(obs_noise)
    )

    P_inf = dare_result.P_inf  # (N, N)
    K_inf = dare_result.K_inf  # (N, M)
    T = observations.shape[0]
    M = observations.shape[-1]
    N = A_op.out_size()

    # Precompute steady-state quantities
    P_inf_op = lx.MatrixLinearOperator(P_inf, lx.positive_semidefinite_tag)
    P_pred_inf = sandwich(A_op, P_inf_op).as_matrix() + Q_dense  # (N, N)
    S_inf = _innovation_covariance(
        H_op, P_pred_inf, R_for_innovation, woodbury=woodbury_innovation
    )
    ld_inf = dispatch_logdet(S_inf, solver)  # scalar

    # Steady-state filtered covariance: P_filt = (I − K∞ H) P⁻pred
    HP_pred_inf = _left_matmul(H_op, P_pred_inf)
    P_filt_inf = P_pred_inf - K_inf @ HP_pred_inf  # (N, N)

    def step(carry, y_t):
        x_filt, ll = carry

        x_pred = _matvec(transition, x_filt)  # (N,)
        v = y_t - _matvec(obs_model, x_pred)  # (M,)  innovation
        x_filt_new = x_pred + K_inf @ v  # (N,)

        # Log-likelihood increment.
        Sinv_v = dispatch_solve(S_inf, v, solver)  # (M,)
        ll_inc = -0.5 * (v @ Sinv_v + ld_inf + M * _LOG_2PI)

        return (x_filt_new, ll + ll_inc), (x_filt_new, x_pred)

    if init_mean is None:
        init_mean = jnp.zeros(N)
    init_carry = (init_mean, jnp.array(0.0))
    (_, total_ll), (f_means, p_means) = jax.lax.scan(
        step,
        init_carry,
        observations,
    )

    # Broadcast constant covariances to (T, N, N)
    f_covs = repeat(P_filt_inf, "n1 n2 -> T n1 n2", T=T)
    p_covs = repeat(P_pred_inf, "n1 n2 -> T n1 n2", T=T)

    return InfiniteHorizonState(
        filtered_means=f_means,
        filtered_covs=f_covs,
        predicted_means=p_means,
        predicted_covs=p_covs,
        log_likelihood=total_ll,
    )

infinite_horizon_smoother(filter_state: InfiniteHorizonState, transition: Float[Array, 'N N'] | lx.AbstractLinearOperator, dare_result: DAREResult, process_noise: Float[Array, 'N N'] | lx.AbstractLinearOperator, *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, 'T N'], Float[Array, 'T N N']]

Infinite-horizon RTS smoother with fixed steady-state gain.

Precomputes the steady-state smoother gain G∞ = P∞ Aᵀ P⁻pred⁻¹, then runs a backward scan with fixed G∞. The steady-state smoothed covariance is the solution of the discrete Lyapunov equation:

P_smooth = P∞ + G∞ (P_smooth − P⁻pred) G∞ᵀ

Parameters:

Name Type Description Default
filter_state InfiniteHorizonState

Output of infinite_horizon_filter.

required
transition Float[Array, 'N N'] | AbstractLinearOperator

State transition matrix or operator, shape (N, N).

required
dare_result DAREResult

DARE result used in the filter.

required
process_noise Float[Array, 'N N'] | AbstractLinearOperator

Process noise covariance or operator, shape (N, N).

required
solver AbstractSolverStrategy | None

Optional solver strategy for structured linear algebra. When None, falls back to structural dispatch.

None

Returns:

Type Description
Float[Array, 'T N']

Tuple (smoothed_means, smoothed_covs) with shapes

Float[Array, 'T N N']

(T, N) and (T, N, N).

Source code in src/gaussx/_ssm/_infinite_horizon_kalman.py
def infinite_horizon_smoother(
    filter_state: InfiniteHorizonState,
    transition: Float[Array, "N N"] | lx.AbstractLinearOperator,
    dare_result: DAREResult,
    process_noise: Float[Array, "N N"] | lx.AbstractLinearOperator,
    *,
    solver: AbstractSolverStrategy | None = None,
) -> tuple[Float[Array, "T N"], Float[Array, "T N N"]]:
    """Infinite-horizon RTS smoother with fixed steady-state gain.

    Precomputes the steady-state smoother gain G∞ = P∞ Aᵀ P⁻pred⁻¹,
    then runs a backward scan with fixed G∞.  The steady-state smoothed
    covariance is the solution of the discrete Lyapunov equation:

        P_smooth = P∞ + G∞ (P_smooth − P⁻pred) G∞ᵀ

    Args:
        filter_state: Output of ``infinite_horizon_filter``.
        transition: State transition matrix or operator, shape ``(N, N)``.
        dare_result: DARE result used in the filter.
        process_noise: Process noise covariance or operator, shape ``(N, N)``.
        solver: Optional solver strategy for structured linear algebra.
            When ``None``, falls back to structural dispatch.

    Returns:
        Tuple ``(smoothed_means, smoothed_covs)`` with shapes
        ``(T, N)`` and ``(T, N, N)``.
    """
    A_op = _as_operator(transition)
    Q_dense = _materialise(process_noise)
    P_inf = dare_result.P_inf  # (N, N)
    P_inf_op = lx.MatrixLinearOperator(P_inf, lx.positive_semidefinite_tag)
    P_pred_inf = sandwich(A_op, P_inf_op).as_matrix() + Q_dense  # (N, N)

    # Steady-state smoother gain: G∞ = P∞ Aᵀ P⁻pred⁻¹
    P_pred_inf_op = lx.MatrixLinearOperator(P_pred_inf, lx.positive_semidefinite_tag)
    G_inf = solve_rows(
        P_pred_inf_op,
        _right_matmul_transpose(P_inf, A_op),
        solver=solver,
    )  # (N, N)

    # Solve discrete Lyapunov equation:
    # P_smooth = P∞ + G∞ (P_smooth − P⁻pred) G∞ᵀ
    # ⟺ P_smooth − G∞ P_smooth G∞ᵀ = P∞ − G∞ P⁻pred G∞ᵀ
    # Routed through `discrete_lyapunov_solve` which uses a
    # per-factor eigendecomposition of ``G∞`` instead of materializing
    # the ``(N², N²)`` Kronecker matrix ``I − G∞ ⊗ G∞``.
    rhs = P_inf - G_inf @ P_pred_inf @ G_inf.T  # (N, N)
    P_smooth_inf = discrete_lyapunov_solve(G_inf, rhs)
    P_smooth_inf = symmetrize(P_smooth_inf)

    T = filter_state.filtered_means.shape[0]

    def step(carry, inputs):
        x_smooth = carry
        x_filt, x_pred = inputs
        x_smooth_new = x_filt + G_inf @ (x_smooth - x_pred)  # (N,)
        return x_smooth_new, x_smooth_new

    init = filter_state.filtered_means[T - 1]
    inputs = (
        filter_state.filtered_means[:-1][::-1],
        filter_state.predicted_means[1:][::-1],
    )

    _, s_means_rev = jax.lax.scan(step, init, inputs)

    s_means = jnp.concatenate(
        [s_means_rev[::-1], filter_state.filtered_means[T - 1 :]],
        axis=0,
    )  # (T, N)
    s_covs = repeat(P_smooth_inf, "n1 n2 -> T n1 n2", T=T)  # (T, N, N)

    return s_means, s_covs

dare(A: Float[Array, 'D D'] | lx.AbstractLinearOperator, H: Float[Array, 'M D'] | lx.AbstractLinearOperator, Q: Float[Array, 'D D'] | lx.AbstractLinearOperator, R: Float[Array, 'M M'] | lx.AbstractLinearOperator, *, P_init: Float[Array, 'D D'] | None = None, max_iter: int = 100, tol: float = 1e-08, solver: AbstractSolverStrategy | None = None, woodbury_innovation: bool = False) -> DAREResult

Discrete Algebraic Riccati Equation solver.

Iterates the Kalman predict-update equations until convergence:

Predict:  P⁻ = A P Aᵀ + Q
Update:   S = H P⁻ Hᵀ + R
          K = P⁻ Hᵀ S⁻¹
          P = (I - KH) P⁻

Convergence is declared when max|P_new - P_old| < tol.

Parameters:

Name Type Description Default
A Float[Array, 'D D'] | AbstractLinearOperator

Transition matrix or operator, shape (D, D).

required
H Float[Array, 'M D'] | AbstractLinearOperator

Observation matrix or operator, shape (M, D).

required
Q Float[Array, 'D D'] | AbstractLinearOperator

Process noise covariance or operator, shape (D, D).

required
R Float[Array, 'M M'] | AbstractLinearOperator

Observation noise covariance or operator, shape (M, M).

required
P_init Float[Array, 'D D'] | None

Initial covariance guess, shape (D, D). Defaults to Q.

None
max_iter int

Maximum number of iterations.

100
tol float

Convergence tolerance on the element-wise max absolute change.

1e-08
solver AbstractSolverStrategy | None

Optional solver strategy for structured linear algebra. When None, falls back to structural dispatch.

None
woodbury_innovation bool

When True, build S = H P⁻ Hᵀ + R as a gaussx.LowRankUpdate so structured R uses Woodbury solves.

False

Returns:

Type Description
DAREResult

A DAREResult containing the steady-state covariance,

DAREResult

Kalman gain, and convergence flag.

Source code in src/gaussx/_ssm/_dare.py
def dare(
    A: Float[Array, "D D"] | lx.AbstractLinearOperator,
    H: Float[Array, "M D"] | lx.AbstractLinearOperator,
    Q: Float[Array, "D D"] | lx.AbstractLinearOperator,
    R: Float[Array, "M M"] | lx.AbstractLinearOperator,
    *,
    P_init: Float[Array, "D D"] | None = None,
    max_iter: int = 100,
    tol: float = 1e-8,
    solver: AbstractSolverStrategy | None = None,
    woodbury_innovation: bool = False,
) -> DAREResult:
    """Discrete Algebraic Riccati Equation solver.

    Iterates the Kalman predict-update equations until convergence:

        Predict:  P⁻ = A P Aᵀ + Q
        Update:   S = H P⁻ Hᵀ + R
                  K = P⁻ Hᵀ S⁻¹
                  P = (I - KH) P⁻

    Convergence is declared when ``max|P_new - P_old| < tol``.

    Args:
        A: Transition matrix or operator, shape ``(D, D)``.
        H: Observation matrix or operator, shape ``(M, D)``.
        Q: Process noise covariance or operator, shape ``(D, D)``.
        R: Observation noise covariance or operator, shape ``(M, M)``.
        P_init: Initial covariance guess, shape ``(D, D)``. Defaults to ``Q``.
        max_iter: Maximum number of iterations.
        tol: Convergence tolerance on the element-wise max absolute change.
        solver: Optional solver strategy for structured linear algebra.
            When ``None``, falls back to structural dispatch.
        woodbury_innovation: When ``True``, build ``S = H P⁻ Hᵀ + R``
            as a `gaussx.LowRankUpdate` so structured ``R`` uses
            Woodbury solves.

    Returns:
        A `DAREResult` containing the steady-state covariance,
        Kalman gain, and convergence flag.
    """
    A_op = _as_operator(A)
    H_op = _as_operator(H)
    Q_dense = _materialise(Q)
    # Keep ``R`` lazy when the Woodbury innovation path will consume the
    # operator directly — avoids an O(M²) allocation for large structured
    # noise (e.g. ``DiagonalLinearOperator`` with large ``M``).
    R_for_innovation = (
        R
        if woodbury_innovation and isinstance(R, lx.AbstractLinearOperator)
        else _materialise(R)
    )

    if P_init is None:
        P_init = Q_dense

    def _step(
        P: Float[Array, "D D"],
    ) -> tuple[Float[Array, "D D"], Float[Array, "D M"]]:
        """One predict-update step. Returns ``(P_new, K)``."""
        P_op = lx.MatrixLinearOperator(P, lx.positive_semidefinite_tag)
        P_pred = sandwich(A_op, P_op).as_matrix() + Q_dense
        # K = P_pred @ H.T @ S⁻¹, computed via a single factorization
        # on the matrix RHS for numerical stability and efficiency.
        S_op = _innovation_covariance(
            H_op, P_pred, R_for_innovation, woodbury=woodbury_innovation
        )
        HP_pred = _left_matmul(H_op, P_pred)
        K = solve_matrix(S_op, HP_pred, solver=solver).T
        P_new = P_pred - K @ HP_pred
        return P_new, K

    def _cond(
        state: tuple[Float[Array, "D D"], int, Bool[Array, ""]],
    ) -> Bool[Array, ""]:
        _, i, converged = state
        return (i < max_iter) & (~converged)

    def _body(
        state: tuple[Float[Array, "D D"], int, Bool[Array, ""]],
    ) -> tuple[Float[Array, "D D"], int, Bool[Array, ""]]:
        P_old, i, _ = state
        P_new, _ = _step(P_old)
        converged = jnp.max(jnp.abs(P_new - P_old)) < tol
        return P_new, i + 1, converged

    init_state = (P_init, 0, jnp.array(False))
    P_inf, _, converged = jax.lax.while_loop(_cond, _body, init_state)

    # Compute the final gain from the converged covariance.
    _, K_inf = _step(P_inf)

    return DAREResult(P_inf=P_inf, K_inf=K_inf, converged=converged)

pairwise_marginals(means: Float[Array, 'T d'], covariances: Float[Array, 'T d d'], cross_covariances: Float[Array, 'Tm1 d d']) -> tuple[Float[Array, 'Tm1 two_d'], Float[Array, 'Tm1 two_d two_d']]

Joint p(x_k, x_{k+1}) for each consecutive pair.

For each pair (k, k+1), the joint distribution is:

p(x_k, x_{k+1}) = N([mu_k; mu_{k+1}],
                     [[P_k,      C_k^T],
                      [C_k,      P_{k+1}]])

where C_k = Cov[x_{k+1}, x_k] is the pairwise cross-covariance.

Parameters:

Name Type Description Default
means Float[Array, 'T d']

Smoothed means, shape (T, d).

required
covariances Float[Array, 'T d d']

Smoothed covariances, shape (T, d, d).

required
cross_covariances Float[Array, 'Tm1 d d']

Pairwise cross-covariances Cov[x_{k+1}, x_k], shape (T-1, d, d).

required

Returns:

Type Description
Float[Array, 'Tm1 two_d']

Tuple (joint_means, joint_covariances) where:

Float[Array, 'Tm1 two_d two_d']
  • joint_means: shape (T-1, 2*d)
tuple[Float[Array, 'Tm1 two_d'], Float[Array, 'Tm1 two_d two_d']]
  • joint_covariances: shape (T-1, 2*d, 2*d)
Source code in src/gaussx/_ssm/_pairwise_marginals.py
def pairwise_marginals(
    means: Float[Array, "T d"],
    covariances: Float[Array, "T d d"],
    cross_covariances: Float[Array, "Tm1 d d"],
) -> tuple[Float[Array, "Tm1 two_d"], Float[Array, "Tm1 two_d two_d"]]:
    r"""Joint p(x_k, x_{k+1}) for each consecutive pair.

    For each pair ``(k, k+1)``, the joint distribution is:

        p(x_k, x_{k+1}) = N([mu_k; mu_{k+1}],
                             [[P_k,      C_k^T],
                              [C_k,      P_{k+1}]])

    where ``C_k = Cov[x_{k+1}, x_k]`` is the pairwise cross-covariance.

    Args:
        means: Smoothed means, shape ``(T, d)``.
        covariances: Smoothed covariances, shape ``(T, d, d)``.
        cross_covariances: Pairwise cross-covariances
            ``Cov[x_{k+1}, x_k]``, shape ``(T-1, d, d)``.

    Returns:
        Tuple ``(joint_means, joint_covariances)`` where:

        - ``joint_means``: shape ``(T-1, 2*d)``
        - ``joint_covariances``: shape ``(T-1, 2*d, 2*d)``
    """

    def _single_pair(
        m_k: Float[Array, " d"],
        m_kp1: Float[Array, " d"],
        P_k: Float[Array, "d d"],
        P_kp1: Float[Array, "d d"],
        C_k: Float[Array, "d d"],
    ) -> tuple[Float[Array, " two_d"], Float[Array, "two_d two_d"]]:
        joint_mean = jnp.concatenate([m_k, m_kp1])
        joint_cov = jnp.block(
            [
                [P_k, C_k.T],
                [C_k, P_kp1],
            ]
        )
        return joint_mean, joint_cov

    joint_means, joint_covariances = jax.vmap(_single_pair)(
        means[:-1],
        means[1:],
        covariances[:-1],
        covariances[1:],
        cross_covariances,
    )

    return joint_means, joint_covariances

SpInGP

State-space (sparse-in-time) GP inference: marginal likelihood and posterior through the SSM representation.

Structured linear algebra and Gaussian primitives for JAX.

spingp_log_likelihood(prior_precision: BlockTriDiag, emission_model: Array, obs_noise: lx.AbstractLinearOperator, observations: Float[Array, 'N d_obs'], *, solver: AbstractSolverStrategy | None = None) -> Float[Array, '']

Log marginal likelihood via sparse inverse GP formulation.

Computes the log marginal likelihood using the precision-form Kalman filter (SpInGP):

1. Likelihood precision sites: $\Lambda_{lik} = H^T R^{-1} H$
2. Posterior precision: $\Lambda_{post} = \Lambda_{prior} + \Lambda_{lik}$
3. log p(y) via banded Cholesky logdet and quadratic form

The full expression is:

log p(y) = -0.5 * (N_{obs} * log(2\pi) + log|R|_{total}
           + y^T R^{-1} y - \eta^T \Lambda_{post}^{-1} \eta
           + log|\Lambda_{post}| - log|\Lambda_{prior}|)

where \(\eta = H^T R^{-1} y\).

All operations exploit banded structure for O(Nd³) cost.

The solver parameter controls the algorithm used for the large-scale posterior precision operations (solve, logdet). Observation noise operations always use structural dispatch since obs_noise is typically a small dense matrix.

Parameters:

Name Type Description Default
prior_precision BlockTriDiag

Prior precision as BlockTriDiag, shape (N, d, d) diagonal and (N-1, d, d) sub-diagonal.

required
emission_model Array

Emission matrix H. Shape (d_obs, d) for shared or (N, d_obs, d) per time step.

required
obs_noise AbstractLinearOperator

Observation noise covariance R operator.

required
observations Float[Array, 'N d_obs']

Observations y, shape (N, d_obs).

required
solver AbstractSolverStrategy | None

Optional solver strategy for posterior precision operations. When None, uses structural dispatch. Observation noise operations always use structural dispatch.

None

Returns:

Type Description
Float[Array, '']

Scalar log marginal likelihood.

Source code in src/gaussx/_ssm/_spingp.py
def spingp_log_likelihood(
    prior_precision: BlockTriDiag,
    emission_model: Array,
    obs_noise: lx.AbstractLinearOperator,
    observations: Float[Array, "N d_obs"],
    *,
    solver: AbstractSolverStrategy | None = None,
) -> Float[Array, ""]:
    r"""Log marginal likelihood via sparse inverse GP formulation.

    Computes the log marginal likelihood using the precision-form
    Kalman filter (SpInGP):

        1. Likelihood precision sites: $\Lambda_{lik} = H^T R^{-1} H$
        2. Posterior precision: $\Lambda_{post} = \Lambda_{prior} + \Lambda_{lik}$
        3. log p(y) via banded Cholesky logdet and quadratic form

    The full expression is:

        log p(y) = -0.5 * (N_{obs} * log(2\pi) + log|R|_{total}
                   + y^T R^{-1} y - \eta^T \Lambda_{post}^{-1} \eta
                   + log|\Lambda_{post}| - log|\Lambda_{prior}|)

    where $\eta = H^T R^{-1} y$.

    All operations exploit banded structure for O(Nd³) cost.

    The ``solver`` parameter controls the algorithm used for the
    large-scale posterior precision operations (solve, logdet).
    Observation noise operations always use structural dispatch
    since ``obs_noise`` is typically a small dense matrix.

    Args:
        prior_precision: Prior precision as ``BlockTriDiag``,
            shape ``(N, d, d)`` diagonal and ``(N-1, d, d)`` sub-diagonal.
        emission_model: Emission matrix H. Shape ``(d_obs, d)`` for
            shared or ``(N, d_obs, d)`` per time step.
        obs_noise: Observation noise covariance R operator.
        observations: Observations y, shape ``(N, d_obs)``.
        solver: Optional solver strategy for posterior precision
            operations. When ``None``, uses structural dispatch.
            Observation noise operations always use structural dispatch.

    Returns:
        Scalar log marginal likelihood.
    """
    N = prior_precision._num_blocks
    d = prior_precision._block_size
    N_obs = observations.size
    log_2pi = jnp.log(2.0 * jnp.pi)

    # Build likelihood precision and posterior precision
    lik_prec = _build_likelihood_precision(emission_model, obs_noise, N, d)
    post_prec = prior_precision.add(lik_prec)

    # Data vector: eta = H^T R^{-1} y
    eta = _build_data_vector(emission_model, obs_noise, observations)

    # Quadratic term: eta^T Lambda_post^{-1} eta
    post_solve = dispatch_solve(post_prec, eta, solver)
    quad_term = jnp.dot(eta, post_solve)

    # Observation quadratic: y^T R^{-1} y (obs_noise is small, use inv)
    R_inv = inv(obs_noise).as_matrix()
    obs_quad = jnp.sum(jax.vmap(lambda y_k: y_k @ R_inv @ y_k)(observations))

    # Log determinants (posterior precision: may be large, use solver)
    ld_post = dispatch_logdet(post_prec, solver)
    ld_prior = dispatch_logdet(prior_precision, solver)

    # Total observation noise logdet: N * log|R| (small, structural dispatch)
    ld_R = logdet(obs_noise)
    ld_R_total = N * ld_R

    return -0.5 * (
        N_obs * log_2pi + ld_R_total + obs_quad - quad_term + ld_post - ld_prior
    )

spingp_posterior(prior_precision: BlockTriDiag, emission_model: Array, obs_noise: lx.AbstractLinearOperator, observations: Float[Array, 'N d_obs'], *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, ' Nd'], BlockTriDiag]

Posterior mean and precision via SpInGP.

Computes the posterior by adding likelihood precision sites to the prior precision and solving for the posterior mean:

\Lambda_{post} = \Lambda_{prior} + H^T R^{-1} H
\mu_{post} = \Lambda_{post}^{-1} H^T R^{-1} y

Parameters:

Name Type Description Default
prior_precision BlockTriDiag

Prior precision as BlockTriDiag.

required
emission_model Array

Emission matrix H. Shape (d_obs, d) for shared or (N, d_obs, d) per time step.

required
obs_noise AbstractLinearOperator

Observation noise covariance R operator.

required
observations Float[Array, 'N d_obs']

Observations y, shape (N, d_obs).

required
solver AbstractSolverStrategy | None

Optional solver strategy for posterior precision operations. When None, uses structural dispatch.

None

Returns:

Type Description
Float[Array, ' Nd']

Tuple (posterior_mean, posterior_precision) where

BlockTriDiag

posterior_mean has shape (N * d,) and

tuple[Float[Array, ' Nd'], BlockTriDiag]

posterior_precision is BlockTriDiag.

Source code in src/gaussx/_ssm/_spingp.py
def spingp_posterior(
    prior_precision: BlockTriDiag,
    emission_model: Array,
    obs_noise: lx.AbstractLinearOperator,
    observations: Float[Array, "N d_obs"],
    *,
    solver: AbstractSolverStrategy | None = None,
) -> tuple[Float[Array, " Nd"], BlockTriDiag]:
    r"""Posterior mean and precision via SpInGP.

    Computes the posterior by adding likelihood precision sites to the
    prior precision and solving for the posterior mean:

        \Lambda_{post} = \Lambda_{prior} + H^T R^{-1} H
        \mu_{post} = \Lambda_{post}^{-1} H^T R^{-1} y

    Args:
        prior_precision: Prior precision as ``BlockTriDiag``.
        emission_model: Emission matrix H. Shape ``(d_obs, d)`` for
            shared or ``(N, d_obs, d)`` per time step.
        obs_noise: Observation noise covariance R operator.
        observations: Observations y, shape ``(N, d_obs)``.
        solver: Optional solver strategy for posterior precision
            operations. When ``None``, uses structural dispatch.

    Returns:
        Tuple ``(posterior_mean, posterior_precision)`` where
        ``posterior_mean`` has shape ``(N * d,)`` and
        ``posterior_precision`` is ``BlockTriDiag``.
    """
    N = prior_precision._num_blocks
    d = prior_precision._block_size

    # Build likelihood precision and posterior precision
    lik_prec = _build_likelihood_precision(emission_model, obs_noise, N, d)
    post_prec = prior_precision.add(lik_prec)

    # Data vector: eta = H^T R^{-1} y
    eta = _build_data_vector(emission_model, obs_noise, observations)

    # Posterior mean: Lambda_post^{-1} eta
    post_mean = dispatch_solve(post_prec, eta, solver)

    return post_mean, post_prec

Sites & natural parameters

Conjugate-computation VI (CVI) site updates and the conversions between SSM moment, expectation, and natural parameterizations used by non-conjugate temporal inference.

Structured linear algebra and Gaussian primitives for JAX.

GaussianSites

Bases: Module

Time-varying Gaussian likelihood sites in natural parameterization.

Stores per-timestep natural parameters for N Gaussian sites, following the \eta_2 = -\tfrac{1}{2}\Lambda convention (consistent with gaussx.mean_cov_to_natural).

Attributes:

Name Type Description
nat1 Float[Array, 'N d']

Natural location parameters, shape (N, d).

nat2 Float[Array, 'N d d']

Natural precision parameters, shape (N, d, d). Stores -\tfrac{1}{2}\Lambda_k at each time step.

Source code in src/gaussx/_ssm/_cvi.py
class GaussianSites(eqx.Module):
    r"""Time-varying Gaussian likelihood sites in natural parameterization.

    Stores per-timestep natural parameters for ``N`` Gaussian sites,
    following the ``\eta_2 = -\tfrac{1}{2}\Lambda`` convention
    (consistent with `gaussx.mean_cov_to_natural`).

    Attributes:
        nat1: Natural location parameters, shape ``(N, d)``.
        nat2: Natural precision parameters, shape ``(N, d, d)``.
            Stores ``-\tfrac{1}{2}\Lambda_k`` at each time step.
    """

    nat1: Float[Array, "N d"]
    nat2: Float[Array, "N d d"]

cvi_update_sites(sites: GaussianSites, grad_nat1: Float[Array, 'N d'], grad_nat2: Float[Array, 'N d d'], rho: float) -> GaussianSites

Natural gradient update for CVI sites.

Performs a damped update in natural parameter space:

\theta \leftarrow (1 - \rho) \theta + \rho \nabla

Parameters:

Name Type Description Default
sites GaussianSites

Current Gaussian sites.

required
grad_nat1 Float[Array, 'N d']

Natural gradient for location, shape (N, d).

required
grad_nat2 Float[Array, 'N d d']

Natural gradient for precision, shape (N, d, d).

required
rho float

Step size / damping factor in [0, 1].

required

Returns:

Type Description
GaussianSites

Updated GaussianSites.

Source code in src/gaussx/_ssm/_cvi.py
def cvi_update_sites(
    sites: GaussianSites,
    grad_nat1: Float[Array, "N d"],
    grad_nat2: Float[Array, "N d d"],
    rho: float,
) -> GaussianSites:
    r"""Natural gradient update for CVI sites.

    Performs a damped update in natural parameter space:

        \theta \leftarrow (1 - \rho) \theta + \rho \nabla

    Args:
        sites: Current Gaussian sites.
        grad_nat1: Natural gradient for location, shape ``(N, d)``.
        grad_nat2: Natural gradient for precision, shape ``(N, d, d)``.
        rho: Step size / damping factor in ``[0, 1]``.

    Returns:
        Updated `GaussianSites`.
    """
    new_nat1 = (1.0 - rho) * sites.nat1 + rho * grad_nat1
    new_nat2 = (1.0 - rho) * sites.nat2 + rho * grad_nat2
    return GaussianSites(nat1=new_nat1, nat2=new_nat2)

sites_to_precision(sites: GaussianSites) -> BlockTriDiag

Convert Gaussian sites to a block-tridiagonal precision.

Returns a block-diagonal BlockTriDiag (zero sub-diagonals) representing the precision contribution of the sites. This can be added to a prior precision via .add() or + to form the posterior precision:

\Lambda_{post} = \Lambda_{prior} + \Lambda_{sites}

Since nat2 stores -\tfrac{1}{2}\Lambda, the precision blocks are -2 \cdot nat2.

Parameters:

Name Type Description Default
sites GaussianSites

Gaussian sites with nat2 in eta2 convention.

required

Returns:

Type Description
BlockTriDiag

Block-diagonal BlockTriDiag precision.

Source code in src/gaussx/_ssm/_cvi.py
def sites_to_precision(sites: GaussianSites) -> BlockTriDiag:
    r"""Convert Gaussian sites to a block-tridiagonal precision.

    Returns a block-diagonal `BlockTriDiag` (zero
    sub-diagonals) representing the precision contribution of the
    sites. This can be added to a prior precision via ``.add()``
    or ``+`` to form the posterior precision:

        \Lambda_{post} = \Lambda_{prior} + \Lambda_{sites}

    Since ``nat2`` stores ``-\tfrac{1}{2}\Lambda``, the precision
    blocks are ``-2 \cdot nat2``.

    Args:
        sites: Gaussian sites with ``nat2`` in eta2 convention.

    Returns:
        Block-diagonal `BlockTriDiag` precision.
    """
    N, d = sites.nat1.shape
    diag_blocks = -2.0 * sites.nat2  # (N, d, d)
    sub_diag_blocks = jnp.zeros((N - 1, d, d), dtype=diag_blocks.dtype)
    return BlockTriDiag(diag_blocks, sub_diag_blocks)

cavity_from_marginal(marg_mean: Float[Array, ' *batch'], marg_var: Float[Array, ' *batch'], site_nat1: Float[Array, ' *batch'], site_nat2: Float[Array, ' *batch']) -> tuple[Float[Array, ' *batch'], Float[Array, ' *batch']]

Compute cavity distribution by removing a site from the marginal.

Parameters:

Name Type Description Default
marg_mean Float[Array, ' *batch']

Marginal distribution means.

required
marg_var Float[Array, ' *batch']

Marginal distribution variances (positive).

required
site_nat1 Float[Array, ' *batch']

Site precision-weighted means to remove.

required
site_nat2 Float[Array, ' *batch']

Site precisions to remove.

required

Returns:

Type Description
tuple[Float[Array, ' *batch'], Float[Array, ' *batch']]

Tuple (cav_mean, cav_var) of the cavity distribution.

Source code in src/gaussx/_ssm/_site_natural.py
def cavity_from_marginal(
    marg_mean: Float[Array, " *batch"],
    marg_var: Float[Array, " *batch"],
    site_nat1: Float[Array, " *batch"],
    site_nat2: Float[Array, " *batch"],
) -> tuple[Float[Array, " *batch"], Float[Array, " *batch"]]:
    """Compute cavity distribution by removing a site from the marginal.

    Args:
        marg_mean: Marginal distribution means.
        marg_var: Marginal distribution variances (positive).
        site_nat1: Site precision-weighted means to remove.
        site_nat2: Site precisions to remove.

    Returns:
        Tuple ``(cav_mean, cav_var)`` of the cavity distribution.
    """
    cav_prec = jnp.reciprocal(marg_var) - site_nat2
    cav_var = jnp.reciprocal(cav_prec)
    cav_mean = (marg_mean * jnp.reciprocal(marg_var) - site_nat1) * cav_var
    return cav_mean, cav_var

site_natural_from_tilted(tilted_mean: Float[Array, ' *batch'], tilted_var: Float[Array, ' *batch'], cav_mean: Float[Array, ' *batch'], cav_var: Float[Array, ' *batch']) -> tuple[Float[Array, ' *batch'], Float[Array, ' *batch']]

Compute site natural parameters from tilted and cavity moments.

Parameters:

Name Type Description Default
tilted_mean Float[Array, ' *batch']

Tilted distribution means.

required
tilted_var Float[Array, ' *batch']

Tilted distribution variances (positive).

required
cav_mean Float[Array, ' *batch']

Cavity distribution means.

required
cav_var Float[Array, ' *batch']

Cavity distribution variances (positive).

required

Returns:

Type Description
tuple[Float[Array, ' *batch'], Float[Array, ' *batch']]

Tuple (site_nat1, site_nat2).

Source code in src/gaussx/_ssm/_site_natural.py
def site_natural_from_tilted(
    tilted_mean: Float[Array, " *batch"],
    tilted_var: Float[Array, " *batch"],
    cav_mean: Float[Array, " *batch"],
    cav_var: Float[Array, " *batch"],
) -> tuple[Float[Array, " *batch"], Float[Array, " *batch"]]:
    """Compute site natural parameters from tilted and cavity moments.

    Args:
        tilted_mean: Tilted distribution means.
        tilted_var: Tilted distribution variances (positive).
        cav_mean: Cavity distribution means.
        cav_var: Cavity distribution variances (positive).

    Returns:
        Tuple ``(site_nat1, site_nat2)``.
    """
    site_nat2 = jnp.reciprocal(tilted_var) - jnp.reciprocal(cav_var)
    site_nat1 = tilted_mean * jnp.reciprocal(tilted_var) - cav_mean * jnp.reciprocal(
        cav_var
    )
    return site_nat1, site_nat2

site_mean_var_from_natural(site_nat1: Float[Array, ' *batch'], site_nat2: Float[Array, ' *batch']) -> tuple[Float[Array, ' *batch'], Float[Array, ' *batch']]

Convert per-site natural parameters to mean/variance.

Parameters:

Name Type Description Default
site_nat1 Float[Array, ' *batch']

Site precision-weighted means.

required
site_nat2 Float[Array, ' *batch']

Site precisions (positive for valid Gaussians).

required

Returns:

Type Description
tuple[Float[Array, ' *batch'], Float[Array, ' *batch']]

Tuple (mean, var) of the equivalent Gaussian site.

Source code in src/gaussx/_ssm/_site_natural.py
def site_mean_var_from_natural(
    site_nat1: Float[Array, " *batch"],
    site_nat2: Float[Array, " *batch"],
) -> tuple[Float[Array, " *batch"], Float[Array, " *batch"]]:
    """Convert per-site natural parameters to mean/variance.

    Args:
        site_nat1: Site precision-weighted means.
        site_nat2: Site precisions (positive for valid Gaussians).

    Returns:
        Tuple ``(mean, var)`` of the equivalent Gaussian site.
    """
    var = jnp.reciprocal(site_nat2)
    mean = site_nat1 * var
    return mean, var

expectations_to_ssm(eta1: Float[Array, ' Nd'], eta2: BlockTriDiag) -> tuple[Float[Array, 'N d'], Float[Array, 'N d d'], Float[Array, 'Nm1 d d']]

Convert expectation parameters back to SSM marginals.

Recovers (means, covs, cross_covs) from the expectation parameters of the joint Gaussian.

Parameters:

Name Type Description Default
eta1 Float[Array, ' Nd']

Concatenated means, shape (N*d,).

required
eta2 BlockTriDiag

Second-moment BlockTriDiag.

required

Returns:

Type Description
Float[Array, 'N d']

Tuple (means, covs, cross_covs) where:

Float[Array, 'N d d']
  • means: shape (N, d)
Float[Array, 'Nm1 d d']
  • covs: shape (N, d, d)
tuple[Float[Array, 'N d'], Float[Array, 'N d d'], Float[Array, 'Nm1 d d']]
  • cross_covs: shape (N-1, d, d)
Source code in src/gaussx/_ssm/_ssm_natural.py
def expectations_to_ssm(
    eta1: Float[Array, " Nd"],
    eta2: BlockTriDiag,
) -> tuple[
    Float[Array, "N d"],
    Float[Array, "N d d"],
    Float[Array, "Nm1 d d"],
]:
    r"""Convert expectation parameters back to SSM marginals.

    Recovers ``(means, covs, cross_covs)`` from the expectation
    parameters of the joint Gaussian.

    Args:
        eta1: Concatenated means, shape ``(N*d,)``.
        eta2: Second-moment `BlockTriDiag`.

    Returns:
        Tuple ``(means, covs, cross_covs)`` where:

        - ``means``: shape ``(N, d)``
        - ``covs``: shape ``(N, d, d)``
        - ``cross_covs``: shape ``(N-1, d, d)``
    """
    d = eta2._block_size
    N = eta2._num_blocks

    means = rearrange(eta1, "(N d) -> N d", N=N, d=d)

    # covs = E[xₖ xₖᵀ] − mₖ mₖᵀ
    covs = eta2.diagonal - einsum(means, means, "N i, N j -> N i j")

    # cross_covs = E[xₖ₊₁ xₖᵀ] − mₖ₊₁ mₖᵀ
    cross_covs = eta2.sub_diagonal - einsum(means[1:], means[:-1], "N i, N j -> N i j")

    return means, covs, cross_covs

naturals_to_ssm(theta_linear: Float[Array, ' Nd'], theta_precision: BlockTriDiag, *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, 'Nm1 d d'], Float[Array, 'N d d'], Float[Array, ' d'], Float[Array, 'd d']]

Convert natural parameters back to SSM parameters.

Recovers (A, Q, \mu_0, P_0) from the block-tridiagonal natural parameters via a backward recurrence on the precision blocks.

Parameters:

Name Type Description Default
theta_linear Float[Array, ' Nd']

Natural location parameter, shape (N*d,).

required
theta_precision BlockTriDiag

Natural precision parameter as BlockTriDiag (eta2 convention).

required
solver AbstractSolverStrategy | None

Optional solver strategy for structured linear algebra. When None, falls back to structural dispatch. This parameter is accepted for API consistency but is not currently used by the matrix inverse operations in this function.

None

Returns:

Type Description
Float[Array, 'Nm1 d d']

Tuple (A, Q, mu_0, P_0) where:

Float[Array, 'N d d']
  • A: Transition matrices, shape (N-1, d, d).
Float[Array, ' d']
  • Q: Process noise covariances, shape (N, d, d).
Float[Array, 'd d']
  • mu_0: Initial mean, shape (d,).
tuple[Float[Array, 'Nm1 d d'], Float[Array, 'N d d'], Float[Array, ' d'], Float[Array, 'd d']]
  • P_0: Initial covariance, shape (d, d).
Source code in src/gaussx/_ssm/_ssm_natural.py
def naturals_to_ssm(
    theta_linear: Float[Array, " Nd"],
    theta_precision: BlockTriDiag,
    *,
    solver: AbstractSolverStrategy | None = None,
) -> tuple[
    Float[Array, "Nm1 d d"],
    Float[Array, "N d d"],
    Float[Array, " d"],
    Float[Array, "d d"],
]:
    r"""Convert natural parameters back to SSM parameters.

    Recovers ``(A, Q, \mu_0, P_0)`` from the block-tridiagonal natural
    parameters via a backward recurrence on the precision blocks.

    Args:
        theta_linear: Natural location parameter, shape ``(N*d,)``.
        theta_precision: Natural precision parameter as
            `BlockTriDiag` (eta2 convention).
        solver: Optional solver strategy for structured linear algebra.
            When ``None``, falls back to structural dispatch. This parameter
            is accepted for API consistency but is not currently used by the
            matrix inverse operations in this function.

    Returns:
        Tuple ``(A, Q, mu_0, P_0)`` where:
        - ``A``: Transition matrices, shape ``(N-1, d, d)``.
        - ``Q``: Process noise covariances, shape ``(N, d, d)``.
        - ``mu_0``: Initial mean, shape ``(d,)``.
        - ``P_0``: Initial covariance, shape ``(d, d)``.
    """
    del solver  # inv does not accept a solver; parameter reserved for future use
    d = theta_precision._block_size

    # Convert from eta2 to raw precision
    prec_diag = -2.0 * theta_precision.diagonal  # (N, d, d)
    prec_sub = -2.0 * theta_precision.sub_diagonal  # (N-1, d, d)

    # Backward recurrence to recover Q and A
    # Start from last block: Q[N-1] = inv(prec_diag[N-1])
    # Then for k = N-2 down to 0:
    #   A[k] = Q[k+1] @ (-prec_sub[k])  (sub-diag was -Q_{k+1}^{-1} A_k)
    #   Q[k] = inv(prec_diag[k] - A[k]^T @ Q[k+1]^{-1} @ A[k])

    def _backward_step(Q_next_inv, inputs):
        diag_k, sub_k = inputs
        Q_next = inv(
            lx.MatrixLinearOperator(Q_next_inv, lx.positive_semidefinite_tag)
        ).as_matrix()
        # sub_k = -Q_{k+1}^{-1} A_k, so A_k = -Q_{k+1} @ sub_k
        A_k = -Q_next @ sub_k
        # Q_k^{-1} = diag_k - A_k^T @ Q_next_inv @ A_k
        Q_k_inv = diag_k - A_k.T @ Q_next_inv @ A_k
        return Q_k_inv, (A_k, Q_k_inv)

    Q_last_inv = prec_diag[-1]

    # Reverse scan: iterate from k=N-2 down to 0
    _, (A_rev, Q_inv_rev) = jax.lax.scan(
        _backward_step,
        Q_last_inv,
        (prec_diag[:-1], prec_sub),
        reverse=True,
    )

    # A_rev is (N-1, d, d), Q_inv_rev is (N-1, d, d) for k=0..N-2
    A = A_rev

    # Q: invert all Q_inv values (batch over N)
    def _inv_single(q_inv):
        return inv(
            lx.MatrixLinearOperator(q_inv, lx.positive_semidefinite_tag)
        ).as_matrix()

    Q_inv_all = jnp.concatenate([Q_inv_rev, Q_last_inv[None]], axis=0)
    Q = jax.vmap(_inv_single)(Q_inv_all)

    # Recover initial conditions
    P_0 = Q[0]
    mu_0 = P_0 @ theta_linear[:d]

    return A, Q, mu_0, P_0

ssm_to_expectations(means: Float[Array, 'N d'], covs: Float[Array, 'N d d'], cross_covs: Float[Array, 'Nm1 d d']) -> tuple[Float[Array, ' Nd'], BlockTriDiag]

Convert SSM marginals to expectation parameters.

Given filtered or smoothed marginals, computes the expectation parameters (eta1, eta2) of the joint Gaussian where:

  • eta1 = E[x] (concatenated means)
  • eta2 is a BlockTriDiag storing the block-tridiagonal subset of E[xx^T] (second moments matching the Gauss-Markov sparsity pattern, not the full dense matrix)

The diagonal blocks of eta2 are E[x_k x_k^T] = P_k + m_k m_k^T and the sub-diagonal blocks are E[x_{k+1} x_k^T] = C_k + m_{k+1} m_k^T where C_k is the cross-covariance Cov(x_{k+1}, x_k).

Parameters:

Name Type Description Default
means Float[Array, 'N d']

Marginal means, shape (N, d).

required
covs Float[Array, 'N d d']

Marginal covariances, shape (N, d, d).

required
cross_covs Float[Array, 'Nm1 d d']

Cross-covariances Cov(x_{k+1}, x_k), shape (N-1, d, d).

required

Returns:

Type Description
Float[Array, ' Nd']

Tuple (eta1, eta2) where eta1 has shape (N*d,)

BlockTriDiag

and eta2 is a BlockTriDiag.

Source code in src/gaussx/_ssm/_ssm_natural.py
def ssm_to_expectations(
    means: Float[Array, "N d"],
    covs: Float[Array, "N d d"],
    cross_covs: Float[Array, "Nm1 d d"],
) -> tuple[Float[Array, " Nd"], BlockTriDiag]:
    r"""Convert SSM marginals to expectation parameters.

    Given filtered or smoothed marginals, computes the expectation
    parameters ``(eta1, eta2)`` of the joint Gaussian where:

    - ``eta1 = E[x]`` (concatenated means)
    - ``eta2`` is a `BlockTriDiag` storing the
      block-tridiagonal subset of ``E[xx^T]`` (second moments matching
      the Gauss-Markov sparsity pattern, not the full dense matrix)

    The diagonal blocks of ``eta2`` are ``E[x_k x_k^T] = P_k + m_k m_k^T``
    and the sub-diagonal blocks are
    ``E[x_{k+1} x_k^T] = C_k + m_{k+1} m_k^T`` where ``C_k`` is the
    cross-covariance ``Cov(x_{k+1}, x_k)``.

    Args:
        means: Marginal means, shape ``(N, d)``.
        covs: Marginal covariances, shape ``(N, d, d)``.
        cross_covs: Cross-covariances ``Cov(x_{k+1}, x_k)``,
            shape ``(N-1, d, d)``.

    Returns:
        Tuple ``(eta1, eta2)`` where ``eta1`` has shape ``(N*d,)``
        and ``eta2`` is a `BlockTriDiag`.
    """
    _N, _d = means.shape

    # eta1 = concatenated means
    eta1 = rearrange(means, "N d -> (N d)")

    # Diagonal blocks: E[xₖ xₖᵀ] = Pₖ + mₖ mₖᵀ
    diag = covs + einsum(means, means, "N i, N j -> N i j")  # (N, d, d)

    # Sub-diagonal blocks: E[xₖ₊₁ xₖᵀ] = Cₖ + mₖ₊₁ mₖᵀ
    sub_diag = cross_covs + einsum(
        means[1:], means[:-1], "N i, N j -> N i j"
    )  # (N-1, d, d)

    eta2 = BlockTriDiag(diag, sub_diag)
    return eta1, eta2

ssm_to_naturals(A: Float[Array, 'Nm1 d d'], Q: Float[Array, 'N d d'], mu_0: Float[Array, ' d'], P_0: Float[Array, 'd d'], *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, ' Nd'], BlockTriDiag]

Convert SSM parameters to natural parameters.

For a linear-Gaussian state-space model:

x_0 \sim N(\mu_0, P_0)
x_{k+1} = A_k x_k + \epsilon_k,\quad \epsilon_k \sim N(0, Q_{k+1})

the joint prior p(x_0, \ldots, x_{N-1}) has a block-tridiagonal precision matrix. This function returns its natural parameters (\theta_1, \theta_2) where \theta_2 = -\tfrac{1}{2}\Lambda (matching the convention in gaussx.mean_cov_to_natural).

Parameters:

Name Type Description Default
A Float[Array, 'Nm1 d d']

Transition matrices, shape (N-1, d, d).

required
Q Float[Array, 'N d d']

Process noise covariances, shape (N, d, d). Q[0] must equal P_0 and Q[k] for k >= 1 is the process noise at step k.

required
mu_0 Float[Array, ' d']

Initial mean, shape (d,).

required
P_0 Float[Array, 'd d']

Initial covariance, shape (d, d).

required
solver AbstractSolverStrategy | None

Optional solver strategy for structured linear algebra. When None, falls back to structural dispatch.

None

Returns:

Type Description
Float[Array, ' Nd']

Tuple (theta_linear, theta_precision) where

BlockTriDiag

theta_linear has shape (N*d,) and

tuple[Float[Array, ' Nd'], BlockTriDiag]

theta_precision is a BlockTriDiag

tuple[Float[Array, ' Nd'], BlockTriDiag]

in the eta_2 = -0.5 * Lambda convention.

Source code in src/gaussx/_ssm/_ssm_natural.py
def ssm_to_naturals(
    A: Float[Array, "Nm1 d d"],
    Q: Float[Array, "N d d"],
    mu_0: Float[Array, " d"],
    P_0: Float[Array, "d d"],
    *,
    solver: AbstractSolverStrategy | None = None,
) -> tuple[Float[Array, " Nd"], BlockTriDiag]:
    r"""Convert SSM parameters to natural parameters.

    For a linear-Gaussian state-space model:

        x_0 \sim N(\mu_0, P_0)
        x_{k+1} = A_k x_k + \epsilon_k,\quad \epsilon_k \sim N(0, Q_{k+1})

    the joint prior ``p(x_0, \ldots, x_{N-1})`` has a block-tridiagonal
    precision matrix. This function returns its natural parameters
    ``(\theta_1, \theta_2)`` where ``\theta_2 = -\tfrac{1}{2}\Lambda``
    (matching the convention in `gaussx.mean_cov_to_natural`).

    Args:
        A: Transition matrices, shape ``(N-1, d, d)``.
        Q: Process noise covariances, shape ``(N, d, d)``.
            ``Q[0]`` must equal ``P_0`` and ``Q[k]`` for ``k >= 1`` is the
            process noise at step ``k``.
        mu_0: Initial mean, shape ``(d,)``.
        P_0: Initial covariance, shape ``(d, d)``.
        solver: Optional solver strategy for structured linear algebra.
            When ``None``, falls back to structural dispatch.

    Returns:
        Tuple ``(theta_linear, theta_precision)`` where
        ``theta_linear`` has shape ``(N*d,)`` and
        ``theta_precision`` is a `BlockTriDiag`
        in the ``eta_2 = -0.5 * Lambda`` convention.
    """
    N = Q.shape[0]
    d = Q.shape[1]

    try:
        q0_matches_p0 = bool(jnp.allclose(Q[0], P_0))
    except jax.errors.TracerBoolConversionError:
        q0_matches_p0 = True  # skip validation under jax.jit

    if not q0_matches_p0:
        msg = "Q[0] must match P_0 so the returned natural parameters are consistent"
        raise ValueError(msg)

    # Invert all process noise covariances (batch over N)
    def _inv_single(q):
        return inv(lx.MatrixLinearOperator(q, lx.positive_semidefinite_tag)).as_matrix()

    Q_inv = jax.vmap(_inv_single)(Q)  # (N, d, d)
    P_0_op = lx.MatrixLinearOperator(P_0, lx.positive_semidefinite_tag)
    P_0_inv = inv(P_0_op).as_matrix()

    # Future contributions: A_k^T Q_{k+1}^{-1} A_k for k = 0..N-2
    future = jax.vmap(lambda Ak, Qinv_kp1: Ak.T @ Qinv_kp1 @ Ak)(
        A, Q_inv[1:]
    )  # (N-1, d, d)

    # Precision diagonal blocks (raw Lambda, not eta2)
    # D[0] = P_0^{-1} + A[0]^T Q[1]^{-1} A[0]
    # D[k] = Q[k]^{-1} + A[k]^T Q[k+1]^{-1} A[k]  for k=1..N-2
    # D[N-1] = Q[N-1]^{-1}
    diag = jnp.zeros((N, d, d), dtype=Q.dtype)
    diag = diag.at[0].set(P_0_inv + future[0] if N > 1 else P_0_inv)
    if N > 2:
        diag = diag.at[1:-1].set(Q_inv[1:-1] + future[1:])
    diag = diag.at[-1].set(Q_inv[-1])

    # Sub-diagonal blocks (raw precision off-diagonal)
    # S[k] = -Q[k+1]^{-1} A[k]  for k=0..N-2
    # (negative because precision cross-terms are negative for transitions)
    sub_diag = jax.vmap(lambda Qinv_kp1, Ak: -Qinv_kp1 @ Ak)(
        Q_inv[1:], A
    )  # (N-1, d, d)

    # Convert to eta2 convention: theta_precision = -0.5 * Lambda
    theta_precision = BlockTriDiag(-0.5 * diag, -0.5 * sub_diag)

    # Linear natural parameter: eta1 = Lambda @ mu
    # For zero-mean transitions, only the initial condition contributes
    theta_linear = jnp.zeros(N * d, dtype=Q.dtype)
    eta1_0 = dispatch_solve(P_0_op, mu_0, solver)
    theta_linear = theta_linear.at[:d].set(eta1_0)

    return theta_linear, theta_precision

Process noise

Structured linear algebra and Gaussian primitives for JAX.

process_noise_covariance(A: Float[Array, 'N N'], Pinf: Float[Array, 'N N']) -> Float[Array, 'N N']

Compute process noise from stationary covariance.

Computes:

Q = Pinf - A @ Pinf @ A^T

For a discrete-time state-space model with stationary covariance Pinf and transition matrix A.

Parameters:

Name Type Description Default
A Float[Array, 'N N']

State transition matrix, shape (N, N).

required
Pinf Float[Array, 'N N']

Stationary covariance, shape (N, N).

required

Returns:

Type Description
Float[Array, 'N N']

Process noise covariance Q, shape (N, N).

Source code in src/gaussx/_inference/_inference.py
def process_noise_covariance(
    A: Float[Array, "N N"],
    Pinf: Float[Array, "N N"],
) -> Float[Array, "N N"]:
    """Compute process noise from stationary covariance.

    Computes:

        Q = Pinf - A @ Pinf @ A^T

    For a discrete-time state-space model with stationary covariance
    ``Pinf`` and transition matrix ``A``.

    Args:
        A: State transition matrix, shape ``(N, N)``.
        Pinf: Stationary covariance, shape ``(N, N)``.

    Returns:
        Process noise covariance Q, shape ``(N, N)``.
    """
    return Pinf - A @ Pinf @ A.T