Skip to content

Bayesian Inference & Ensembles

Layer 3 recipes for conjugate updates, second-order variational steps, and ensemble data assimilation. All covariances are operators, so the updates inherit structured solves; all stochastic routines take explicit PRNG keys.

Bayesian linear regression

Closed-form Gaussian posterior updates — full covariance or diagonal-only — plus the marginal likelihood and expected log-likelihood that score them.

Structured linear algebra and Gaussian primitives for JAX.

blr_full_update(nat1: Float[Array, ' d'], nat2: Float[Array, 'd d'], grad: Float[Array, ' d'], hessian: Float[Array, 'd d'], lr: float, *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, ' d'], Float[Array, 'd d']]

Full-rank natural parameter BLR update step.

Computes the damped update for full-rank variational parameters:

nat2_{new} = (1 - lr) \cdot nat2 + lr \cdot (-\tfrac{1}{2}(-H))
\mu = solve(-2 \cdot nat2, nat1)
nat1_{new} = (1 - lr) \cdot nat1 + lr \cdot (grad - H \mu)

Parameters:

Name Type Description Default
nat1 Float[Array, ' d']

Current natural location, shape (d,).

required
nat2 Float[Array, 'd d']

Current natural precision matrix (eta2), shape (d, d).

required
grad Float[Array, ' d']

Gradient of log-likelihood, shape (d,).

required
hessian Float[Array, 'd d']

Hessian of log-likelihood (negative for log-concave), shape (d, d).

required
lr float

Learning rate / damping factor.

required
solver AbstractSolverStrategy | None

Optional solver strategy for structured linear algebra. When None, falls back to structural dispatch.

None

Returns:

Type Description
tuple[Float[Array, ' d'], Float[Array, 'd d']]

Tuple (nat1_new, nat2_new) — updated natural parameters.

Source code in src/gaussx/_inference/_blr.py
def blr_full_update(
    nat1: Float[Array, " d"],
    nat2: Float[Array, "d d"],
    grad: Float[Array, " d"],
    hessian: Float[Array, "d d"],
    lr: float,
    *,
    solver: AbstractSolverStrategy | None = None,
) -> tuple[Float[Array, " d"], Float[Array, "d d"]]:
    r"""Full-rank natural parameter BLR update step.

    Computes the damped update for full-rank variational parameters:

        nat2_{new} = (1 - lr) \cdot nat2 + lr \cdot (-\tfrac{1}{2}(-H))
        \mu = solve(-2 \cdot nat2, nat1)
        nat1_{new} = (1 - lr) \cdot nat1 + lr \cdot (grad - H \mu)

    Args:
        nat1: Current natural location, shape ``(d,)``.
        nat2: Current natural precision matrix (eta2), shape ``(d, d)``.
        grad: Gradient of log-likelihood, shape ``(d,)``.
        hessian: Hessian of log-likelihood (negative for log-concave),
            shape ``(d, d)``.
        lr: Learning rate / damping factor.
        solver: Optional solver strategy for structured linear algebra.
            When ``None``, falls back to structural dispatch.

    Returns:
        Tuple ``(nat1_new, nat2_new)`` — updated natural parameters.
    """
    # Current mean from natural parameters: mu = solve(-2*eta2, eta1)
    Lambda = -2.0 * nat2
    Lambda_op = lx.MatrixLinearOperator(Lambda, lx.positive_semidefinite_tag)
    mu = dispatch_solve(Lambda_op, nat1, solver)

    # Target natural parameters from Newton step
    nat1_target = grad - hessian @ mu
    nat2_target = 0.5 * hessian  # eta2 = -0.5 * (-H) = 0.5 * H

    # Damped update
    nat1_new = (1.0 - lr) * nat1 + lr * nat1_target
    nat2_new = (1.0 - lr) * nat2 + lr * nat2_target

    return nat1_new, nat2_new

blr_diag_update(nat1: Float[Array, ' d'], nat2_diag: Float[Array, ' d'], grad: Float[Array, ' d'], hessian_diag: Float[Array, ' d'], lr: float) -> tuple[Float[Array, ' d'], Float[Array, ' d']]

Diagonal natural parameter BLR update step.

Computes the damped update for diagonal variational parameters:

\mu = nat1 / (-2 \cdot nat2)
eta2_{target} = -\tfrac{1}{2}(-hessian\_diag) = 0.5 \cdot hessian\_diag
eta1_{target} = grad - hessian\_diag \cdot \mu
nat1_{new} = (1 - lr) \cdot nat1 + lr \cdot eta1_{target}
nat2_{new} = (1 - lr) \cdot nat2 + lr \cdot eta2_{target}

where nat2 (eta2) stores -\tfrac{1}{2} \lambda with \lambda = -hessian\_diag (diagonal precision).

Parameters:

Name Type Description Default
nat1 Float[Array, ' d']

Current natural location, shape (d,).

required
nat2_diag Float[Array, ' d']

Current diagonal natural precision (eta2), shape (d,).

required
grad Float[Array, ' d']

Gradient of log-likelihood, shape (d,).

required
hessian_diag Float[Array, ' d']

Diagonal of Hessian (negative for log-concave), shape (d,).

required
lr float

Learning rate / damping factor.

required

Returns:

Type Description
tuple[Float[Array, ' d'], Float[Array, ' d']]

Tuple (nat1_new, nat2_new) — updated natural parameters.

Source code in src/gaussx/_inference/_blr.py
def blr_diag_update(
    nat1: Float[Array, " d"],
    nat2_diag: Float[Array, " d"],
    grad: Float[Array, " d"],
    hessian_diag: Float[Array, " d"],
    lr: float,
) -> tuple[Float[Array, " d"], Float[Array, " d"]]:
    r"""Diagonal natural parameter BLR update step.

    Computes the damped update for diagonal variational parameters:

        \mu = nat1 / (-2 \cdot nat2)
        eta2_{target} = -\tfrac{1}{2}(-hessian\_diag) = 0.5 \cdot hessian\_diag
        eta1_{target} = grad - hessian\_diag \cdot \mu
        nat1_{new} = (1 - lr) \cdot nat1 + lr \cdot eta1_{target}
        nat2_{new} = (1 - lr) \cdot nat2 + lr \cdot eta2_{target}

    where ``nat2`` (eta2) stores ``-\tfrac{1}{2} \lambda`` with
    ``\lambda = -hessian\_diag`` (diagonal precision).

    Args:
        nat1: Current natural location, shape ``(d,)``.
        nat2_diag: Current diagonal natural precision (eta2), shape ``(d,)``.
        grad: Gradient of log-likelihood, shape ``(d,)``.
        hessian_diag: Diagonal of Hessian (negative for log-concave),
            shape ``(d,)``.
        lr: Learning rate / damping factor.

    Returns:
        Tuple ``(nat1_new, nat2_new)`` — updated natural parameters.
    """
    # Current mean from natural parameters
    mu = nat1 / (-2.0 * nat2_diag)

    # Target natural parameters from Newton step
    nat1_target = grad - hessian_diag * mu
    nat2_target = -0.5 * (-hessian_diag)  # eta2 = -0.5 * (-H) = 0.5 * H

    # Damped update
    nat1_new = (1.0 - lr) * nat1 + lr * nat1_target
    nat2_new = (1.0 - lr) * nat2_diag + lr * nat2_target

    return nat1_new, nat2_new

log_marginal_likelihood(loc: Float[Array, ' N'], cov_operator: lx.AbstractLinearOperator, y: Float[Array, ' N'], *, solver: AbstractSolverStrategy | None = None) -> Float[Array, '']

GP log marginal likelihood.

Computes:

log p(y) = -0.5 * (y-mu)^T K^{-1} (y-mu) - 0.5 * log|K| - N/2 * log(2pi)

Delegates to gaussx.gaussian_log_prob.

Parameters:

Name Type Description Default
loc Float[Array, ' N']

Prior mean, shape (N,).

required
cov_operator AbstractLinearOperator

Covariance operator K, shape (N, N).

required
y Float[Array, ' N']

Observations, shape (N,).

required
solver AbstractSolverStrategy | None

Optional solver strategy. When None, uses structural dispatch.

None

Returns:

Type Description
Float[Array, '']

Scalar log marginal likelihood.

Source code in src/gaussx/_inference/_inference.py
def log_marginal_likelihood(
    loc: Float[Array, " N"],
    cov_operator: lx.AbstractLinearOperator,
    y: Float[Array, " N"],
    *,
    solver: AbstractSolverStrategy | None = None,
) -> Float[Array, ""]:
    """GP log marginal likelihood.

    Computes:

        log p(y) = -0.5 * (y-mu)^T K^{-1} (y-mu) - 0.5 * log|K| - N/2 * log(2pi)

    Delegates to `gaussx.gaussian_log_prob`.

    Args:
        loc: Prior mean, shape ``(N,)``.
        cov_operator: Covariance operator K, shape ``(N, N)``.
        y: Observations, shape ``(N,)``.
        solver: Optional solver strategy. When ``None``, uses
            structural dispatch.

    Returns:
        Scalar log marginal likelihood.
    """
    return gaussian_log_prob(loc, cov_operator, y, solver=solver)

gaussian_expected_log_lik(y: Float[Array, ' N'], q_mu: Float[Array, ' N'], q_cov: lx.AbstractLinearOperator, noise: lx.AbstractLinearOperator, *, solver: AbstractSolverStrategy | None = None) -> Float[Array, '']

Expected log-likelihood E_q[log N(y | f, R)].

Computes:

E_q[log N(y|f,R)] = log N(y | q_mu, R) - 0.5 * tr(R^{-1} q_cov)

Core to variational inference ELBO computation.

Parameters:

Name Type Description Default
y Float[Array, ' N']

Observations, shape (N,).

required
q_mu Float[Array, ' N']

Variational mean, shape (N,).

required
q_cov AbstractLinearOperator

Variational covariance operator, shape (N, N).

required
noise AbstractLinearOperator

Noise covariance operator R, shape (N, N).

required
solver AbstractSolverStrategy | None

Optional solver strategy. When None, uses structural dispatch.

None

Returns:

Type Description
Float[Array, '']

Scalar expected log-likelihood.

Source code in src/gaussx/_inference/_inference.py
def gaussian_expected_log_lik(
    y: Float[Array, " N"],
    q_mu: Float[Array, " N"],
    q_cov: lx.AbstractLinearOperator,
    noise: lx.AbstractLinearOperator,
    *,
    solver: AbstractSolverStrategy | None = None,
) -> Float[Array, ""]:
    r"""Expected log-likelihood ``E_q[log N(y | f, R)]``.

    Computes:

        E_q[log N(y|f,R)] = log N(y | q_mu, R) - 0.5 * tr(R^{-1} q_cov)

    Core to variational inference ELBO computation.

    Args:
        y: Observations, shape ``(N,)``.
        q_mu: Variational mean, shape ``(N,)``.
        q_cov: Variational covariance operator, shape ``(N, N)``.
        noise: Noise covariance operator R, shape ``(N, N)``.
        solver: Optional solver strategy. When ``None``, uses
            structural dispatch.

    Returns:
        Scalar expected log-likelihood.
    """
    N = y.shape[-1]
    residual = y - q_mu
    alpha = dispatch_solve(noise, residual, solver)
    quad = residual @ alpha
    ld = dispatch_logdet(noise, solver)

    # Trace correction: tr(R^{-1} q_cov)
    R_inv = inv(noise)
    from gaussx._linalg._linalg import trace_product

    tr_term = trace_product(R_inv, q_cov)

    return -0.5 * (N * _LOG_2PI + ld + quad + tr_term)

Newton & natural-gradient updates

Second-order variational steps: Newton's method on the variational objective, Gauss-Newton curvature (exact diagonal or Hutchinson-estimated), damped natural-gradient steps, and the PSD projection that keeps Riemannian updates on the manifold.

Structured linear algebra and Gaussian primitives for JAX.

newton_update(mean: Float[Array, ' N'], jacobian: Float[Array, ' N'], hessian: Float[Array, 'N N']) -> tuple[Float[Array, ' N'], Float[Array, 'N N']]

Convert a Newton step to natural pseudo-likelihood parameters.

Computes:

nat1 = jacobian - hessian @ mean
nat2 = -hessian

Used in Laplace/Newton-based approximate inference to convert function-space derivatives into site natural parameters.

Parameters:

Name Type Description Default
mean Float[Array, ' N']

Current mean, shape (N,) or (D,).

required
jacobian Float[Array, ' N']

First derivative of log-likelihood, shape (N,).

required
hessian Float[Array, 'N N']

Second derivative (negative definite), shape (N, N).

required

Returns:

Type Description
tuple[Float[Array, ' N'], Float[Array, 'N N']]

Tuple (nat1, nat2) — site natural parameters.

Source code in src/gaussx/_inference/_inference.py
def newton_update(
    mean: Float[Array, " N"],
    jacobian: Float[Array, " N"],
    hessian: Float[Array, "N N"],
) -> tuple[Float[Array, " N"], Float[Array, "N N"]]:
    """Convert a Newton step to natural pseudo-likelihood parameters.

    Computes:

        nat1 = jacobian - hessian @ mean
        nat2 = -hessian

    Used in Laplace/Newton-based approximate inference to convert
    function-space derivatives into site natural parameters.

    Args:
        mean: Current mean, shape ``(N,)`` or ``(D,)``.
        jacobian: First derivative of log-likelihood, shape ``(N,)``.
        hessian: Second derivative (negative definite), shape ``(N, N)``.

    Returns:
        Tuple ``(nat1, nat2)`` — site natural parameters.
    """
    nat1 = jacobian - hessian @ mean
    nat2 = -hessian
    return nat1, nat2

damped_natural_update(nat1_old: Float[Array, ' d'], nat2_old: lx.AbstractLinearOperator | Float[Array, 'd d'], nat1_target: Float[Array, ' d'], nat2_target: lx.AbstractLinearOperator | Float[Array, 'd d'], lr: float = 1.0) -> tuple[Float[Array, ' d'], lx.AbstractLinearOperator | Float[Array, 'd d']]

Damped update in natural parameter space.

The universal primitive for iterative approximate inference (EP, VI, Newton, PL). Every method reduces to computing target natural parameters and applying this damped update:

nat1_{new} = (1 - lr) \cdot nat1_{old} + lr \cdot nat1_{target}
nat2_{new} = (1 - lr) \cdot nat2_{old} + lr \cdot nat2_{target}

Parameters:

Name Type Description Default
nat1_old Float[Array, ' d']

Current natural location parameter.

required
nat2_old AbstractLinearOperator | Float[Array, 'd d']

Current natural precision-like parameter. Can be an array, BlockTriDiag, or any linear operator.

required
nat1_target Float[Array, ' d']

Target natural location parameter.

required
nat2_target AbstractLinearOperator | Float[Array, 'd d']

Target natural precision-like parameter.

required
lr float

Learning rate / damping factor. lr=1 gives the undamped update. Default 1.0.

1.0

Returns:

Type Description
tuple[Float[Array, ' d'], AbstractLinearOperator | Float[Array, 'd d']]

Tuple (nat1_new, nat2_new) with same types as inputs.

Source code in src/gaussx/_inference/_natural_gradient.py
def damped_natural_update(
    nat1_old: Float[Array, " d"],
    nat2_old: lx.AbstractLinearOperator | Float[Array, "d d"],
    nat1_target: Float[Array, " d"],
    nat2_target: lx.AbstractLinearOperator | Float[Array, "d d"],
    lr: float = 1.0,
) -> tuple[Float[Array, " d"], lx.AbstractLinearOperator | Float[Array, "d d"]]:
    r"""Damped update in natural parameter space.

    The universal primitive for iterative approximate inference
    (EP, VI, Newton, PL). Every method reduces to computing target
    natural parameters and applying this damped update:

        nat1_{new} = (1 - lr) \cdot nat1_{old} + lr \cdot nat1_{target}
        nat2_{new} = (1 - lr) \cdot nat2_{old} + lr \cdot nat2_{target}

    Args:
        nat1_old: Current natural location parameter.
        nat2_old: Current natural precision-like parameter.
            Can be an array, ``BlockTriDiag``, or any linear operator.
        nat1_target: Target natural location parameter.
        nat2_target: Target natural precision-like parameter.
        lr: Learning rate / damping factor. ``lr=1`` gives the
            undamped update. Default ``1.0``.

    Returns:
        Tuple ``(nat1_new, nat2_new)`` with same types as inputs.
    """
    nat1_new = (1.0 - lr) * nat1_old + lr * nat1_target

    if isinstance(nat2_old, jax.Array) and isinstance(nat2_target, jax.Array):
        nat2_new: lx.AbstractLinearOperator | Float[Array, "d d"] = (
            1.0 - lr
        ) * nat2_old + lr * nat2_target
    elif isinstance(nat2_old, BlockTriDiag) and isinstance(nat2_target, BlockTriDiag):
        nat2_new = (1.0 - lr) * nat2_old + lr * nat2_target
    elif isinstance(nat2_old, lx.AbstractLinearOperator) and isinstance(
        nat2_target, lx.AbstractLinearOperator
    ):
        nat2_new_mat = (1.0 - lr) * nat2_old.as_matrix() + lr * nat2_target.as_matrix()
        nat2_new = lx.MatrixLinearOperator(nat2_new_mat)
    else:
        msg = "nat2_old and nat2_target must be the same type"
        raise TypeError(msg)

    return nat1_new, nat2_new

gauss_newton_precision(jacobian: Float[Array, 'D_obs D_latent']) -> lx.AbstractLinearOperator

Gauss-Newton precision matrix J^T J.

For likelihoods with residual structure r(f), the Gauss-Newton Hessian approximation is -J_r^T J_r which gives precision \Lambda = J^T J (always PSD).

When D_{obs} < D_{latent}, returns a LowRankUpdate to enable efficient Woodbury-based solves downstream.

Parameters:

Name Type Description Default
jacobian Float[Array, 'D_obs D_latent']

Jacobian of the residual, shape (D_obs, D_latent).

required

Returns:

Type Description
AbstractLinearOperator

PSD precision operator of shape (D_latent, D_latent).

Source code in src/gaussx/_inference/_natural_gradient.py
def gauss_newton_precision(
    jacobian: Float[Array, "D_obs D_latent"],
) -> lx.AbstractLinearOperator:
    r"""Gauss-Newton precision matrix ``J^T J``.

    For likelihoods with residual structure ``r(f)``, the Gauss-Newton
    Hessian approximation is ``-J_r^T J_r`` which gives precision
    ``\Lambda = J^T J`` (always PSD).

    When ``D_{obs} < D_{latent}``, returns a `LowRankUpdate`
    to enable efficient Woodbury-based solves downstream.

    Args:
        jacobian: Jacobian of the residual, shape ``(D_obs, D_latent)``.

    Returns:
        PSD precision operator of shape ``(D_latent, D_latent)``.
    """
    D_obs, D_latent = jacobian.shape

    if D_obs < D_latent:
        base = lx.DiagonalLinearOperator(jnp.zeros(D_latent))
        return LowRankUpdate(
            base=base,
            U=jacobian.T,
            d=jnp.ones(D_obs),
            tags=frozenset({lx.symmetric_tag, lx.positive_semidefinite_tag}),
        )

    return lx.MatrixLinearOperator(
        jacobian.T @ jacobian,
        lx.positive_semidefinite_tag,
    )

ggn_diagonal(jacobian: Float[Array, 'N d']) -> Float[Array, ' d']

Generalized Gauss-Newton diagonal approximation.

Computes \mathrm{diag}(J^T J) = \sum_i J_{i,:}^2, the diagonal of the Gauss-Newton Hessian approximation. Always non-negative, guaranteeing PSD precision updates.

Parameters:

Name Type Description Default
jacobian Float[Array, 'N d']

Jacobian matrix, shape (N, d) where N is the number of observations and d is the parameter dimension.

required

Returns:

Type Description
Float[Array, ' d']

Diagonal of J^T J, shape (d,).

Source code in src/gaussx/_inference/_blr.py
def ggn_diagonal(
    jacobian: Float[Array, "N d"],
) -> Float[Array, " d"]:
    r"""Generalized Gauss-Newton diagonal approximation.

    Computes ``\mathrm{diag}(J^T J) = \sum_i J_{i,:}^2``, the diagonal
    of the Gauss-Newton Hessian approximation. Always non-negative,
    guaranteeing PSD precision updates.

    Args:
        jacobian: Jacobian matrix, shape ``(N, d)`` where N is the
            number of observations and d is the parameter dimension.

    Returns:
        Diagonal of ``J^T J``, shape ``(d,)``.
    """
    return reduce(jacobian**2, "K D -> D", "sum")

hutchinson_hessian_diag(hvp_fn: Callable[[Float[Array, ' d']], Float[Array, ' d']], key: jax.Array, d: int, n_samples: int = 1, dtype: DTypeLike | None = None) -> Float[Array, ' d']

Stochastic Hessian diagonal via Hutchinson with Rademacher probes.

Estimates \mathrm{diag}(H) using the identity \mathrm{diag}(H) = E[z \odot (H z)] where z is a Rademacher random vector (entries \pm 1 with equal probability).

Parameters:

Name Type Description Default
hvp_fn Callable[[Float[Array, ' d']], Float[Array, ' d']]

Hessian-vector product function v -> H @ v.

required
key Array

PRNG key for random probe generation.

required
d int

Dimension of the Hessian.

required
n_samples int

Number of random probes. More samples give better estimates. Default 1.

1
dtype DTypeLike | None

Floating-point dtype for the Rademacher probes. Defaults to the current JAX default floating dtype.

None

Returns:

Type Description
Float[Array, ' d']

Estimated diagonal of the Hessian, shape (d,).

Source code in src/gaussx/_inference/_blr.py
def hutchinson_hessian_diag(
    hvp_fn: Callable[[Float[Array, " d"]], Float[Array, " d"]],
    key: jax.Array,
    d: int,
    n_samples: int = 1,
    dtype: DTypeLike | None = None,
) -> Float[Array, " d"]:
    r"""Stochastic Hessian diagonal via Hutchinson with Rademacher probes.

    Estimates ``\mathrm{diag}(H)`` using the identity
    ``\mathrm{diag}(H) = E[z \odot (H z)]`` where ``z`` is a
    Rademacher random vector (entries ``\pm 1`` with equal probability).

    Args:
        hvp_fn: Hessian-vector product function ``v -> H @ v``.
        key: PRNG key for random probe generation.
        d: Dimension of the Hessian.
        n_samples: Number of random probes. More samples give better
            estimates. Default ``1``.
        dtype: Floating-point dtype for the Rademacher probes. Defaults to the
            current JAX default floating dtype.

    Returns:
        Estimated diagonal of the Hessian, shape ``(d,)``.
    """

    probe_dtype = jnp.dtype(jnp.asarray(0.0).dtype if dtype is None else dtype)

    def _single_probe(k):
        z = jnp.where(
            jax.random.bernoulli(k, shape=(d,)),
            jnp.array(1.0, dtype=probe_dtype),
            jnp.array(-1.0, dtype=probe_dtype),
        )
        return z * hvp_fn(z)

    keys = jax.random.split(key, n_samples)
    estimates = jax.vmap(_single_probe)(keys)
    return jnp.mean(estimates, axis=0)

riemannian_psd_correction(hessian: Float[Array, 'd d'], site_precision: Float[Array, 'd d'], site_covariance: Float[Array, 'd d'], lr: float = 1.0) -> Float[Array, 'd d']

Riemannian gradient correction for PSD precision updates.

Ensures the corrected Hessian remains negative semi-definite, stabilizing Newton/EP/VI when the raw Hessian is indefinite:

G = site\_precision + hessian
H_{psd} = hessian - 0.5 \cdot lr \cdot G \cdot S \cdot G

where S is the site covariance.

Parameters:

Name Type Description Default
hessian Float[Array, 'd d']

Raw second derivative, shape (d, d).

required
site_precision Float[Array, 'd d']

Current site precision, shape (d, d).

required
site_covariance Float[Array, 'd d']

Current site covariance, shape (d, d).

required
lr float

Learning rate. Default 1.0.

1.0

Returns:

Type Description
Float[Array, 'd d']

Corrected Hessian, shape (d, d).

Source code in src/gaussx/_inference/_natural_gradient.py
def riemannian_psd_correction(
    hessian: Float[Array, "d d"],
    site_precision: Float[Array, "d d"],
    site_covariance: Float[Array, "d d"],
    lr: float = 1.0,
) -> Float[Array, "d d"]:
    r"""Riemannian gradient correction for PSD precision updates.

    Ensures the corrected Hessian remains negative semi-definite,
    stabilizing Newton/EP/VI when the raw Hessian is indefinite:

        G = site\_precision + hessian
        H_{psd} = hessian - 0.5 \cdot lr \cdot G \cdot S \cdot G

    where ``S`` is the site covariance.

    Args:
        hessian: Raw second derivative, shape ``(d, d)``.
        site_precision: Current site precision, shape ``(d, d)``.
        site_covariance: Current site covariance, shape ``(d, d)``.
        lr: Learning rate. Default ``1.0``.

    Returns:
        Corrected Hessian, shape ``(d, d)``.
    """
    G = site_precision + hessian
    correction = G @ site_covariance @ G
    return hessian - 0.5 * lr * correction

cavity_distribution(post_mean: Float[Array, ' N'], post_cov: lx.AbstractLinearOperator, site_nat1: Float[Array, ' N'], site_nat2: lx.AbstractLinearOperator, power: float = 1.0) -> tuple[Float[Array, ' N'], lx.AbstractLinearOperator]

Compute EP cavity distribution by removing a site.

Computes:

cav_prec = post_prec - power * site_nat2
cav_cov  = inv(cav_prec)
cav_mean = cav_cov @ (post_prec @ post_mean - power * site_nat1)

Parameters:

Name Type Description Default
post_mean Float[Array, ' N']

Posterior mean, shape (N,).

required
post_cov AbstractLinearOperator

Posterior covariance operator.

required
site_nat1 Float[Array, ' N']

Site natural parameter (precision-weighted mean).

required
site_nat2 AbstractLinearOperator

Site natural parameter (precision).

required
power float

Power EP fraction (default 1.0 for standard EP).

1.0

Returns:

Type Description
tuple[Float[Array, ' N'], AbstractLinearOperator]

Tuple (cav_mean, cav_cov).

Source code in src/gaussx/_inference/_inference.py
def cavity_distribution(
    post_mean: Float[Array, " N"],
    post_cov: lx.AbstractLinearOperator,
    site_nat1: Float[Array, " N"],
    site_nat2: lx.AbstractLinearOperator,
    power: float = 1.0,
) -> tuple[Float[Array, " N"], lx.AbstractLinearOperator]:
    """Compute EP cavity distribution by removing a site.

    Computes:

        cav_prec = post_prec - power * site_nat2
        cav_cov  = inv(cav_prec)
        cav_mean = cav_cov @ (post_prec @ post_mean - power * site_nat1)

    Args:
        post_mean: Posterior mean, shape ``(N,)``.
        post_cov: Posterior covariance operator.
        site_nat1: Site natural parameter (precision-weighted mean).
        site_nat2: Site natural parameter (precision).
        power: Power EP fraction (default 1.0 for standard EP).

    Returns:
        Tuple ``(cav_mean, cav_cov)``.
    """
    post_prec = inv(post_cov)
    cav_prec_mat = post_prec.as_matrix() - power * site_nat2.as_matrix()
    cav_prec = lx.MatrixLinearOperator(cav_prec_mat)
    cav_cov = inv(cav_prec)

    eta1_cav = post_prec.mv(post_mean) - power * site_nat1
    cav_mean = cav_cov.mv(eta1_cav)

    return cav_mean, cav_cov

trace_correction(K_xx: lx.AbstractLinearOperator, K_xz: Float[Array, 'N M'], K_zz: lx.AbstractLinearOperator, *, solver: AbstractSolveStrategy | None = None) -> Float[Array, '']

Trace term in Titsias collapsed ELBO.

Computes:

tr(K_xx) - tr(K_xz^T K_zz^{-1} K_xz)

This is the "trace correction" that penalizes the Nystrom approximation error.

Parameters:

Name Type Description Default
K_xx AbstractLinearOperator

Full covariance, shape (N, N).

required
K_xz Float[Array, 'N M']

Cross-covariance, shape (N, M).

required
K_zz AbstractLinearOperator

Inducing covariance, shape (M, M).

required
solver AbstractSolveStrategy | None

Optional solve strategy. When None, uses structural dispatch.

None

Returns:

Type Description
Float[Array, '']

Scalar trace correction.

Source code in src/gaussx/_inference/_inference.py
def trace_correction(
    K_xx: lx.AbstractLinearOperator,
    K_xz: Float[Array, "N M"],
    K_zz: lx.AbstractLinearOperator,
    *,
    solver: AbstractSolveStrategy | None = None,
) -> Float[Array, ""]:
    """Trace term in Titsias collapsed ELBO.

    Computes:

        tr(K_xx) - tr(K_xz^T K_zz^{-1} K_xz)

    This is the "trace correction" that penalizes the Nystrom
    approximation error.

    Args:
        K_xx: Full covariance, shape ``(N, N)``.
        K_xz: Cross-covariance, shape ``(N, M)``.
        K_zz: Inducing covariance, shape ``(M, M)``.
        solver: Optional solve strategy. When ``None``, uses
            structural dispatch.

    Returns:
        Scalar trace correction.
    """
    tr_full = trace(K_xx)

    # tr(K_xz^T K_zz^{-1} K_xz) = sum_ij W_ij * K_xz_ij
    # where W = K_zz^{-1} K_xz^T reshaped, but easier:
    # tr(A^T B) = sum(A * B), so tr(K_xz^T W) where W_col = K_zz^{-1} K_xz_col
    from gaussx._linalg._linalg import solve_rows

    W = solve_rows(K_zz, K_xz, solver=solver)  # (N, M)
    tr_approx = jnp.sum(K_xz * W)

    return tr_full - tr_approx

Ensemble covariances & Kalman gain

Bessel-corrected empirical (cross-)covariances from ensemble members and the ensemble Kalman gain built from them.

Structured linear algebra and Gaussian primitives for JAX.

ensemble_covariance(particles: Float[Array, 'J N'], *, bessel: bool = False) -> LowRankUpdate

Empirical covariance from an ensemble as a low-rank operator.

Returns C = c X'^T X' with c = 1 / J when bessel=False (default, maximum likelihood) and c = 1 / (J - 1) when bessel=True (unbiased / ensemble Kalman filter convention). The result is a LowRankUpdate of rank <= J-1 rather than materializing the full (N, N) matrix. Efficient when J << N.

Parameters:

Name Type Description Default
particles Float[Array, 'J N']

Ensemble of shape (J, N).

required
bessel bool

If True, apply the 1 / (J - 1) Bessel correction used throughout the ensemble Kalman filter literature. This lower-level helper defaults to False for backwards compatibility; ensemble_kalman_gain defaults to True for the EnKF convention.

False

Returns:

Type Description
LowRankUpdate

A LowRankUpdate operator representing the empirical

LowRankUpdate

covariance, with a zero base and J-column low-rank factor.

Source code in src/gaussx/_inference/_ensemble.py
def ensemble_covariance(
    particles: Float[Array, "J N"],
    *,
    bessel: bool = False,
) -> LowRankUpdate:
    r"""Empirical covariance from an ensemble as a low-rank operator.

    Returns ``C = c X'^T X'`` with ``c = 1 / J`` when ``bessel=False``
    (default, maximum likelihood) and ``c = 1 / (J - 1)`` when
    ``bessel=True`` (unbiased / ensemble Kalman filter convention).
    The result is a ``LowRankUpdate`` of rank ``<= J-1`` rather than
    materializing the full ``(N, N)`` matrix.  Efficient when
    ``J << N``.

    Args:
        particles: Ensemble of shape ``(J, N)``.
        bessel: If True, apply the ``1 / (J - 1)`` Bessel correction
            used throughout the ensemble Kalman filter literature. This
            lower-level helper defaults to False for backwards compatibility;
            `ensemble_kalman_gain` defaults to True for the EnKF
            convention.

    Returns:
        A ``LowRankUpdate`` operator representing the empirical
        covariance, with a zero base and ``J``-column low-rank factor.
    """
    J, N = particles.shape
    _check_ensemble_size(J, bessel)
    mean = jnp.mean(particles, axis=0)
    deviations = particles - mean[None, :]  # (J, N)

    divisor = J - 1 if bessel else J
    U = deviations.T / jnp.sqrt(divisor)  # (N, J)

    base = lx.DiagonalLinearOperator(jnp.zeros(N, dtype=particles.dtype))
    return LowRankUpdate(base, U)

ensemble_cross_covariance(particles_theta: Float[Array, 'J N'], particles_G: Float[Array, 'J M'], *, bessel: bool = False) -> Float[Array, 'N M']

Cross-covariance between two ensemble sets.

Computes C^{theta,G} = c sum_j (theta_j - bar)(G_j - bar)^T with c = 1 / J by default or c = 1 / (J - 1) when bessel=True.

Parameters:

Name Type Description Default
particles_theta Float[Array, 'J N']

First ensemble, shape (J, N).

required
particles_G Float[Array, 'J M']

Second ensemble, shape (J, M).

required
bessel bool

If True, apply the 1 / (J - 1) Bessel correction used by ensemble Kalman filter recipes. This lower-level helper defaults to False for backwards compatibility; ensemble_kalman_gain defaults to True for the EnKF convention.

False

Returns:

Type Description
Float[Array, 'N M']

Cross-covariance array of shape (N, M).

Source code in src/gaussx/_inference/_ensemble.py
def ensemble_cross_covariance(
    particles_theta: Float[Array, "J N"],
    particles_G: Float[Array, "J M"],
    *,
    bessel: bool = False,
) -> Float[Array, "N M"]:
    r"""Cross-covariance between two ensemble sets.

    Computes ``C^{theta,G} = c sum_j (theta_j - bar)(G_j - bar)^T``
    with ``c = 1 / J`` by default or ``c = 1 / (J - 1)`` when
    ``bessel=True``.

    Args:
        particles_theta: First ensemble, shape ``(J, N)``.
        particles_G: Second ensemble, shape ``(J, M)``.
        bessel: If True, apply the ``1 / (J - 1)`` Bessel correction
            used by ensemble Kalman filter recipes. This lower-level helper
            defaults to False for backwards compatibility; `ensemble_kalman_gain`
            defaults to True for the EnKF convention.

    Returns:
        Cross-covariance array of shape ``(N, M)``.
    """
    J = particles_theta.shape[0]
    _check_ensemble_size(J, bessel)
    dev_theta = particles_theta - jnp.mean(particles_theta, axis=0, keepdims=True)
    dev_G = particles_G - jnp.mean(particles_G, axis=0, keepdims=True)
    divisor = J - 1 if bessel else J
    return (dev_theta.T @ dev_G) / divisor

ensemble_kalman_gain(particles: Float[Array, 'J N'], obs_particles: Float[Array, 'J M'], obs_noise: lx.AbstractLinearOperator, *, solver: AbstractSolverStrategy | None = None, bessel: bool = True) -> Float[Array, 'N M']

Kalman gain from an ensemble and its image in observation space.

Computes K = C^{xH} (C^{HH} + R)^{-1}, where C^{xH} is the state-observation cross-covariance and C^{HH} is the observation-space ensemble covariance. The innovation covariance S = C^{HH} + R is assembled as a LowRankUpdate so solve_rows can use structural dispatch via the Woodbury identity.

Parameters:

Name Type Description Default
particles Float[Array, 'J N']

Prior ensemble in state space, shape (J, N).

required
obs_particles Float[Array, 'J M']

Prior ensemble in observation space, shape (J, M).

required
obs_noise AbstractLinearOperator

Observation error covariance operator, shape (M, M).

required
solver AbstractSolverStrategy | None

Optional solver strategy. None uses structural dispatch.

None
bessel bool

Defaults to True, unlike the lower-level covariance helpers, because this recipe follows the unbiased EnKF convention. Use False for maximum-likelihood recipes with a 1 / J divisor.

True

Returns:

Type Description
Float[Array, 'N M']

Dense Kalman gain of shape (N, M).

Source code in src/gaussx/_inference/_ensemble.py
def ensemble_kalman_gain(
    particles: Float[Array, "J N"],
    obs_particles: Float[Array, "J M"],
    obs_noise: lx.AbstractLinearOperator,
    *,
    solver: AbstractSolverStrategy | None = None,
    bessel: bool = True,
) -> Float[Array, "N M"]:
    r"""Kalman gain from an ensemble and its image in observation space.

    Computes ``K = C^{xH} (C^{HH} + R)^{-1}``, where ``C^{xH}`` is the
    state-observation cross-covariance and ``C^{HH}`` is the
    observation-space ensemble covariance. The innovation covariance
    ``S = C^{HH} + R`` is assembled as a ``LowRankUpdate`` so
    ``solve_rows`` can use structural dispatch via the Woodbury identity.

    Args:
        particles: Prior ensemble in state space, shape ``(J, N)``.
        obs_particles: Prior ensemble in observation space, shape ``(J, M)``.
        obs_noise: Observation error covariance operator, shape ``(M, M)``.
        solver: Optional solver strategy. ``None`` uses structural dispatch.
        bessel: Defaults to True, unlike the lower-level covariance helpers,
            because this recipe follows the unbiased EnKF convention. Use
            False for maximum-likelihood recipes with a ``1 / J`` divisor.

    Returns:
        Dense Kalman gain of shape ``(N, M)``.
    """
    if particles.shape[0] != obs_particles.shape[0]:
        raise ValueError(
            "particles and obs_particles must share the same ensemble size, "
            f"got J={particles.shape[0]} and J={obs_particles.shape[0]}."
        )
    cross_cov = ensemble_cross_covariance(
        particles,
        obs_particles,
        bessel=bessel,
    )
    innovation_cov = ensemble_covariance(obs_particles, bessel=bessel)
    innovation_cov = LowRankUpdate(obs_noise, innovation_cov.U)
    return solve_rows(innovation_cov, cross_cov, solver=solver)

etkf_transform(obs_particles: Float[Array, 'J M'], y: Float[Array, ' M'], obs_noise: lx.AbstractLinearOperator, *, inflation: float = 1.0) -> tuple[Float[Array, ' J'], Float[Array, 'J J']]

Ensemble Transform Kalman Filter (ETKF) analysis weights.

Deterministic (perturbed-obs-free) ensemble square-root analysis in the J-dimensional ensemble space (Bishop et al. 2001; Hunt et al. 2007). With raw observation perturbations Y = H X'^f (columns are members) and d = y - H x_bar^f,

\[ \tilde{A}^{-1} = \tfrac{J-1}{\lambda} I + Y^T R^{-1} Y, \qquad \bar{w} = \tilde{A}\, Y^T R^{-1} d, \qquad W = \big((J-1)\,\tilde{A}\big)^{1/2}, \]

where lambda is the (multiplicative) inflation and W is the symmetric square root. The analysis ensemble is reconstructed as

\[ \bar{x}^a = \bar{x}^f + X'^f \bar{w}, \qquad X'^a = X'^f\, W. \]

The symmetric (eigendecomposition) square root -- not a Cholesky factor -- is required: because the observation perturbations are zero-mean, 1 is an eigenvector of W with eigenvalue 1, which makes the transform exactly mean-preserving (sum_j X'^a_j = 0).

Parameters:

Name Type Description Default
obs_particles Float[Array, 'J M']

Forecast ensemble in observation space, shape (J, M).

required
y Float[Array, ' M']

Observation vector, shape (M,).

required
obs_noise AbstractLinearOperator

Observation error covariance operator R, shape (M, M).

required
inflation float

Multiplicative covariance inflation lambda >= 1, applied to the prior term (J - 1) / lambda.

1.0

Returns:

Type Description
Float[Array, ' J']

(w_mean, transform) where w_mean has shape (J,) and

Float[Array, 'J J']

transform has shape (J, J). Apply to forecast state

tuple[Float[Array, ' J'], Float[Array, 'J J']]

perturbations Xp (shape (J, N)) as

tuple[Float[Array, ' J'], Float[Array, 'J J']]

x_bar^a = x_bar^f + w_mean @ Xp and X'^a = transform @ Xp.

Source code in src/gaussx/_inference/_ensemble.py
def etkf_transform(
    obs_particles: Float[Array, "J M"],
    y: Float[Array, " M"],
    obs_noise: lx.AbstractLinearOperator,
    *,
    inflation: float = 1.0,
) -> tuple[Float[Array, " J"], Float[Array, "J J"]]:
    r"""Ensemble Transform Kalman Filter (ETKF) analysis weights.

    Deterministic (perturbed-obs-free) ensemble square-root analysis in the
    ``J``-dimensional ensemble space (Bishop et al. 2001; Hunt et al. 2007).
    With raw observation perturbations ``Y = H X'^f`` (columns are members) and
    ``d = y - H x_bar^f``,

    $$
    \tilde{A}^{-1} = \tfrac{J-1}{\lambda} I + Y^T R^{-1} Y, \qquad
    \bar{w} = \tilde{A}\, Y^T R^{-1} d, \qquad
    W = \big((J-1)\,\tilde{A}\big)^{1/2},
    $$

    where ``lambda`` is the (multiplicative) ``inflation`` and ``W`` is the
    **symmetric** square root. The analysis ensemble is reconstructed as

    $$
    \bar{x}^a = \bar{x}^f + X'^f \bar{w}, \qquad X'^a = X'^f\, W.
    $$

    The symmetric (eigendecomposition) square root -- not a Cholesky factor --
    is required: because the observation perturbations are zero-mean, ``1`` is
    an eigenvector of ``W`` with eigenvalue ``1``, which makes the transform
    exactly mean-preserving (``sum_j X'^a_j = 0``).

    Args:
        obs_particles: Forecast ensemble in observation space, shape ``(J, M)``.
        y: Observation vector, shape ``(M,)``.
        obs_noise: Observation error covariance operator ``R``, shape ``(M, M)``.
        inflation: Multiplicative covariance inflation ``lambda >= 1``, applied
            to the prior term ``(J - 1) / lambda``.

    Returns:
        ``(w_mean, transform)`` where ``w_mean`` has shape ``(J,)`` and
        ``transform`` has shape ``(J, J)``. Apply to forecast state
        perturbations ``Xp`` (shape ``(J, N)``) as
        ``x_bar^a = x_bar^f + w_mean @ Xp`` and ``X'^a = transform @ Xp``.
    """
    n_ens = obs_particles.shape[0]
    obs_mean = jnp.mean(obs_particles, axis=0)
    obs_pert = obs_particles - obs_mean[None, :]  # (J, M), zero-mean rows

    r_matrix = obs_noise.as_matrix()
    # R^{-1} applied to the (M, .) right-hand sides.
    rinv_pert = jnp.linalg.solve(r_matrix, obs_pert.T)  # (M, J)
    rinv_d = jnp.linalg.solve(r_matrix, y - obs_mean)  # (M,)

    precision = (n_ens - 1) / inflation * jnp.eye(n_ens) + obs_pert @ rinv_pert
    precision = symmetrize(precision)
    analysis_cov = jnp.linalg.inv(precision)  # tilde A, (J, J)

    w_mean = analysis_cov @ (obs_pert @ rinv_d)  # (J,)
    transform = _symmetric_sqrt((n_ens - 1) * analysis_cov)
    return w_mean, transform

Localization & inflation

The standard fixes for small-ensemble rank deficiency: Schur-product localization with a taper (Gaspari-Cohn by default) and multiplicative / RTPP / RTPS inflation.

Structured linear algebra and Gaussian primitives for JAX.

localization_matrix(coords_a: Float[Array, 'Na D'], coords_b: Float[Array, 'Nb D'], c: float, metric: Callable[[Float[Array, 'Na D'], Float[Array, 'Nb D']], Float[Array, 'Na Nb']] = euclidean_distance) -> Float[Array, 'Na Nb']

Pairwise Gaspari-Cohn taper rho(dist(a_i, b_j); c).

Use this to build the rho_xy (state-obs) and rho_yy (obs-obs) localization matrices consumed by localized_kalman_gain.

Parameters:

Name Type Description Default
coords_a Float[Array, 'Na D']

First set of points, shape (Na, D).

required
coords_b Float[Array, 'Nb D']

Second set of points, shape (Nb, D).

required
c float

Gaspari-Cohn compact-support radius.

required
metric Callable[[Float[Array, 'Na D'], Float[Array, 'Nb D']], Float[Array, 'Na Nb']]

Pairwise distance function returning an (Na, Nb) matrix. Defaults to euclidean_distance; pass haversine_distance for spherical coordinates.

euclidean_distance

Returns:

Type Description
Float[Array, 'Na Nb']

Localization matrix of shape (Na, Nb) with entries in [0, 1].

Source code in src/gaussx/_inference/_ensemble.py
def localization_matrix(
    coords_a: Float[Array, "Na D"],
    coords_b: Float[Array, "Nb D"],
    c: float,
    metric: Callable[
        [Float[Array, "Na D"], Float[Array, "Nb D"]], Float[Array, "Na Nb"]
    ] = euclidean_distance,
) -> Float[Array, "Na Nb"]:
    """Pairwise Gaspari-Cohn taper ``rho(dist(a_i, b_j); c)``.

    Use this to build the ``rho_xy`` (state-obs) and ``rho_yy`` (obs-obs)
    localization matrices consumed by `localized_kalman_gain`.

    Args:
        coords_a: First set of points, shape ``(Na, D)``.
        coords_b: Second set of points, shape ``(Nb, D)``.
        c: Gaspari-Cohn compact-support radius.
        metric: Pairwise distance function returning an ``(Na, Nb)`` matrix.
            Defaults to `euclidean_distance`; pass
            `haversine_distance` for spherical coordinates.

    Returns:
        Localization matrix of shape ``(Na, Nb)`` with entries in ``[0, 1]``.
    """
    return gaspari_cohn(metric(coords_a, coords_b), c)

localized_kalman_gain(particles: Float[Array, 'J N'], obs_particles: Float[Array, 'J M'], obs_noise: lx.AbstractLinearOperator, rho_xy: Float[Array, 'N M'], rho_yy: Float[Array, 'M M'], *, solver: AbstractSolverStrategy | None = None, bessel: bool = True) -> Float[Array, 'N M']

Ensemble Kalman gain with Hadamard (Schur-product) localization.

Computes

\[ K = (\rho_{xy} \circ P_{xy})\,(\rho_{yy} \circ P_{yy} + R)^{-1}, \]

where P_xy is the state-observation cross-covariance and P_yy the observation-space ensemble covariance. Tapering kills spurious long-range sample correlations; because Gaspari-Cohn is positive-definite, the Schur product theorem keeps rho_yy . P_yy PSD, so the innovation covariance stays invertible.

This is the localized counterpart of ensemble_kalman_gain. Unlike that routine, the Hadamard product destroys the low-rank structure, so the innovation covariance is materialized densely and the solve is O(N M + M^3). Recover the unlocalized gain as the c -> inf limit (rho_xy = rho_yy = 1).

Parameters:

Name Type Description Default
particles Float[Array, 'J N']

Prior ensemble in state space, shape (J, N).

required
obs_particles Float[Array, 'J M']

Prior ensemble in observation space, shape (J, M).

required
obs_noise AbstractLinearOperator

Observation error covariance operator R, shape (M, M).

required
rho_xy Float[Array, 'N M']

State-observation localization matrix, shape (N, M).

required
rho_yy Float[Array, 'M M']

Observation-observation localization matrix, shape (M, M).

required
solver AbstractSolverStrategy | None

Optional solver strategy for the dense innovation solve.

None
bessel bool

Use the 1 / (J - 1) divisor (EnKF convention, default).

True

Returns:

Type Description
Float[Array, 'N M']

Dense localized Kalman gain of shape (N, M).

Source code in src/gaussx/_inference/_ensemble.py
def localized_kalman_gain(
    particles: Float[Array, "J N"],
    obs_particles: Float[Array, "J M"],
    obs_noise: lx.AbstractLinearOperator,
    rho_xy: Float[Array, "N M"],
    rho_yy: Float[Array, "M M"],
    *,
    solver: AbstractSolverStrategy | None = None,
    bessel: bool = True,
) -> Float[Array, "N M"]:
    r"""Ensemble Kalman gain with Hadamard (Schur-product) localization.

    Computes

    $$
    K = (\rho_{xy} \circ P_{xy})\,(\rho_{yy} \circ P_{yy} + R)^{-1},
    $$

    where ``P_xy`` is the state-observation cross-covariance and ``P_yy`` the
    observation-space ensemble covariance. Tapering kills spurious long-range
    sample correlations; because Gaspari-Cohn is positive-definite, the Schur
    product theorem keeps ``rho_yy . P_yy`` PSD, so the innovation covariance
    stays invertible.

    This is the localized counterpart of `ensemble_kalman_gain`. Unlike
    that routine, the Hadamard product destroys the low-rank structure, so the
    innovation covariance is materialized densely and the solve is
    ``O(N M + M^3)``. Recover the unlocalized gain as the ``c -> inf`` limit
    (``rho_xy = rho_yy = 1``).

    Args:
        particles: Prior ensemble in state space, shape ``(J, N)``.
        obs_particles: Prior ensemble in observation space, shape ``(J, M)``.
        obs_noise: Observation error covariance operator ``R``, shape ``(M, M)``.
        rho_xy: State-observation localization matrix, shape ``(N, M)``.
        rho_yy: Observation-observation localization matrix, shape ``(M, M)``.
        solver: Optional solver strategy for the dense innovation solve.
        bessel: Use the ``1 / (J - 1)`` divisor (EnKF convention, default).

    Returns:
        Dense localized Kalman gain of shape ``(N, M)``.
    """
    if particles.shape[0] != obs_particles.shape[0]:
        raise ValueError(
            "particles and obs_particles must share the same ensemble size, "
            f"got J={particles.shape[0]} and J={obs_particles.shape[0]}."
        )
    cross_cov = ensemble_cross_covariance(particles, obs_particles, bessel=bessel)
    obs_cov = ensemble_cross_covariance(obs_particles, obs_particles, bessel=bessel)

    localized_cross = rho_xy * cross_cov
    innovation = rho_yy * obs_cov + obs_noise.as_matrix()
    innovation = symmetrize(innovation)
    innovation_op = lx.MatrixLinearOperator(innovation, lx.positive_semidefinite_tag)
    return solve_rows(innovation_op, localized_cross, solver=solver)

gaspari_cohn(r: Float[Array, '*shape'], c: float) -> Float[Array, '*shape']

Gaspari-Cohn (1999) fifth-order compactly-supported taper.

The standard positive-definite, approximately-Gaussian localization function. With z = 2 |r| / c it is the piecewise-rational

\[ \begin{aligned} \rho = \begin{cases} -\tfrac14 z^5 + \tfrac12 z^4 + \tfrac58 z^3 - \tfrac53 z^2 + 1 & 0 \le z \le 1 \\ \tfrac1{12} z^5 - \tfrac12 z^4 + \tfrac58 z^3 + \tfrac53 z^2 - 5 z + 4 - \tfrac{2}{3 z} & 1 < z \le 2 \\ 0 & z > 2. \end{cases} \end{aligned} \]

so rho(0) = 1 and rho = 0 for |r| >= c (c is the compact-support radius, not a Gaussian length scale). The taper is only \(C^1\) at the knots z = 1, 2.

Differentiability: the 2 / (3 z) term in the middle branch is guarded with a safe denominator so reverse-mode gradients are finite at r = 0 (which would otherwise produce NaN via the standard where pitfall).

Parameters:

Name Type Description Default
r Float[Array, '*shape']

Distances (any shape), e.g. a pairwise distance matrix.

required
c float

Compact-support radius; rho = 0 beyond |r| = c.

required

Returns:

Type Description
Float[Array, '*shape']

Taper values in [0, 1], same shape as r.

Source code in src/gaussx/_inference/_ensemble.py
def gaspari_cohn(r: Float[Array, "*shape"], c: float) -> Float[Array, "*shape"]:
    r"""Gaspari-Cohn (1999) fifth-order compactly-supported taper.

    The standard positive-definite, approximately-Gaussian localization
    function. With ``z = 2 |r| / c`` it is the piecewise-rational

    $$
    \begin{aligned}
    \rho = \begin{cases}
      -\tfrac14 z^5 + \tfrac12 z^4 + \tfrac58 z^3 - \tfrac53 z^2 + 1
        & 0 \le z \le 1 \\
      \tfrac1{12} z^5 - \tfrac12 z^4 + \tfrac58 z^3 + \tfrac53 z^2
        - 5 z + 4 - \tfrac{2}{3 z}
        & 1 < z \le 2 \\
      0 & z > 2.
    \end{cases}
    \end{aligned}
    $$

    so ``rho(0) = 1`` and ``rho = 0`` for ``|r| >= c`` (``c`` is the
    compact-support radius, **not** a Gaussian length scale). The taper is
    only $C^1$ at the knots ``z = 1, 2``.

    Differentiability: the ``2 / (3 z)`` term in the middle branch is guarded
    with a safe denominator so reverse-mode gradients are finite at ``r = 0``
    (which would otherwise produce ``NaN`` via the standard ``where`` pitfall).

    Args:
        r: Distances (any shape), e.g. a pairwise distance matrix.
        c: Compact-support radius; ``rho = 0`` beyond ``|r| = c``.

    Returns:
        Taper values in ``[0, 1]``, same shape as ``r``.
    """
    z = 2.0 * jnp.abs(r) / c
    # Guard the 1 / z term: at z = 0 the near branch is selected, but JAX still
    # traces the middle branch, so an unguarded 1 / z poisons the gradient.
    z_safe = jnp.where(z > 0.0, z, 1.0)

    near = -0.25 * z**5 + 0.5 * z**4 + 0.625 * z**3 - (5.0 / 3.0) * z**2 + 1.0
    mid = (
        (1.0 / 12.0) * z**5
        - 0.5 * z**4
        + 0.625 * z**3
        + (5.0 / 3.0) * z**2
        - 5.0 * z
        + 4.0
        - 2.0 / (3.0 * z_safe)
    )
    return jnp.where(z <= 1.0, near, jnp.where(z < 2.0, mid, 0.0))

inflate_multiplicative(ensemble: Float[Array, 'J N'], factor: float) -> Float[Array, 'J N']

Multiplicative ensemble inflation about the mean.

Restores ensemble spread lost to sampling error / model collapse by scaling perturbations: x_j <- x_bar + factor (x_j - x_bar).

Parameters:

Name Type Description Default
ensemble Float[Array, 'J N']

Ensemble of shape (J, N).

required
factor float

Inflation factor >= 1 (e.g. 1.02-1.10).

required

Returns:

Type Description
Float[Array, 'J N']

Inflated ensemble, shape (J, N). The mean is unchanged.

Source code in src/gaussx/_inference/_ensemble.py
def inflate_multiplicative(
    ensemble: Float[Array, "J N"],
    factor: float,
) -> Float[Array, "J N"]:
    r"""Multiplicative ensemble inflation about the mean.

    Restores ensemble spread lost to sampling error / model collapse by scaling
    perturbations: ``x_j <- x_bar + factor (x_j - x_bar)``.

    Args:
        ensemble: Ensemble of shape ``(J, N)``.
        factor: Inflation factor ``>= 1`` (e.g. ``1.02``-``1.10``).

    Returns:
        Inflated ensemble, shape ``(J, N)``. The mean is unchanged.
    """
    mean = jnp.mean(ensemble, axis=0, keepdims=True)
    return mean + factor * (ensemble - mean)

inflate_rtpp(posterior: Float[Array, 'J N'], prior: Float[Array, 'J N'], alpha: float) -> Float[Array, 'J N']

Relaxation to prior perturbations (RTPP; Zhang et al. 2004).

Relaxes posterior perturbations toward the prior perturbations while keeping the posterior mean: x'^a <- (1 - alpha) x'^a + alpha x'^f, where the perturbations are taken about each ensemble's own mean.

Parameters:

Name Type Description Default
posterior Float[Array, 'J N']

Analysis ensemble, shape (J, N).

required
prior Float[Array, 'J N']

Forecast ensemble, shape (J, N).

required
alpha float

Relaxation weight in [0, 1].

required

Returns:

Type Description
Float[Array, 'J N']

Relaxed analysis ensemble, shape (J, N). The posterior mean is

Float[Array, 'J N']

preserved.

Source code in src/gaussx/_inference/_ensemble.py
def inflate_rtpp(
    posterior: Float[Array, "J N"],
    prior: Float[Array, "J N"],
    alpha: float,
) -> Float[Array, "J N"]:
    r"""Relaxation to prior perturbations (RTPP; Zhang et al. 2004).

    Relaxes posterior perturbations toward the prior perturbations while keeping
    the posterior mean: ``x'^a <- (1 - alpha) x'^a + alpha x'^f``, where the
    perturbations are taken about each ensemble's own mean.

    Args:
        posterior: Analysis ensemble, shape ``(J, N)``.
        prior: Forecast ensemble, shape ``(J, N)``.
        alpha: Relaxation weight in ``[0, 1]``.

    Returns:
        Relaxed analysis ensemble, shape ``(J, N)``. The posterior mean is
        preserved.
    """
    post_mean = jnp.mean(posterior, axis=0, keepdims=True)
    post_pert = posterior - post_mean
    prior_pert = prior - jnp.mean(prior, axis=0, keepdims=True)
    return post_mean + (1.0 - alpha) * post_pert + alpha * prior_pert

inflate_rtps(posterior: Float[Array, 'J N'], prior: Float[Array, 'J N'], beta: float, eps: float = 1e-12) -> Float[Array, 'J N']

Relaxation to prior spread (RTPS; Whitaker & Hamill 2012).

Scales each posterior perturbation, per coordinate, so the analysis spread relaxes back toward the prior spread: x'^a <- x'^a [ (1 - beta) + beta sigma^f / sigma^a ], with sigma the per-coordinate ensemble standard deviation.

Parameters:

Name Type Description Default
posterior Float[Array, 'J N']

Analysis ensemble, shape (J, N).

required
prior Float[Array, 'J N']

Forecast ensemble, shape (J, N).

required
beta float

Relaxation weight in [0, 1].

required
eps float

Floor on the posterior std to avoid division by zero.

1e-12

Returns:

Type Description
Float[Array, 'J N']

Spread-restored analysis ensemble, shape (J, N). The posterior mean

Float[Array, 'J N']

is preserved.

Source code in src/gaussx/_inference/_ensemble.py
def inflate_rtps(
    posterior: Float[Array, "J N"],
    prior: Float[Array, "J N"],
    beta: float,
    eps: float = 1e-12,
) -> Float[Array, "J N"]:
    r"""Relaxation to prior spread (RTPS; Whitaker & Hamill 2012).

    Scales each posterior perturbation, per coordinate, so the analysis spread
    relaxes back toward the prior spread:
    ``x'^a <- x'^a [ (1 - beta) + beta sigma^f / sigma^a ]``, with ``sigma`` the
    per-coordinate ensemble standard deviation.

    Args:
        posterior: Analysis ensemble, shape ``(J, N)``.
        prior: Forecast ensemble, shape ``(J, N)``.
        beta: Relaxation weight in ``[0, 1]``.
        eps: Floor on the posterior std to avoid division by zero.

    Returns:
        Spread-restored analysis ensemble, shape ``(J, N)``. The posterior mean
        is preserved.
    """
    post_mean = jnp.mean(posterior, axis=0, keepdims=True)
    post_pert = posterior - post_mean
    sigma_post = jnp.std(posterior, axis=0)
    sigma_prior = jnp.std(prior, axis=0)
    scale = (1.0 - beta) + beta * sigma_prior / (sigma_post + eps)
    return post_mean + post_pert * scale[None, :]

Distances

Structured linear algebra and Gaussian primitives for JAX.

euclidean_distance(coords_a: Float[Array, 'Na D'], coords_b: Float[Array, 'Nb D']) -> Float[Array, 'Na Nb']

Pairwise Euclidean distances ||a_i - b_j||.

A default metric for localization_matrix. Builds on stable_squared_distances and takes a gradient-safe square root so zero distances (e.g. the diagonal of a self-distance matrix) do not produce NaN gradients.

Parameters:

Name Type Description Default
coords_a Float[Array, 'Na D']

First set of points, shape (Na, D).

required
coords_b Float[Array, 'Nb D']

Second set of points, shape (Nb, D).

required

Returns:

Type Description
Float[Array, 'Na Nb']

Distance matrix of shape (Na, Nb).

Source code in src/gaussx/_inference/_ensemble.py
def euclidean_distance(
    coords_a: Float[Array, "Na D"],
    coords_b: Float[Array, "Nb D"],
) -> Float[Array, "Na Nb"]:
    """Pairwise Euclidean distances ``||a_i - b_j||``.

    A default ``metric`` for `localization_matrix`. Builds on
    `stable_squared_distances` and takes a gradient-safe square root so
    zero distances (e.g. the diagonal of a self-distance matrix) do not produce
    ``NaN`` gradients.

    Args:
        coords_a: First set of points, shape ``(Na, D)``.
        coords_b: Second set of points, shape ``(Nb, D)``.

    Returns:
        Distance matrix of shape ``(Na, Nb)``.
    """
    sq = stable_squared_distances(
        coords_a,
        coords_b,
        compute_dtype=coords_a.dtype,
        accumulate_dtype=coords_a.dtype,
    )
    sq_safe = jnp.where(sq > 0.0, sq, 1.0)
    return jnp.where(sq > 0.0, jnp.sqrt(sq_safe), 0.0)

haversine_distance(coords_a: Float[Array, 'Na 2'], coords_b: Float[Array, 'Nb 2'], radius: float = 6371000.0) -> Float[Array, 'Na Nb']

Pairwise great-circle (haversine) distances on a sphere.

A metric for localization_matrix on geophysical grids. Coordinates are (latitude, longitude) in radians.

Parameters:

Name Type Description Default
coords_a Float[Array, 'Na 2']

First set of points (lat, lon) in radians, shape (Na, 2).

required
coords_b Float[Array, 'Nb 2']

Second set of points (lat, lon) in radians, shape (Nb, 2).

required
radius float

Sphere radius in the units of the returned distance (default the Earth mean radius, 6.371e6 m).

6371000.0

Returns:

Type Description
Float[Array, 'Na Nb']

Great-circle distance matrix of shape (Na, Nb).

Source code in src/gaussx/_inference/_ensemble.py
def haversine_distance(
    coords_a: Float[Array, "Na 2"],
    coords_b: Float[Array, "Nb 2"],
    radius: float = 6.371e6,
) -> Float[Array, "Na Nb"]:
    """Pairwise great-circle (haversine) distances on a sphere.

    A ``metric`` for `localization_matrix` on geophysical grids.
    Coordinates are ``(latitude, longitude)`` in **radians**.

    Args:
        coords_a: First set of points ``(lat, lon)`` in radians, shape ``(Na, 2)``.
        coords_b: Second set of points ``(lat, lon)`` in radians, shape ``(Nb, 2)``.
        radius: Sphere radius in the units of the returned distance (default the
            Earth mean radius, ``6.371e6`` m).

    Returns:
        Great-circle distance matrix of shape ``(Na, Nb)``.
    """
    lat_a = coords_a[:, 0][:, None]
    lon_a = coords_a[:, 1][:, None]
    lat_b = coords_b[:, 0][None, :]
    lon_b = coords_b[:, 1][None, :]
    dlat = lat_b - lat_a
    dlon = lon_b - lon_a
    h = (
        jnp.sin(dlat / 2.0) ** 2
        + jnp.cos(lat_a) * jnp.cos(lat_b) * jnp.sin(dlon / 2.0) ** 2
    )
    return 2.0 * radius * jnp.arcsin(jnp.sqrt(jnp.clip(h, 0.0, 1.0)))