Skip to content

Distributions & Exponential Family

Layer 2: Gaussian distributions over structured covariance operators, the sugar operations that probabilistic code actually calls, and the exponential-family (natural-parameter) view used by variational and EP-style inference.

Multivariate normal distributions

NumPyro-compatible distributions whose covariance (or precision) is a lineax operator, so sample / log_prob inherit every structured fast path. MultivariateNormalPrecision carries \(\Lambda = \Sigma^{-1}\) directly — the natural home for natural-parameter guides, where materializing \(\Sigma\) would be wasted work. Both require numpyro to be installed.

Structured linear algebra and Gaussian primitives for JAX.

MultivariateNormal

Bases: Distribution

Multivariate normal parameterized by a lineax linear operator.

Unlike numpyro.distributions.MultivariateNormal which requires dense arrays, this distribution accepts any lineax.AbstractLinearOperator as its covariance. This enables efficient log-prob, sampling, and entropy for structured covariances (Kronecker, block-diagonal, low-rank, diagonal, etc.) via gaussx structural dispatch.

Requires the numpyro optional extra (pip install "gaussx[numpyro]").

Parameters:

Name Type Description Default
loc Float[Array, '*batch N']

Mean vector of shape (N,).

required
cov_operator AbstractLinearOperator

Covariance as a lineax linear operator of shape (N, N).

required
solver AbstractSolverStrategy | None

Solver strategy for solve and logdet. Defaults to AutoSolver().

None
validate_args bool | None

Whether to validate input arguments.

None

Examples:

>>> import jax.numpy as jnp
>>> import lineax as lx
>>> from gaussx._distributions import MultivariateNormal
>>> Sigma = lx.MatrixLinearOperator(
...     jnp.eye(3), lx.positive_semidefinite_tag
... )
>>> d = MultivariateNormal(jnp.zeros(3), Sigma)
>>> d.log_prob(jnp.ones(3))
Source code in src/gaussx/_distributions/_mvn.py
class MultivariateNormal(dist.Distribution):
    """Multivariate normal parameterized by a lineax linear operator.

    Unlike ``numpyro.distributions.MultivariateNormal`` which requires
    dense arrays, this distribution accepts any
    ``lineax.AbstractLinearOperator`` as its covariance. This enables
    efficient log-prob, sampling, and entropy for structured covariances
    (Kronecker, block-diagonal, low-rank, diagonal, etc.) via gaussx
    structural dispatch.

    Requires the ``numpyro`` optional extra
    (``pip install "gaussx[numpyro]"``).

    Args:
        loc: Mean vector of shape ``(N,)``.
        cov_operator: Covariance as a lineax linear operator of shape
            ``(N, N)``.
        solver: Solver strategy for ``solve`` and ``logdet``. Defaults
            to ``AutoSolver()``.
        validate_args: Whether to validate input arguments.

    Examples:

        >>> import jax.numpy as jnp
        >>> import lineax as lx
        >>> from gaussx._distributions import MultivariateNormal
        >>> Sigma = lx.MatrixLinearOperator(
        ...     jnp.eye(3), lx.positive_semidefinite_tag
        ... )
        >>> d = MultivariateNormal(jnp.zeros(3), Sigma)
        >>> d.log_prob(jnp.ones(3))
    """

    arg_constraints = {"loc": dist.constraints.real_vector}  # noqa: RUF012
    support = dist.constraints.real_vector
    reparametrized_params = ["loc"]  # noqa: RUF012
    pytree_data_fields = ("loc", "cov_operator", "solver")

    def __init__(
        self,
        loc: Float[Array, "*batch N"],
        cov_operator: lx.AbstractLinearOperator,
        solver: AbstractSolverStrategy | None = None,
        *,
        validate_args: bool | None = None,
    ) -> None:
        if solver is None:
            solver = AutoSolver()
        self.loc = loc
        self.cov_operator = cov_operator
        self.solver = solver
        event_shape = loc.shape[-1:]
        batch_shape = loc.shape[:-1]
        super().__init__(
            batch_shape=batch_shape,
            event_shape=event_shape,
            validate_args=validate_args,
        )

    def _log_prob_single(self, residual: Float[Array, " N"]) -> Float[Array, ""]:
        return _gaussian_log_prob_residual(
            residual, self.cov_operator, solver=self.solver
        )

    @validate_sample
    def log_prob(self, value: Float[Array, "*batch N"]) -> Float[Array, "*batch"]:
        residual = value - self.loc
        leading_shape = residual.shape[:-1]
        residual_flat = rearrange(residual, "... D -> (...) D")
        log_prob_flat = jax.vmap(self._log_prob_single)(residual_flat)
        return _reshape_batch(log_prob_flat, leading_shape)

    def sample(
        self,
        key: jax.Array | None,
        sample_shape: tuple[int, ...] = (),
    ) -> Float[Array, "*batch N"]:
        if key is None:
            raise ValueError(
                "PRNG key must be provided to sample from MultivariateNormal."
            )
        L = _cholesky(self.cov_operator)
        shape = sample_shape + self.batch_shape + self.event_shape
        eps = jax.random.normal(key, shape=shape)  # type: ignore[arg-type]
        eps_flat = rearrange(eps, "... D -> (...) D")
        samples_flat = jax.vmap(L.mv)(eps_flat)
        return self.loc + _reshape_samples(samples_flat, shape[:-1])

    @lazy_property
    def mean(self) -> Float[Array, "*batch N"]:
        return self.loc

    @lazy_property
    def variance(self) -> Float[Array, "*batch N"]:
        return jnp.broadcast_to(
            _diag(self.cov_operator), self.batch_shape + self.event_shape
        )

    def entropy(self) -> Float[Array, ""]:
        return gaussian_entropy(self.cov_operator, solver=self.solver)

MultivariateNormalPrecision

Bases: Distribution

Multivariate normal parameterized by a precision (inverse covariance) operator.

This is the natural parameterization for many inference algorithms (e.g. message passing, variational inference in natural coordinates). The precision operator Lambda satisfies Lambda = Sigma^{-1}.

Requires the numpyro optional extra (pip install "gaussx[numpyro]").

Parameters:

Name Type Description Default
loc Float[Array, '*batch N']

Mean vector of shape (N,).

required
prec_operator AbstractLinearOperator

Precision matrix as a lineax linear operator of shape (N, N).

required
solver AbstractSolverStrategy | None

Solver strategy for solve and logdet. Defaults to AutoSolver().

None
validate_args bool | None

Whether to validate input arguments.

None

Examples:

>>> import jax.numpy as jnp
>>> import lineax as lx
>>> from gaussx._distributions import MultivariateNormalPrecision
>>> Lambda = lx.MatrixLinearOperator(
...     2.0 * jnp.eye(3), lx.positive_semidefinite_tag
... )
>>> d = MultivariateNormalPrecision(jnp.zeros(3), Lambda)
>>> d.log_prob(jnp.ones(3))
Source code in src/gaussx/_distributions/_mvn_prec.py
class MultivariateNormalPrecision(dist.Distribution):
    """Multivariate normal parameterized by a precision (inverse covariance) operator.

    This is the natural parameterization for many inference algorithms
    (e.g. message passing, variational inference in natural coordinates).
    The precision operator ``Lambda`` satisfies ``Lambda = Sigma^{-1}``.

    Requires the ``numpyro`` optional extra
    (``pip install "gaussx[numpyro]"``).

    Args:
        loc: Mean vector of shape ``(N,)``.
        prec_operator: Precision matrix as a lineax linear operator of
            shape ``(N, N)``.
        solver: Solver strategy for ``solve`` and ``logdet``. Defaults
            to ``AutoSolver()``.
        validate_args: Whether to validate input arguments.

    Examples:

        >>> import jax.numpy as jnp
        >>> import lineax as lx
        >>> from gaussx._distributions import MultivariateNormalPrecision
        >>> Lambda = lx.MatrixLinearOperator(
        ...     2.0 * jnp.eye(3), lx.positive_semidefinite_tag
        ... )
        >>> d = MultivariateNormalPrecision(jnp.zeros(3), Lambda)
        >>> d.log_prob(jnp.ones(3))
    """

    arg_constraints = {"loc": dist.constraints.real_vector}  # noqa: RUF012
    support = dist.constraints.real_vector
    reparametrized_params = ["loc"]  # noqa: RUF012
    pytree_data_fields = ("loc", "prec_operator", "solver")

    def __init__(
        self,
        loc: Float[Array, "*batch N"],
        prec_operator: lx.AbstractLinearOperator,
        solver: AbstractSolverStrategy | None = None,
        *,
        validate_args: bool | None = None,
    ) -> None:
        if solver is None:
            solver = AutoSolver()
        self.loc = loc
        self.prec_operator = prec_operator
        self.solver = solver
        event_shape = loc.shape[-1:]
        batch_shape = loc.shape[:-1]
        super().__init__(
            batch_shape=batch_shape,
            event_shape=event_shape,
            validate_args=validate_args,
        )

    def _log_prob_single(self, residual: Float[Array, " N"]) -> Float[Array, ""]:
        quad = jnp.sum(residual * self.prec_operator.mv(residual), axis=-1)
        ld = self.solver.logdet(self.prec_operator)
        n = self.loc.shape[-1]
        return -0.5 * (n * _LOG_2PI - ld + quad)

    @validate_sample
    def log_prob(self, value: Float[Array, "*batch N"]) -> Float[Array, "*batch"]:
        residual = value - self.loc
        leading_shape = residual.shape[:-1]
        residual_flat = rearrange(residual, "... D -> (...) D")
        log_prob_flat = jax.vmap(self._log_prob_single)(residual_flat)
        return _reshape_batch(log_prob_flat, leading_shape)

    def sample(
        self,
        key: jax.Array | None,
        sample_shape: tuple[int, ...] = (),
    ) -> Float[Array, "*batch N"]:
        if key is None:
            raise ValueError(
                "PRNG key must be provided to sample from MultivariateNormalPrecision."
            )
        L = _cholesky(self.prec_operator)
        shape = sample_shape + self.batch_shape + self.event_shape
        eps = jax.random.normal(key, shape=shape)  # type: ignore[arg-type]

        def _solve_one(z):
            return _solve(L.T, z)

        eps_flat = rearrange(eps, "... D -> (...) D")
        samples_flat = jax.vmap(_solve_one)(eps_flat)
        return self.loc + _reshape_samples(samples_flat, shape[:-1])

    @lazy_property
    def mean(self) -> Float[Array, "*batch N"]:
        return self.loc

    @lazy_property
    def variance(self) -> Float[Array, "*batch N"]:
        return jnp.broadcast_to(
            _diag(_inv(self.prec_operator)), self.batch_shape + self.event_shape
        )

    def entropy(self) -> Float[Array, ""]:
        n = self.loc.shape[-1]
        ld = self.solver.logdet(self.prec_operator)
        return 0.5 * (n * (1.0 + _LOG_2PI) - ld)

Gaussian sugar ops

\[ \log \mathcal{N}(x \mid \mu, \Sigma) = -\tfrac12 (x-\mu)^\top \Sigma^{-1} (x-\mu) - \tfrac12 \log|\Sigma| - \tfrac{N}{2}\log 2\pi \]

evaluated through structured solve + logdet, plus entropy, quadratic forms, KL divergences, conditioning, and the numerically stable Joseph-form covariance update.

Structured linear algebra and Gaussian primitives for JAX.

gaussian_log_prob(loc: Float[Array, ' N'], cov_operator: lx.AbstractLinearOperator, value: Float[Array, ' N'], *, solver: AbstractSolverStrategy | None = None) -> Float[Array, '']

Multivariate normal log-probability.

Computes:

log N(value | loc, Sigma)
= -0.5 * (N log(2 pi) + log|Sigma| + (value - loc)^T Sigma^{-1} (value - loc))

All expensive operations (solve, logdet) dispatch on operator structure automatically, or through an explicit solver.

Parameters:

Name Type Description Default
loc Float[Array, ' N']

Mean vector, shape (N,).

required
cov_operator AbstractLinearOperator

Covariance operator Sigma, shape (N, N).

required
value Float[Array, ' N']

Observation vector, shape (N,).

required
solver AbstractSolverStrategy | None

Optional solver strategy (needs both solve and logdet). When None, uses structural dispatch.

None

Returns:

Type Description
Float[Array, '']

Scalar log-probability.

Source code in src/gaussx/_distributions/_gaussian.py
def gaussian_log_prob(
    loc: Float[Array, " N"],
    cov_operator: lx.AbstractLinearOperator,
    value: Float[Array, " N"],
    *,
    solver: AbstractSolverStrategy | None = None,
) -> Float[Array, ""]:
    """Multivariate normal log-probability.

    Computes:

        log N(value | loc, Sigma)
        = -0.5 * (N log(2 pi) + log|Sigma| + (value - loc)^T Sigma^{-1} (value - loc))

    All expensive operations (``solve``, ``logdet``) dispatch on
    operator structure automatically, or through an explicit *solver*.

    Args:
        loc: Mean vector, shape ``(N,)``.
        cov_operator: Covariance operator Sigma, shape ``(N, N)``.
        value: Observation vector, shape ``(N,)``.
        solver: Optional solver strategy (needs both solve and logdet).
            When ``None``, uses structural dispatch.

    Returns:
        Scalar log-probability.
    """
    return _gaussian_log_prob_residual(value - loc, cov_operator, solver=solver)

gaussian_entropy(cov_operator: lx.AbstractLinearOperator, *, solver: AbstractLogdetStrategy | None = None) -> Float[Array, '']

Entropy of a multivariate normal N(mu, Sigma).

Computes:

H = 0.5 * (N * (1 + log(2 pi)) + log|Sigma|)

Independent of the mean.

Parameters:

Name Type Description Default
cov_operator AbstractLinearOperator

Covariance operator, shape (N, N).

required
solver AbstractLogdetStrategy | None

Optional logdet strategy. When None, uses structural dispatch.

None

Returns:

Type Description
Float[Array, '']

Scalar entropy.

Source code in src/gaussx/_distributions/_gaussian.py
def gaussian_entropy(
    cov_operator: lx.AbstractLinearOperator,
    *,
    solver: AbstractLogdetStrategy | None = None,
) -> Float[Array, ""]:
    """Entropy of a multivariate normal ``N(mu, Sigma)``.

    Computes:

        H = 0.5 * (N * (1 + log(2 pi)) + log|Sigma|)

    Independent of the mean.

    Args:
        cov_operator: Covariance operator, shape ``(N, N)``.
        solver: Optional logdet strategy. When ``None``, uses
            structural dispatch.

    Returns:
        Scalar entropy.
    """
    N = cov_operator.in_size()
    ld = dispatch_logdet(cov_operator, solver)
    return 0.5 * (N * (1.0 + _LOG_2PI) + ld)

quadratic_form(operator: lx.AbstractLinearOperator, x: Float[Array, ' N'], *, solver: AbstractSolveStrategy | None = None) -> Float[Array, '']

Compute x^T A^{-1} x via a single solve.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A non-singular linear operator A.

required
x Float[Array, ' N']

Vector, shape (N,).

required
solver AbstractSolveStrategy | None

Optional solve strategy. When None, uses structural dispatch.

None

Returns:

Type Description
Float[Array, '']

Scalar x^T A^{-1} x.

Source code in src/gaussx/_distributions/_gaussian.py
def quadratic_form(
    operator: lx.AbstractLinearOperator,
    x: Float[Array, " N"],
    *,
    solver: AbstractSolveStrategy | None = None,
) -> Float[Array, ""]:
    """Compute ``x^T A^{-1} x`` via a single solve.

    Args:
        operator: A non-singular linear operator A.
        x: Vector, shape ``(N,)``.
        solver: Optional solve strategy. When ``None``, uses
            structural dispatch.

    Returns:
        Scalar ``x^T A^{-1} x``.
    """
    return x @ dispatch_solve(operator, x, solver)

kl_standard_normal(m: Float[Array, ' N'], S: lx.AbstractLinearOperator, *, solver: AbstractLogdetStrategy | None = None) -> Float[Array, '']

KL divergence KL(N(m, S) || N(0, I)).

Special case of dist_kl_divergence with q_loc = 0 and q_cov = I. The identity prior means no matrix inversion is required, making this more efficient than calling the general form directly.

Computes:

KL = 0.5 * (tr(S) + m^T m - N - log|S|)

Ubiquitous in variational inference as the prior KL term.

Parameters:

Name Type Description Default
m Float[Array, ' N']

Mean vector, shape (N,).

required
S AbstractLinearOperator

Covariance operator, shape (N, N).

required
solver AbstractLogdetStrategy | None

Optional logdet strategy. When None, uses structural dispatch.

None

Returns:

Type Description
Float[Array, '']

Scalar KL divergence.

See Also

dist_kl_divergence: General KL between two multivariate normals with arbitrary lineax covariance operators.

Source code in src/gaussx/_distributions/_gaussian.py
def kl_standard_normal(
    m: Float[Array, " N"],
    S: lx.AbstractLinearOperator,
    *,
    solver: AbstractLogdetStrategy | None = None,
) -> Float[Array, ""]:
    """KL divergence ``KL(N(m, S) || N(0, I))``.

    Special case of `dist_kl_divergence`
    with ``q_loc = 0`` and ``q_cov = I``.  The identity prior means no
    matrix inversion is required, making this more efficient than calling
    the general form directly.

    Computes:

        KL = 0.5 * (tr(S) + m^T m - N - log|S|)

    Ubiquitous in variational inference as the prior KL term.

    Args:
        m: Mean vector, shape ``(N,)``.
        S: Covariance operator, shape ``(N, N)``.
        solver: Optional logdet strategy. When ``None``, uses
            structural dispatch.

    Returns:
        Scalar KL divergence.

    See Also:
        `dist_kl_divergence`: General KL
        between two multivariate normals with arbitrary lineax covariance
        operators.
    """
    N = m.shape[-1]
    tr_S = trace(S)
    mTm = m @ m
    ld = dispatch_logdet(S, solver)
    return 0.5 * (tr_S + mTm - N - ld)

dist_kl_divergence(p_loc: Float[Array, ' N'], p_cov: lx.AbstractLinearOperator, q_loc: Float[Array, ' N'], q_cov: lx.AbstractLinearOperator) -> Float[Array, '']

KL divergence KL(p || q) between two multivariate normals.

This is the canonical KL implementation for lineax-operator covariances. The specialised variants below all compute the same quantity but with different parameterisations suited to their use cases:

  • kl_standard_normal — special case KL(N(m, S) || N(0, I)); avoids matrix inversion.
  • gauss_kl — Cholesky-parameterised form for GP/SVGP models; supports multi-output and diagonal q_sqrt.
  • kl_divergence — Bregman-divergence form operating on natural parameters for the exponential family.
\[ KL(p \| q) = \frac{1}{2}\bigl( \operatorname{tr}(\Sigma_q^{-1} \Sigma_p) + (\mu_q - \mu_p)^T \Sigma_q^{-1} (\mu_q - \mu_p) - N + \log|\Sigma_q| - \log|\Sigma_p| \bigr) \]

Exploits structured operators for the trace and logdet terms.

Parameters:

Name Type Description Default
p_loc Float[Array, ' N']

Mean of distribution p, shape (N,).

required
p_cov AbstractLinearOperator

Covariance operator of distribution p.

required
q_loc Float[Array, ' N']

Mean of distribution q, shape (N,).

required
q_cov AbstractLinearOperator

Covariance operator of distribution q.

required

Returns:

Type Description
Float[Array, '']

Scalar KL divergence.

Source code in src/gaussx/_distributions/_kl.py
def dist_kl_divergence(
    p_loc: Float[Array, " N"],
    p_cov: lx.AbstractLinearOperator,
    q_loc: Float[Array, " N"],
    q_cov: lx.AbstractLinearOperator,
) -> Float[Array, ""]:
    r"""KL divergence ``KL(p || q)`` between two multivariate normals.

    This is the **canonical KL implementation** for lineax-operator covariances.
    The specialised variants below all compute the same quantity but with
    different parameterisations suited to their use cases:

    - `kl_standard_normal` —
      special case ``KL(N(m, S) || N(0, I))``; avoids matrix inversion.
    - `gauss_kl` — Cholesky-parameterised form
      for GP/SVGP models; supports multi-output and diagonal ``q_sqrt``.
    - `kl_divergence` — Bregman-divergence
      form operating on natural parameters for the exponential family.

    $$
    KL(p \| q) = \frac{1}{2}\bigl(
        \operatorname{tr}(\Sigma_q^{-1} \Sigma_p)
        + (\mu_q - \mu_p)^T \Sigma_q^{-1} (\mu_q - \mu_p)
        - N
        + \log|\Sigma_q| - \log|\Sigma_p|
    \bigr)
    $$

    Exploits structured operators for the trace and logdet terms.

    Args:
        p_loc: Mean of distribution p, shape ``(N,)``.
        p_cov: Covariance operator of distribution p.
        q_loc: Mean of distribution q, shape ``(N,)``.
        q_cov: Covariance operator of distribution q.

    Returns:
        Scalar KL divergence.
    """
    N = p_loc.shape[-1]
    delta = q_loc - p_loc

    # tr(Sigma_q^{-1} Sigma_p)
    q_inv = inv(q_cov)
    trace_term = trace_product(q_inv, p_cov)

    # Quadratic term: delta^T Sigma_q^{-1} delta
    quad = jnp.sum(delta * solve(q_cov, delta))

    # Log-determinant difference
    ld_q = logdet(q_cov)
    ld_p = logdet(p_cov)

    return 0.5 * (trace_term + quad - N + ld_q - ld_p)

conditional(loc: Float[Array, ' N'], cov: lx.AbstractLinearOperator, obs_idx: Int[Array, ' M'], obs_values: Float[Array, ' M'], *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, ' R'], lx.AbstractLinearOperator]

Compute p(x_A | x_B = b) from a joint Gaussian p(x_A, x_B).

Given a joint distribution \(\mathcal{N}(\mu, \Sigma)\) and observed indices B with values b, returns the conditional distribution over the remaining indices A:

\[ \begin{aligned} \mu_{A|B} &= \mu_A + \Sigma_{AB} \Sigma_{BB}^{-1} (b - \mu_B) \\ \Sigma_{A|B} &= \Sigma_{AA} - \Sigma_{AB} \Sigma_{BB}^{-1} \Sigma_{BA} \end{aligned} \]

Parameters:

Name Type Description Default
loc Float[Array, ' N']

Mean vector of the joint distribution, shape (N,).

required
cov AbstractLinearOperator

Covariance operator of the joint distribution, shape (N, N).

required
obs_idx Int[Array, ' M']

Indices of the observed variables, shape (M,).

required
obs_values Float[Array, ' M']

Observed values, shape (M,).

required
solver AbstractSolverStrategy | None

Optional solver strategy for structured linear algebra. When None, falls back to structural dispatch.

None

Returns:

Type Description
Float[Array, ' R']

Tuple (cond_mean, cond_cov) — mean and covariance of the

AbstractLinearOperator

conditional distribution over unobserved variables.

Source code in src/gaussx/_distributions/_conditional.py
def conditional(
    loc: Float[Array, " N"],
    cov: lx.AbstractLinearOperator,
    obs_idx: Int[Array, " M"],
    obs_values: Float[Array, " M"],
    *,
    solver: AbstractSolverStrategy | None = None,
) -> tuple[Float[Array, " R"], lx.AbstractLinearOperator]:
    r"""Compute ``p(x_A | x_B = b)`` from a joint Gaussian ``p(x_A, x_B)``.

    Given a joint distribution $\mathcal{N}(\mu, \Sigma)$ and
    observed indices *B* with values *b*, returns the conditional
    distribution over the remaining indices *A*:

    $$
    \begin{aligned}
    \mu_{A|B} &= \mu_A + \Sigma_{AB} \Sigma_{BB}^{-1} (b - \mu_B) \\
    \Sigma_{A|B} &= \Sigma_{AA} - \Sigma_{AB} \Sigma_{BB}^{-1} \Sigma_{BA}
    \end{aligned}
    $$

    Args:
        loc: Mean vector of the joint distribution, shape ``(N,)``.
        cov: Covariance operator of the joint distribution, shape ``(N, N)``.
        obs_idx: Indices of the observed variables, shape ``(M,)``.
        obs_values: Observed values, shape ``(M,)``.
        solver: Optional solver strategy for structured linear algebra.
            When ``None``, falls back to structural dispatch.

    Returns:
        Tuple ``(cond_mean, cond_cov)`` — mean and covariance of the
        conditional distribution over unobserved variables.
    """
    N = loc.shape[0]
    obs_idx = jnp.asarray(obs_idx, dtype=jnp.int32)
    obs_values = jnp.asarray(obs_values, dtype=loc.dtype)

    if obs_idx.ndim != 1:
        raise ValueError("obs_idx must be a 1D array.")
    if obs_values.shape != obs_idx.shape:
        raise ValueError("obs_values must have the same shape as obs_idx.")
    if bool(jnp.any((obs_idx < 0) | (obs_idx >= N))):
        raise ValueError(f"obs_idx must be within bounds [0, {N}).")
    if bool(jnp.any(jnp.diff(jnp.sort(obs_idx)) == 0)):
        raise ValueError("obs_idx must not contain duplicates.")

    # Build mask for unobserved indices
    mask = jnp.ones(N, dtype=bool).at[obs_idx].set(False)
    free_idx = jnp.where(mask, size=N - obs_idx.shape[0])[0]

    # Extract sub-blocks via structural dispatch — avoids materializing
    # the full ``(N, N)`` joint covariance for structured operators
    # (Diagonal, BlockDiag). Falls back to full materialization for
    # unstructured operators where there is no efficient alternative.
    Sigma_AA = submatrix(cov, free_idx, free_idx)
    Sigma_AB = submatrix(cov, free_idx, obs_idx)
    Sigma_BB = submatrix(cov, obs_idx, obs_idx)

    mu_A = loc[free_idx]
    mu_B = loc[obs_idx]

    # Sigma_BB^{-1} (b - mu_B)
    residual = obs_values - mu_B
    Sigma_BB_op = lx.MatrixLinearOperator(Sigma_BB, lx.positive_semidefinite_tag)
    alpha = dispatch_solve(Sigma_BB_op, residual, solver)

    # Conditional mean: mu_A + Sigma_AB @ alpha
    cond_mean = mu_A + Sigma_AB @ alpha

    # Sigma_BB^{-1} Sigma_BA — single matrix solve (one factorization for
    # the whole RHS in the default PSD path).
    Sigma_BA = Sigma_AB.T
    X = solve_matrix(Sigma_BB_op, Sigma_BA, solver=solver)

    # Conditional covariance: Sigma_AA - Sigma_AB @ X
    cond_cov_mat = Sigma_AA - Sigma_AB @ X

    # Symmetrize for numerical stability
    cond_cov_mat = symmetrize(cond_cov_mat)
    cond_cov = lx.MatrixLinearOperator(cond_cov_mat, lx.positive_semidefinite_tag)

    return cond_mean, cond_cov

joseph_update(P_pred: Float[Array, 'N N'], K: Float[Array, 'N M'], H: Float[Array, 'M N'], R: Float[Array, 'M M']) -> Float[Array, 'N N']

Numerically stable Joseph-form covariance update.

Computes the updated covariance after a Kalman measurement update:

P_update = (I - K H) P_pred (I - K H)^T + K R K^T

This form is more numerically stable than the simplified P = P_pred - K S K^T or P = (I - K H) P_pred because it guarantees symmetry and is more robust when the Kalman gain K is approximate or the system is poorly conditioned.

Parameters:

Name Type Description Default
P_pred Float[Array, 'N N']

Predicted covariance, shape (N, N).

required
K Float[Array, 'N M']

Kalman gain, shape (N, M).

required
H Float[Array, 'M N']

Observation model, shape (M, N).

required
R Float[Array, 'M M']

Observation noise covariance, shape (M, M).

required

Returns:

Type Description
Float[Array, 'N N']

Updated covariance, shape (N, N).

Source code in src/gaussx/_distributions/_joseph.py
def joseph_update(
    P_pred: Float[Array, "N N"],
    K: Float[Array, "N M"],
    H: Float[Array, "M N"],
    R: Float[Array, "M M"],
) -> Float[Array, "N N"]:
    r"""Numerically stable Joseph-form covariance update.

    Computes the updated covariance after a Kalman measurement update:

        P_update = (I - K H) P_pred (I - K H)^T + K R K^T

    This form is more numerically stable than the simplified
    ``P = P_pred - K S K^T`` or ``P = (I - K H) P_pred`` because it
    guarantees symmetry and is more robust when the Kalman gain ``K``
    is approximate or the system is poorly conditioned.

    Args:
        P_pred: Predicted covariance, shape ``(N, N)``.
        K: Kalman gain, shape ``(N, M)``.
        H: Observation model, shape ``(M, N)``.
        R: Observation noise covariance, shape ``(M, M)``.

    Returns:
        Updated covariance, shape ``(N, N)``.
    """
    N = P_pred.shape[0]
    I_KH = jnp.eye(N, dtype=P_pred.dtype) - K @ H  # (N, N)
    P_update = I_KH @ P_pred @ I_KH.T + K @ R @ K.T
    return (P_update + P_update.T) / 2

add_jitter(operator: lx.AbstractLinearOperator, jitter: float = 1e-06) -> lx.AbstractLinearOperator

Add diagonal jitter for numerical stability: A + eps * I.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A linear operator, shape (N, N).

required
jitter float

Scalar jitter value. Default 1e-6.

1e-06

Returns:

Type Description
AbstractLinearOperator

A + jitter * I as a lineax AddLinearOperator.

Source code in src/gaussx/_distributions/_gaussian.py
def add_jitter(
    operator: lx.AbstractLinearOperator,
    jitter: float = 1e-6,
) -> lx.AbstractLinearOperator:
    """Add diagonal jitter for numerical stability: ``A + eps * I``.

    Args:
        operator: A linear operator, shape ``(N, N)``.
        jitter: Scalar jitter value. Default ``1e-6``.

    Returns:
        ``A + jitter * I`` as a lineax ``AddLinearOperator``.
    """
    n = operator.in_size()
    dtype = operator.out_structure().dtype
    jitter_op = lx.DiagonalLinearOperator(jnp.full(n, jitter, dtype=dtype))
    return operator + jitter_op

project(K_XZ: Float[Array, 'B M'], L_Z: lx.AbstractLinearOperator) -> Float[Array, 'B M']

Compute A_X = K_XZ @ K_ZZ^{-1} via Cholesky solve.

Solves L_Z @ L_Z^T @ A_X^T = K_XZ^T using forward/backward substitution. Used in sparse variational GPs to project test points onto the inducing space.

Parameters:

Name Type Description Default
K_XZ Float[Array, 'B M']

Cross-covariance matrix, shape (B, M).

required
L_Z AbstractLinearOperator

Lower-triangular Cholesky factor of K_ZZ, shape (M, M).

required

Returns:

Type Description
Float[Array, 'B M']

Projection matrix A_X, shape (B, M).

Source code in src/gaussx/_distributions/_project.py
def project(
    K_XZ: Float[Array, "B M"],
    L_Z: lx.AbstractLinearOperator,
) -> Float[Array, "B M"]:
    """Compute A_X = K_XZ @ K_ZZ^{-1} via Cholesky solve.

    Solves ``L_Z @ L_Z^T @ A_X^T = K_XZ^T`` using forward/backward
    substitution.  Used in sparse variational GPs to project test
    points onto the inducing space.

    Args:
        K_XZ: Cross-covariance matrix, shape ``(B, M)``.
        L_Z: Lower-triangular Cholesky factor of K_ZZ, shape ``(M, M)``.

    Returns:
        Projection matrix A_X, shape ``(B, M)``.
    """
    # Solve L_Z @ Y = K_XZ^T, then L_Z^T @ A_X^T = Y
    # Equivalently, solve (L_Z @ L_Z^T) @ A_X^T = K_XZ^T per column
    solver = lx.Triangular()

    def _solve_col(kxz_row):
        # Solve L_Z y = kxz_row
        y = lx.linear_solve(L_Z, kxz_row, solver).value
        # Solve L_Z^T a = y
        return lx.linear_solve(L_Z.T, y, solver).value

    return jax.vmap(_solve_col)(K_XZ)

Exponential family

The Gaussian in natural form: \(\eta_1 = \Lambda\mu\), \(\eta_2 = -\tfrac12 \Lambda\). Conversions between mean/covariance, natural, and expectation parameterizations — multivariate (operator-aware) and univariate (per-site diagonal) — plus the log-partition, Fisher information, and sufficient statistics that natural-gradient and EP updates are built from.

Structured linear algebra and Gaussian primitives for JAX.

GaussianExpFam

Bases: Module

Gaussian in natural (exponential family) parameters.

\[ q(x \mid \eta) = h(x) \exp(\eta^T T(x) - A(\eta)) \]

where:

  • Natural parameters: eta1 = Lambda @ mu, eta2 = -0.5 * Lambda
  • Sufficient statistics: T(x) = [x, x x^T]
  • Log-partition: A(eta) = -0.25 * eta1^T eta2^{-1} eta1 - 0.5 * log|-2 eta2|
  • Base measure: h(x) = (2 pi)^{-N/2}

Attributes:

Name Type Description
eta1 Float[Array, ' N']

Natural location parameter, shape (N,).

eta2 AbstractLinearOperator

Natural precision-like operator, shape (N, N). Represents -0.5 * Lambda where Lambda is the precision.

Source code in src/gaussx/_expfam/_gaussian.py
class GaussianExpFam(eqx.Module):
    r"""Gaussian in natural (exponential family) parameters.

    $$
    q(x \mid \eta) = h(x) \exp(\eta^T T(x) - A(\eta))
    $$

    where:

    - Natural parameters: ``eta1 = Lambda @ mu``, ``eta2 = -0.5 * Lambda``
    - Sufficient statistics: ``T(x) = [x, x x^T]``
    - Log-partition: ``A(eta) = -0.25 * eta1^T eta2^{-1} eta1 - 0.5 * log|-2 eta2|``
    - Base measure: ``h(x) = (2 pi)^{-N/2}``

    Attributes:
        eta1: Natural location parameter, shape ``(N,)``.
        eta2: Natural precision-like operator, shape ``(N, N)``.
            Represents ``-0.5 * Lambda`` where Lambda is the precision.
    """

    eta1: Float[Array, " N"]
    eta2: lx.AbstractLinearOperator

    @staticmethod
    def from_mean_cov(
        mu: Float[Array, " N"],
        Sigma: lx.AbstractLinearOperator,
    ) -> GaussianExpFam:
        """Construct from mean and covariance.

        Args:
            mu: Mean vector, shape ``(N,)``.
            Sigma: Covariance operator, shape ``(N, N)``.

        Returns:
            A ``GaussianExpFam`` instance.
        """
        eta1, eta2 = mean_cov_to_natural(mu, Sigma)
        return GaussianExpFam(eta1=eta1, eta2=eta2)

    @staticmethod
    def from_mean_prec(
        mu: Float[Array, " N"],
        Lambda: lx.AbstractLinearOperator,
    ) -> GaussianExpFam:
        """Construct from mean and precision.

        Args:
            mu: Mean vector, shape ``(N,)``.
            Lambda: Precision operator, shape ``(N, N)``.

        Returns:
            A ``GaussianExpFam`` instance.
        """
        eta1 = Lambda.mv(mu)
        eta2 = -0.5 * Lambda
        return GaussianExpFam(eta1=eta1, eta2=eta2)

from_mean_cov(mu: Float[Array, ' N'], Sigma: lx.AbstractLinearOperator) -> GaussianExpFam staticmethod

Construct from mean and covariance.

Parameters:

Name Type Description Default
mu Float[Array, ' N']

Mean vector, shape (N,).

required
Sigma AbstractLinearOperator

Covariance operator, shape (N, N).

required

Returns:

Type Description
GaussianExpFam

A GaussianExpFam instance.

Source code in src/gaussx/_expfam/_gaussian.py
@staticmethod
def from_mean_cov(
    mu: Float[Array, " N"],
    Sigma: lx.AbstractLinearOperator,
) -> GaussianExpFam:
    """Construct from mean and covariance.

    Args:
        mu: Mean vector, shape ``(N,)``.
        Sigma: Covariance operator, shape ``(N, N)``.

    Returns:
        A ``GaussianExpFam`` instance.
    """
    eta1, eta2 = mean_cov_to_natural(mu, Sigma)
    return GaussianExpFam(eta1=eta1, eta2=eta2)

from_mean_prec(mu: Float[Array, ' N'], Lambda: lx.AbstractLinearOperator) -> GaussianExpFam staticmethod

Construct from mean and precision.

Parameters:

Name Type Description Default
mu Float[Array, ' N']

Mean vector, shape (N,).

required
Lambda AbstractLinearOperator

Precision operator, shape (N, N).

required

Returns:

Type Description
GaussianExpFam

A GaussianExpFam instance.

Source code in src/gaussx/_expfam/_gaussian.py
@staticmethod
def from_mean_prec(
    mu: Float[Array, " N"],
    Lambda: lx.AbstractLinearOperator,
) -> GaussianExpFam:
    """Construct from mean and precision.

    Args:
        mu: Mean vector, shape ``(N,)``.
        Lambda: Precision operator, shape ``(N, N)``.

    Returns:
        A ``GaussianExpFam`` instance.
    """
    eta1 = Lambda.mv(mu)
    eta2 = -0.5 * Lambda
    return GaussianExpFam(eta1=eta1, eta2=eta2)

to_natural(mu: Float[Array, ' N'], Sigma: lx.AbstractLinearOperator) -> tuple[Float[Array, ' N'], lx.AbstractLinearOperator]

Convert expectation to natural parameters.

Parameters:

Name Type Description Default
mu Float[Array, ' N']

Mean vector, shape (N,).

required
Sigma AbstractLinearOperator

Covariance operator, shape (N, N).

required

Returns:

Type Description
tuple[Float[Array, ' N'], AbstractLinearOperator]

Tuple (eta1, eta2) — natural parameters.

Source code in src/gaussx/_expfam/_gaussian.py
def to_natural(
    mu: Float[Array, " N"],
    Sigma: lx.AbstractLinearOperator,
) -> tuple[Float[Array, " N"], lx.AbstractLinearOperator]:
    """Convert expectation to natural parameters.

    Args:
        mu: Mean vector, shape ``(N,)``.
        Sigma: Covariance operator, shape ``(N, N)``.

    Returns:
        Tuple ``(eta1, eta2)`` — natural parameters.
    """
    return mean_cov_to_natural(mu, Sigma)

to_expectation(expfam: GaussianExpFam) -> tuple[Float[Array, ' N'], lx.AbstractLinearOperator]

Convert natural to expectation parameters.

Parameters:

Name Type Description Default
expfam GaussianExpFam

Gaussian in natural form.

required

Returns:

Type Description
tuple[Float[Array, ' N'], AbstractLinearOperator]

Tuple (mu, Sigma) — mean vector and covariance operator.

Source code in src/gaussx/_expfam/_gaussian.py
def to_expectation(
    expfam: GaussianExpFam,
) -> tuple[Float[Array, " N"], lx.AbstractLinearOperator]:
    """Convert natural to expectation parameters.

    Args:
        expfam: Gaussian in natural form.

    Returns:
        Tuple ``(mu, Sigma)`` — mean vector and covariance operator.
    """
    return natural_to_mean_cov(expfam.eta1, expfam.eta2)

mean_cov_to_natural(mu: Float[Array, ' N'], Sigma: lx.AbstractLinearOperator, *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, ' N'], lx.AbstractLinearOperator]

Convert mean/covariance to natural parameters (operator form).

Given mean mu and covariance Sigma:

  • eta1 = solve(Sigma, mu)
  • eta2 = -0.5 * inv(Sigma)

Operator structure (diagonal, Kronecker, …) is exploited via structural dispatch. For dense-array inputs see meanvar_to_natural.

For block-tridiagonal (SSM) inputs see gaussx._ssm._ssm_natural.ssm_to_naturals.

Parameters:

Name Type Description Default
mu Float[Array, ' N']

Mean vector, shape (N,).

required
Sigma AbstractLinearOperator

Covariance operator, shape (N, N).

required
solver AbstractSolverStrategy | None

Optional solver strategy. When None, uses structural dispatch.

None

Returns:

Type Description
Float[Array, ' N']

Tuple (eta1, eta2) where eta1 is shape (N,) and

AbstractLinearOperator

eta2 is a linear operator.

Source code in src/gaussx/_expfam/_natural.py
def mean_cov_to_natural(
    mu: Float[Array, " N"],
    Sigma: lx.AbstractLinearOperator,
    *,
    solver: AbstractSolverStrategy | None = None,
) -> tuple[Float[Array, " N"], lx.AbstractLinearOperator]:
    """Convert mean/covariance to natural parameters (operator form).

    Given mean ``mu`` and covariance ``Sigma``:

    - ``eta1 = solve(Sigma, mu)``
    - ``eta2 = -0.5 * inv(Sigma)``

    Operator structure (diagonal, Kronecker, …) is exploited via
    structural dispatch. For dense-array inputs see
    `meanvar_to_natural`.

    For block-tridiagonal (SSM) inputs see
    `gaussx._ssm._ssm_natural.ssm_to_naturals`.

    Args:
        mu: Mean vector, shape ``(N,)``.
        Sigma: Covariance operator, shape ``(N, N)``.
        solver: Optional solver strategy. When ``None``, uses
            structural dispatch.

    Returns:
        Tuple ``(eta1, eta2)`` where eta1 is shape ``(N,)`` and
        eta2 is a linear operator.
    """
    eta1 = dispatch_solve(Sigma, mu, solver)
    eta2 = -0.5 * inv(Sigma)
    return eta1, eta2

natural_to_mean_cov(eta1: Float[Array, ' N'], eta2: lx.AbstractLinearOperator, *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, ' N'], lx.AbstractLinearOperator]

Convert natural parameters to mean/covariance (operator form).

Given natural parameters (eta1, eta2) where eta1 = Lambda @ mu and eta2 = -0.5 * Lambda:

  • mu = solve(-2 * eta2, eta1)
  • Sigma = inv(-2 * eta2)

Operator structure (diagonal, Kronecker, …) is exploited via structural dispatch. For dense-array inputs see natural_to_meanvar.

For block-tridiagonal (SSM) inputs see gaussx._ssm._ssm_natural.naturals_to_ssm.

Parameters:

Name Type Description Default
eta1 Float[Array, ' N']

Natural location parameter, shape (N,).

required
eta2 AbstractLinearOperator

Natural precision-like operator, shape (N, N).

required
solver AbstractSolverStrategy | None

Optional solver strategy. When None, uses structural dispatch.

None

Returns:

Type Description
Float[Array, ' N']

Tuple (mu, Sigma) where mu is shape (N,) and

AbstractLinearOperator

Sigma is a linear operator.

Source code in src/gaussx/_expfam/_natural.py
def natural_to_mean_cov(
    eta1: Float[Array, " N"],
    eta2: lx.AbstractLinearOperator,
    *,
    solver: AbstractSolverStrategy | None = None,
) -> tuple[Float[Array, " N"], lx.AbstractLinearOperator]:
    """Convert natural parameters to mean/covariance (operator form).

    Given natural parameters ``(eta1, eta2)`` where
    ``eta1 = Lambda @ mu`` and ``eta2 = -0.5 * Lambda``:

    - ``mu = solve(-2 * eta2, eta1)``
    - ``Sigma = inv(-2 * eta2)``

    Operator structure (diagonal, Kronecker, …) is exploited via
    structural dispatch. For dense-array inputs see
    `natural_to_meanvar`.

    For block-tridiagonal (SSM) inputs see
    `gaussx._ssm._ssm_natural.naturals_to_ssm`.

    Args:
        eta1: Natural location parameter, shape ``(N,)``.
        eta2: Natural precision-like operator, shape ``(N, N)``.
        solver: Optional solver strategy. When ``None``, uses
            structural dispatch.

    Returns:
        Tuple ``(mu, Sigma)`` where mu is shape ``(N,)`` and
        Sigma is a linear operator.
    """
    neg2_eta2 = -2.0 * eta2
    mu = dispatch_solve(neg2_eta2, eta1, solver)
    Sigma = inv(neg2_eta2)
    return mu, Sigma

meanvar_to_natural(mu: Float[Array, '*batch N'], S_sqrt: Float[Array, '*batch N N']) -> tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]

Convert mean/variance (Cholesky) to natural parameters.

Given mu and lower-triangular S_sqrt such that Sigma = S_sqrt @ S_sqrt^T:

  • eta1 = Sigma^{-1} mu
  • eta2 = -0.5 * Sigma^{-1}

Uses the Cholesky factor directly via triangular solves; no solver parameter is exposed because the underlying systems are triangular rather than symmetric/PSD, and iterative strategies (CG, BBMM, PreconditionedCG, MINRES) are not valid here.

Parameters:

Name Type Description Default
mu Float[Array, '*batch N']

Mean vector, shape (*batch, N).

required
S_sqrt Float[Array, '*batch N N']

Lower-triangular Cholesky factor, shape (*batch, N, N).

required

Returns:

Type Description
tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]

Tuple (eta1, eta2) of natural parameters.

Source code in src/gaussx/_expfam/_natural.py
def meanvar_to_natural(
    mu: Float[Array, "*batch N"],
    S_sqrt: Float[Array, "*batch N N"],
) -> tuple[Float[Array, "*batch N"], Float[Array, "*batch N N"]]:
    r"""Convert mean/variance (Cholesky) to natural parameters.

    Given ``mu`` and lower-triangular ``S_sqrt`` such that
    ``Sigma = S_sqrt @ S_sqrt^T``:

    - ``eta1 = Sigma^{-1} mu``
    - ``eta2 = -0.5 * Sigma^{-1}``

    Uses the Cholesky factor directly via triangular solves; no solver
    parameter is exposed because the underlying systems are triangular
    rather than symmetric/PSD, and iterative strategies (CG, BBMM,
    PreconditionedCG, MINRES) are not valid here.

    Args:
        mu: Mean vector, shape ``(*batch, N)``.
        S_sqrt: Lower-triangular Cholesky factor, shape ``(*batch, N, N)``.

    Returns:
        Tuple ``(eta1, eta2)`` of natural parameters.
    """

    def _core(mu_s: Float[Array, " N"], s_sqrt_s: Float[Array, "N N"]):
        # eta1 = Sigma^{-1} mu = S_sqrt^{-T} S_sqrt^{-1} mu via cho_solve.
        eta1_s = jax.scipy.linalg.cho_solve((s_sqrt_s, True), mu_s)
        # eta2 = -0.5 * Sigma^{-1}, computed by a single matrix cho_solve.
        N = s_sqrt_s.shape[0]
        identity = jnp.eye(N, dtype=s_sqrt_s.dtype)
        Sigma_inv = jax.scipy.linalg.cho_solve((s_sqrt_s, True), identity)
        return eta1_s, -0.5 * Sigma_inv

    *batch, N = mu.shape
    if not batch:
        return _core(mu, S_sqrt)
    mu_flat = mu.reshape(-1, N)
    s_flat = S_sqrt.reshape(-1, N, N)
    eta1_flat, eta2_flat = jax.vmap(_core)(mu_flat, s_flat)
    return eta1_flat.reshape(mu.shape), eta2_flat.reshape(S_sqrt.shape)

natural_to_meanvar(eta1: Float[Array, '*batch N'], eta2: Float[Array, '*batch N N'], *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]

Convert natural parameters to mean/variance (Cholesky).

Given eta1 = Lambda @ mu and eta2 = -0.5 * Lambda:

  • Sigma = (-2 * eta2)^{-1}
  • mu = Sigma @ eta1
  • S_sqrt = cholesky(Sigma)

Parameters:

Name Type Description Default
eta1 Float[Array, '*batch N']

Natural location parameter, shape (*batch, N).

required
eta2 Float[Array, '*batch N N']

Natural quadratic parameter, shape (*batch, N, N).

required
solver AbstractSolverStrategy | None

Optional solver strategy for structured linear algebra. When None, falls back to structural dispatch.

None

Returns:

Type Description
Float[Array, '*batch N']

Tuple (mu, S_sqrt) where S_sqrt is the lower-triangular

Float[Array, '*batch N N']

Cholesky factor of the covariance.

Source code in src/gaussx/_expfam/_natural.py
def natural_to_meanvar(
    eta1: Float[Array, "*batch N"],
    eta2: Float[Array, "*batch N N"],
    *,
    solver: AbstractSolverStrategy | None = None,
) -> tuple[Float[Array, "*batch N"], Float[Array, "*batch N N"]]:
    r"""Convert natural parameters to mean/variance (Cholesky).

    Given ``eta1 = Lambda @ mu`` and ``eta2 = -0.5 * Lambda``:

    - ``Sigma = (-2 * eta2)^{-1}``
    - ``mu = Sigma @ eta1``
    - ``S_sqrt = cholesky(Sigma)``

    Args:
        eta1: Natural location parameter, shape ``(*batch, N)``.
        eta2: Natural quadratic parameter, shape ``(*batch, N, N)``.
        solver: Optional solver strategy for structured linear algebra.
            When ``None``, falls back to structural dispatch.

    Returns:
        Tuple ``(mu, S_sqrt)`` where ``S_sqrt`` is the lower-triangular
        Cholesky factor of the covariance.
    """

    def _core(e1: Float[Array, " N"], e2: Float[Array, "N N"]):
        Lambda_op = lx.MatrixLinearOperator(-2.0 * e2, lx.positive_semidefinite_tag)
        mu_s = dispatch_solve(Lambda_op, e1, solver)
        Sigma = inv(Lambda_op).as_matrix()
        Sigma_op = lx.MatrixLinearOperator(Sigma, lx.positive_semidefinite_tag)
        return mu_s, cholesky(Sigma_op).as_matrix()

    *batch, N = eta1.shape
    if not batch:
        return _core(eta1, eta2)
    eta1_flat = eta1.reshape(-1, N)
    eta2_flat = eta2.reshape(-1, N, N)
    mu_flat, s_flat = jax.vmap(_core)(eta1_flat, eta2_flat)
    return mu_flat.reshape(eta1.shape), s_flat.reshape(eta2.shape)

meanvar_to_expectation(mu: Float[Array, '*batch N'], S_sqrt: Float[Array, '*batch N N']) -> tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]

Convert mean/variance (Cholesky) to expectation parameters.

Given mu and S_sqrt (lower-triangular Cholesky of Sigma):

  • m1 = mu
  • m2 = mu @ mu^T + Sigma = mu @ mu^T + S_sqrt @ S_sqrt^T

Parameters:

Name Type Description Default
mu Float[Array, '*batch N']

Mean vector, shape (*batch, N).

required
S_sqrt Float[Array, '*batch N N']

Lower-triangular Cholesky factor, shape (*batch, N, N).

required

Returns:

Type Description
tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]

Tuple (m1, m2) of expectation parameters.

Source code in src/gaussx/_expfam/_natural.py
def meanvar_to_expectation(
    mu: Float[Array, "*batch N"],
    S_sqrt: Float[Array, "*batch N N"],
) -> tuple[Float[Array, "*batch N"], Float[Array, "*batch N N"]]:
    r"""Convert mean/variance (Cholesky) to expectation parameters.

    Given ``mu`` and ``S_sqrt`` (lower-triangular Cholesky of ``Sigma``):

    - ``m1 = mu``
    - ``m2 = mu @ mu^T + Sigma = mu @ mu^T + S_sqrt @ S_sqrt^T``

    Args:
        mu: Mean vector, shape ``(*batch, N)``.
        S_sqrt: Lower-triangular Cholesky factor, shape ``(*batch, N, N)``.

    Returns:
        Tuple ``(m1, m2)`` of expectation parameters.
    """
    m1 = mu
    m2 = mu[..., None] * mu[..., None, :] + S_sqrt @ S_sqrt.mT
    return m1, m2

expectation_to_meanvar(m1: Float[Array, '*batch N'], m2: Float[Array, '*batch N N']) -> tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]

Convert expectation parameters to mean/variance (Cholesky).

Given m1 = mu and m2 = mu @ mu^T + Sigma:

  • mu = m1
  • Sigma = m2 - m1 @ m1^T
  • S_sqrt = cholesky(Sigma)

No solver parameter is exposed because the only linear-algebra operation is Cholesky factorization, which is structurally fixed.

Parameters:

Name Type Description Default
m1 Float[Array, '*batch N']

First moment (mean), shape (*batch, N).

required
m2 Float[Array, '*batch N N']

Second moment, shape (*batch, N, N).

required

Returns:

Type Description
Float[Array, '*batch N']

Tuple (mu, S_sqrt) where S_sqrt is the lower-triangular

Float[Array, '*batch N N']

Cholesky factor of the covariance.

Source code in src/gaussx/_expfam/_natural.py
def expectation_to_meanvar(
    m1: Float[Array, "*batch N"],
    m2: Float[Array, "*batch N N"],
) -> tuple[Float[Array, "*batch N"], Float[Array, "*batch N N"]]:
    r"""Convert expectation parameters to mean/variance (Cholesky).

    Given ``m1 = mu`` and ``m2 = mu @ mu^T + Sigma``:

    - ``mu = m1``
    - ``Sigma = m2 - m1 @ m1^T``
    - ``S_sqrt = cholesky(Sigma)``

    No solver parameter is exposed because the only linear-algebra
    operation is Cholesky factorization, which is structurally fixed.

    Args:
        m1: First moment (mean), shape ``(*batch, N)``.
        m2: Second moment, shape ``(*batch, N, N)``.

    Returns:
        Tuple ``(mu, S_sqrt)`` where ``S_sqrt`` is the lower-triangular
        Cholesky factor of the covariance.
    """

    def _core(m1_s: Float[Array, " N"], m2_s: Float[Array, "N N"]):
        Sigma = m2_s - m1_s[:, None] * m1_s[None, :]
        Sigma_op = lx.MatrixLinearOperator(Sigma, lx.positive_semidefinite_tag)
        return m1_s, cholesky(Sigma_op).as_matrix()

    *batch, N = m1.shape
    if not batch:
        return _core(m1, m2)
    m1_flat = m1.reshape(-1, N)
    m2_flat = m2.reshape(-1, N, N)
    mu_flat, s_flat = jax.vmap(_core)(m1_flat, m2_flat)
    return mu_flat.reshape(m1.shape), s_flat.reshape(m2.shape)

expectation_to_natural(m1: Float[Array, '*batch N'], m2: Float[Array, '*batch N N'], *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]

Convert expectation parameters to natural parameters.

Given m1 = mu and m2 = mu @ mu^T + Sigma:

  • Sigma = m2 - m1 @ m1^T
  • eta1 = Sigma^{-1} @ m1
  • eta2 = -0.5 * Sigma^{-1}

Parameters:

Name Type Description Default
m1 Float[Array, '*batch N']

First moment (mean), shape (*batch, N).

required
m2 Float[Array, '*batch N N']

Second moment, shape (*batch, N, N).

required
solver AbstractSolverStrategy | None

Optional solver strategy for structured linear algebra. When None, falls back to structural dispatch.

None

Returns:

Type Description
tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]

Tuple (eta1, eta2) of natural parameters.

Source code in src/gaussx/_expfam/_natural.py
def expectation_to_natural(
    m1: Float[Array, "*batch N"],
    m2: Float[Array, "*batch N N"],
    *,
    solver: AbstractSolverStrategy | None = None,
) -> tuple[Float[Array, "*batch N"], Float[Array, "*batch N N"]]:
    r"""Convert expectation parameters to natural parameters.

    Given ``m1 = mu`` and ``m2 = mu @ mu^T + Sigma``:

    - ``Sigma = m2 - m1 @ m1^T``
    - ``eta1 = Sigma^{-1} @ m1``
    - ``eta2 = -0.5 * Sigma^{-1}``

    Args:
        m1: First moment (mean), shape ``(*batch, N)``.
        m2: Second moment, shape ``(*batch, N, N)``.
        solver: Optional solver strategy for structured linear algebra.
            When ``None``, falls back to structural dispatch.

    Returns:
        Tuple ``(eta1, eta2)`` of natural parameters.
    """

    def _core(m1_s: Float[Array, " N"], m2_s: Float[Array, "N N"]):
        Sigma = m2_s - m1_s[:, None] * m1_s[None, :]
        Sigma_op = lx.MatrixLinearOperator(Sigma, lx.positive_semidefinite_tag)
        eta1_s = dispatch_solve(Sigma_op, m1_s, solver)
        N = m1_s.shape[0]
        identity = jnp.eye(N, dtype=m1_s.dtype)
        Sigma_inv = solve_columns(Sigma_op, identity, solver=solver)
        return eta1_s, -0.5 * Sigma_inv

    *batch, N = m1.shape
    if not batch:
        return _core(m1, m2)
    m1_flat = m1.reshape(-1, N)
    m2_flat = m2.reshape(-1, N, N)
    eta1_flat, eta2_flat = jax.vmap(_core)(m1_flat, m2_flat)
    return eta1_flat.reshape(m1.shape), eta2_flat.reshape(m2.shape)

natural_to_expectation(eta1: Float[Array, '*batch N'], eta2: Float[Array, '*batch N N'], *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]

Convert natural parameters to expectation parameters.

Given eta1 = Lambda @ mu and eta2 = -0.5 * Lambda:

  • Sigma = (-2 * eta2)^{-1}
  • mu = Sigma @ eta1
  • m1 = mu
  • m2 = mu @ mu^T + Sigma

Parameters:

Name Type Description Default
eta1 Float[Array, '*batch N']

Natural location parameter, shape (*batch, N).

required
eta2 Float[Array, '*batch N N']

Natural quadratic parameter, shape (*batch, N, N).

required
solver AbstractSolverStrategy | None

Optional solver strategy for structured linear algebra. When None, falls back to structural dispatch.

None

Returns:

Type Description
tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]

Tuple (m1, m2) of expectation parameters.

Source code in src/gaussx/_expfam/_natural.py
def natural_to_expectation(
    eta1: Float[Array, "*batch N"],
    eta2: Float[Array, "*batch N N"],
    *,
    solver: AbstractSolverStrategy | None = None,
) -> tuple[Float[Array, "*batch N"], Float[Array, "*batch N N"]]:
    r"""Convert natural parameters to expectation parameters.

    Given ``eta1 = Lambda @ mu`` and ``eta2 = -0.5 * Lambda``:

    - ``Sigma = (-2 * eta2)^{-1}``
    - ``mu = Sigma @ eta1``
    - ``m1 = mu``
    - ``m2 = mu @ mu^T + Sigma``

    Args:
        eta1: Natural location parameter, shape ``(*batch, N)``.
        eta2: Natural quadratic parameter, shape ``(*batch, N, N)``.
        solver: Optional solver strategy for structured linear algebra.
            When ``None``, falls back to structural dispatch.

    Returns:
        Tuple ``(m1, m2)`` of expectation parameters.
    """

    def _core(e1: Float[Array, " N"], e2: Float[Array, "N N"]):
        Lambda_op = lx.MatrixLinearOperator(-2.0 * e2, lx.positive_semidefinite_tag)
        mu_s = dispatch_solve(Lambda_op, e1, solver)
        Sigma = inv(Lambda_op).as_matrix()
        m2_s = mu_s[:, None] * mu_s[None, :] + Sigma
        return mu_s, m2_s

    *batch, N = eta1.shape
    if not batch:
        return _core(eta1, eta2)
    eta1_flat = eta1.reshape(-1, N)
    eta2_flat = eta2.reshape(-1, N, N)
    m1_flat, m2_flat = jax.vmap(_core)(eta1_flat, eta2_flat)
    return m1_flat.reshape(eta1.shape), m2_flat.reshape(eta2.shape)

log_partition(expfam: GaussianExpFam) -> Float[Array, '']

Log-partition function A(eta).

\[ A(\eta) = -\frac{1}{4} \eta_1^T \eta_2^{-1} \eta_1 - \frac{1}{2} \log|-2\eta_2| \]

Parameters:

Name Type Description Default
expfam GaussianExpFam

Gaussian in natural form.

required

Returns:

Type Description
Float[Array, '']

Scalar log-partition value.

Source code in src/gaussx/_expfam/_gaussian.py
def log_partition(expfam: GaussianExpFam) -> Float[Array, ""]:
    r"""Log-partition function ``A(eta)``.

    $$
    A(\eta) = -\frac{1}{4} \eta_1^T \eta_2^{-1} \eta_1
              - \frac{1}{2} \log|-2\eta_2|
    $$

    Args:
        expfam: Gaussian in natural form.

    Returns:
        Scalar log-partition value.
    """
    neg2_eta2 = -2.0 * expfam.eta2
    N = neg2_eta2.in_size()

    # -0.25 * eta1^T @ eta2^{-1} @ eta1
    # eta2^{-1} = (-0.5 Lambda)^{-1} = -2 Sigma
    # So -0.25 * eta1^T @ (-2 Sigma) @ eta1 = 0.5 * eta1^T Sigma eta1
    eta2_inv_eta1 = solve(expfam.eta2, expfam.eta1)
    quad = -0.25 * (expfam.eta1 @ eta2_inv_eta1)

    # -0.5 * log|-2 eta2| = -0.5 * logdet(Lambda)
    ld = -0.5 * logdet(neg2_eta2)

    # Add base measure contribution: N/2 * log(2pi)
    return quad + ld + 0.5 * N * _LOG_2PI

fisher_info(expfam: GaussianExpFam) -> lx.AbstractLinearOperator

Fisher information matrix F(eta) = nabla^2 A(eta).

For a Gaussian, the Fisher information in terms of the covariance is Sigma^{-1} (the precision matrix).

Parameters:

Name Type Description Default
expfam GaussianExpFam

Gaussian in natural form.

required

Returns:

Type Description
AbstractLinearOperator

Precision operator (the Fisher information matrix).

Source code in src/gaussx/_expfam/_gaussian.py
def fisher_info(
    expfam: GaussianExpFam,
) -> lx.AbstractLinearOperator:
    r"""Fisher information matrix ``F(eta) = nabla^2 A(eta)``.

    For a Gaussian, the Fisher information in terms of the
    covariance is ``Sigma^{-1}`` (the precision matrix).

    Args:
        expfam: Gaussian in natural form.

    Returns:
        Precision operator (the Fisher information matrix).
    """
    # Lambda = -2 * eta2
    return -2.0 * expfam.eta2

sufficient_stats(x: Float[Array, '*batch N']) -> tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]

Compute sufficient statistics T(x) = [x, x x^T].

Parameters:

Name Type Description Default
x Float[Array, '*batch N']

Data vector, shape (N,) or batch (B, N).

required

Returns:

Type Description
Float[Array, '*batch N']

Tuple (x, outer_product) where outer_product has

Float[Array, '*batch N N']

shape (N, N) or (B, N, N).

Source code in src/gaussx/_expfam/_gaussian.py
def sufficient_stats(
    x: Float[Array, "*batch N"],
) -> tuple[Float[Array, "*batch N"], Float[Array, "*batch N N"]]:
    """Compute sufficient statistics ``T(x) = [x, x x^T]``.

    Args:
        x: Data vector, shape ``(N,)`` or batch ``(B, N)``.

    Returns:
        Tuple ``(x, outer_product)`` where outer_product has
        shape ``(N, N)`` or ``(B, N, N)``.
    """
    if x.ndim == 1:
        return x, jnp.outer(x, x)
    # Batched: (B, N) -> (B, N, N)
    return x, einsum(x, x, "b i, b j -> b i j")

kl_divergence(q: GaussianExpFam, p: GaussianExpFam) -> Float[Array, '']

KL divergence KL(q || p) via the Bregman-divergence form on natural parameters.

Exponential-family expression of the KL divergence in terms of the log-partition A and the natural parameters of q and p. Mathematically equivalent to dist_kl_divergence.

The current implementation evaluates the Bregman form by routing through to_expectation for the natural-gradient term (eta_p - eta_q)^T nabla A(eta_q). The second-moment contraction splits into a quadratic form (operator matvecs) plus gaussx.trace_product, so structured eta2 / Sigma_q operators are never materialized. The benefit relative to dist_kl_divergence is keeping the gradient flowing in natural-parameter space (suitable inside a natural-gradient loop).

\[ KL(q || p) = A(eta_p) - A(eta_q) - (eta_p - eta_q)^T nabla A(eta_q) \]

Parameters:

Name Type Description Default
q GaussianExpFam

First Gaussian (the "true" distribution).

required
p GaussianExpFam

Second Gaussian (the "approximate" distribution).

required

Returns:

Type Description
Float[Array, '']

Scalar KL divergence.

See Also

dist_kl_divergence: General KL in mean/covariance form with lineax operators.

Source code in src/gaussx/_expfam/_gaussian.py
def kl_divergence(
    q: GaussianExpFam,
    p: GaussianExpFam,
) -> Float[Array, ""]:
    """KL divergence ``KL(q || p)`` via the Bregman-divergence form on
    natural parameters.

    Exponential-family expression of the KL divergence in terms of the
    log-partition ``A`` and the natural parameters of ``q`` and ``p``.
    Mathematically equivalent to
    `dist_kl_divergence`.

    The current implementation evaluates the Bregman form by routing
    through `to_expectation` for the natural-gradient term
    ``(eta_p - eta_q)^T nabla A(eta_q)``. The second-moment contraction
    splits into a quadratic form (operator matvecs) plus
    `gaussx.trace_product`, so structured ``eta2`` / ``Sigma_q``
    operators are never materialized. The benefit relative to
    `dist_kl_divergence` is keeping the gradient flowing in
    natural-parameter space (suitable inside a natural-gradient loop).

    $$
    KL(q || p) = A(eta_p) - A(eta_q) - (eta_p - eta_q)^T nabla A(eta_q)
    $$

    Args:
        q: First Gaussian (the "true" distribution).
        p: Second Gaussian (the "approximate" distribution).

    Returns:
        Scalar KL divergence.

    See Also:
        `dist_kl_divergence`: General KL
        in mean/covariance form with lineax operators.
    """
    A_p = log_partition(p)
    A_q = log_partition(q)

    # grad A(eta_q) w.r.t eta1 = mu_q, w.r.t eta2 = mu_q mu_q^T + Sigma_q
    # The linear term: (eta_p - eta_q)^T grad A(eta_q)
    # For eta1 part: (eta1_p - eta1_q)^T mu_q
    mu_q, Sigma_q = to_expectation(q)

    delta_eta1 = p.eta1 - q.eta1
    linear_eta1 = delta_eta1 @ mu_q

    # For eta2 part: tr((eta2_p - eta2_q) @ (mu mu^T + Sigma))
    # = mu^T (eta2_p - eta2_q) mu + tr(eta2_p Sigma) - tr(eta2_q Sigma).
    # Quadratic form via matvecs + structured trace_product — no
    # materialization of eta2 or Sigma_q.
    quad = mu_q @ (p.eta2.mv(mu_q) - q.eta2.mv(mu_q))
    linear_eta2 = quad + trace_product(p.eta2, Sigma_q) - trace_product(q.eta2, Sigma_q)

    return A_p - A_q - linear_eta1 - linear_eta2