Skip to content

Quadrature & Moment Matching

Gaussian integration: \(\mathbb{E}_{x \sim \mathcal{N}(\mu, \Sigma)}[f(x)]\) via deterministic rules (Gauss-Hermite, unscented / cubature, Taylor) or Monte Carlo, behind one AbstractIntegrator interface. Everything that needs an expectation — expected log-likelihoods, EP tilted moments, uncertain-input GP predictions — takes an integrator argument, so swapping the rule never touches the model.

State & integrators

GaussianState pairs a mean with a covariance operator; integrators propagate functions of it.

Structured linear algebra and Gaussian primitives for JAX.

GaussianState

Bases: Module

Gaussian distribution as (mean, covariance operator) pair.

Attributes:

Name Type Description
mean Float[Array, ' N']

Mean vector, shape (N,).

cov AbstractLinearOperator

Covariance operator, shape (N, N).

Source code in src/gaussx/_quadrature/_types.py
class GaussianState(eqx.Module):
    """Gaussian distribution as (mean, covariance operator) pair.

    Attributes:
        mean: Mean vector, shape ``(N,)``.
        cov: Covariance operator, shape ``(N, N)``.
    """

    mean: Float[Array, " N"]
    cov: lx.AbstractLinearOperator

PropagationResult

Bases: Module

Output of uncertainty propagation through a nonlinear function.

Attributes:

Name Type Description
state GaussianState

Output Gaussian distribution.

cross_cov Float[Array, 'N_in N_out'] | None

Input-output cross-covariance, shape (N_in, N_out). Used for downstream Kalman updates. None if not computed.

Source code in src/gaussx/_quadrature/_types.py
class PropagationResult(eqx.Module):
    """Output of uncertainty propagation through a nonlinear function.

    Attributes:
        state: Output Gaussian distribution.
        cross_cov: Input-output cross-covariance, shape ``(N_in, N_out)``.
            Used for downstream Kalman updates. ``None`` if not computed.
    """

    state: GaussianState
    cross_cov: Float[Array, "N_in N_out"] | None

AbstractIntegrator

Bases: Module

Protocol for Gaussian integral approximation.

Subclasses implement integrate to propagate a Gaussian through a nonlinear function, returning an approximate output distribution and (optionally) input-output cross-covariance.

Source code in src/gaussx/_quadrature/_integrator.py
class AbstractIntegrator(eqx.Module):
    """Protocol for Gaussian integral approximation.

    Subclasses implement ``integrate`` to propagate a Gaussian through
    a nonlinear function, returning an approximate output distribution
    and (optionally) input-output cross-covariance.
    """

    @abc.abstractmethod
    def integrate(
        self,
        fn: Callable[[Float[Array, " N"]], Float[Array, " M"]],
        state: GaussianState,
    ) -> PropagationResult:
        """Propagate a Gaussian through ``fn``, returning output moments.

        Args:
            fn: Nonlinear function mapping ``(N,) -> (M,)``.
            state: Input Gaussian distribution.

        Returns:
            ``PropagationResult`` with output distribution and optional
            cross-covariance.
        """
        ...

integrate(fn: Callable[[Float[Array, ' N']], Float[Array, ' M']], state: GaussianState) -> PropagationResult abstractmethod

Propagate a Gaussian through fn, returning output moments.

Parameters:

Name Type Description Default
fn Callable[[Float[Array, ' N']], Float[Array, ' M']]

Nonlinear function mapping (N,) -> (M,).

required
state GaussianState

Input Gaussian distribution.

required

Returns:

Type Description
PropagationResult

PropagationResult with output distribution and optional

PropagationResult

cross-covariance.

Source code in src/gaussx/_quadrature/_integrator.py
@abc.abstractmethod
def integrate(
    self,
    fn: Callable[[Float[Array, " N"]], Float[Array, " M"]],
    state: GaussianState,
) -> PropagationResult:
    """Propagate a Gaussian through ``fn``, returning output moments.

    Args:
        fn: Nonlinear function mapping ``(N,) -> (M,)``.
        state: Input Gaussian distribution.

    Returns:
        ``PropagationResult`` with output distribution and optional
        cross-covariance.
    """
    ...

GaussHermiteIntegrator

Bases: AbstractIntegrator

Gauss-Hermite quadrature integrator.

Approximates Gaussian expectations using tensor-product Gauss-Hermite quadrature:

E[g(f)] \approx \sum_i w_i \cdot g(\mu + L z_i)

where (z_i, w_i) are GH points/weights in standard normal space and L is the square root of the covariance.

Exact for polynomials up to degree 2 * order - 1. Complexity: O(order^dim), practical for dim <= ~5.

Attributes:

Name Type Description
order int

Number of quadrature points per dimension. Default 20.

Source code in src/gaussx/_quadrature/_gauss_hermite.py
class GaussHermiteIntegrator(AbstractIntegrator):
    r"""Gauss-Hermite quadrature integrator.

    Approximates Gaussian expectations using tensor-product Gauss-Hermite
    quadrature:

        E[g(f)] \approx \sum_i w_i \cdot g(\mu + L z_i)

    where ``(z_i, w_i)`` are GH points/weights in standard normal space
    and ``L`` is the square root of the covariance.

    Exact for polynomials up to degree ``2 * order - 1``.
    Complexity: ``O(order^dim)``, practical for ``dim <= ~5``.

    Attributes:
        order: Number of quadrature points per dimension. Default ``20``.
    """

    order: int = eqx.field(static=True, default=20)

    def integrate(
        self,
        fn: Callable[[Float[Array, " N"]], Float[Array, " M"]],
        state: GaussianState,
    ) -> PropagationResult:
        """Propagate Gaussian via Gauss-Hermite quadrature."""
        from gaussx._primitives._sqrt import sqrt
        from gaussx._quadrature._quadrature import gauss_hermite_points

        mu = state.mean
        N = mu.shape[0]

        # GH points in standard normal space
        z, w = gauss_hermite_points(self.order, N)

        # Transform to input space: xᵢ = μ + S zᵢ
        S = sqrt(state.cov).as_matrix()
        chi = mu[None, :] + z @ S.T  # (P, N)

        # Normalize weights to sum to 1
        w = w / jnp.sum(w)

        # Propagate all quadrature points
        Y = jax.vmap(fn)(chi)  # (P, M)

        return assemble_propagation_result(chi, Y, mu, w)

integrate(fn: Callable[[Float[Array, ' N']], Float[Array, ' M']], state: GaussianState) -> PropagationResult

Propagate Gaussian via Gauss-Hermite quadrature.

Source code in src/gaussx/_quadrature/_gauss_hermite.py
def integrate(
    self,
    fn: Callable[[Float[Array, " N"]], Float[Array, " M"]],
    state: GaussianState,
) -> PropagationResult:
    """Propagate Gaussian via Gauss-Hermite quadrature."""
    from gaussx._primitives._sqrt import sqrt
    from gaussx._quadrature._quadrature import gauss_hermite_points

    mu = state.mean
    N = mu.shape[0]

    # GH points in standard normal space
    z, w = gauss_hermite_points(self.order, N)

    # Transform to input space: xᵢ = μ + S zᵢ
    S = sqrt(state.cov).as_matrix()
    chi = mu[None, :] + z @ S.T  # (P, N)

    # Normalize weights to sum to 1
    w = w / jnp.sum(w)

    # Propagate all quadrature points
    Y = jax.vmap(fn)(chi)  # (P, M)

    return assemble_propagation_result(chi, Y, mu, w)

UnscentedIntegrator

Bases: AbstractIntegrator

Unscented transform: deterministic sigma points.

Generates 2N+1 sigma points around the mean, propagates them through the nonlinear function, and reconstructs output moments:

chi_i = mu + sqrt((N + lambda) * Sigma) @ xi_i
y_i = f(chi_i)
mu_y = sum(w_m * y_i)
Sigma_y = sum(w_c * (y_i - mu_y)(y_i - mu_y)^T)
cross_cov = sum(w_c * (chi_i - mu)(y_i - mu_y)^T)

where lambda = alpha^2 * (N + kappa) - N.

Attributes:

Name Type Description
alpha float

Spread parameter. Default 1e-3.

beta float

Prior knowledge parameter (2.0 optimal for Gaussian).

kappa float

Secondary scaling. Default 0.0.

Source code in src/gaussx/_quadrature/_unscented.py
class UnscentedIntegrator(AbstractIntegrator):
    r"""Unscented transform: deterministic sigma points.

    Generates ``2N+1`` sigma points around the mean, propagates them
    through the nonlinear function, and reconstructs output moments:

        chi_i = mu + sqrt((N + lambda) * Sigma) @ xi_i
        y_i = f(chi_i)
        mu_y = sum(w_m * y_i)
        Sigma_y = sum(w_c * (y_i - mu_y)(y_i - mu_y)^T)
        cross_cov = sum(w_c * (chi_i - mu)(y_i - mu_y)^T)

    where ``lambda = alpha^2 * (N + kappa) - N``.

    Attributes:
        alpha: Spread parameter. Default ``1e-3``.
        beta: Prior knowledge parameter (2.0 optimal for Gaussian).
        kappa: Secondary scaling. Default ``0.0``.
    """

    alpha: float = eqx.field(static=True, default=1e-3)
    beta: float = eqx.field(static=True, default=2.0)
    kappa: float = eqx.field(static=True, default=0.0)

    def integrate(
        self,
        fn: Callable[[Float[Array, " N"]], Float[Array, " M"]],
        state: GaussianState,
    ) -> PropagationResult:
        """Propagate Gaussian via unscented transform."""
        chi, w_m, w_c = sigma_points(
            state.mean,
            state.cov,
            alpha=self.alpha,
            beta=self.beta,
            kappa=self.kappa,
        )
        Y = jax.vmap(fn)(chi)
        return assemble_propagation_result(chi, Y, state.mean, w_m, w_c)

integrate(fn: Callable[[Float[Array, ' N']], Float[Array, ' M']], state: GaussianState) -> PropagationResult

Propagate Gaussian via unscented transform.

Source code in src/gaussx/_quadrature/_unscented.py
def integrate(
    self,
    fn: Callable[[Float[Array, " N"]], Float[Array, " M"]],
    state: GaussianState,
) -> PropagationResult:
    """Propagate Gaussian via unscented transform."""
    chi, w_m, w_c = sigma_points(
        state.mean,
        state.cov,
        alpha=self.alpha,
        beta=self.beta,
        kappa=self.kappa,
    )
    Y = jax.vmap(fn)(chi)
    return assemble_propagation_result(chi, Y, state.mean, w_m, w_c)

TaylorIntegrator

Bases: AbstractIntegrator

1st or 2nd order Taylor expansion for uncertainty propagation.

1st order (EKF):

mu_y = f(mu_x)
Sigma_y = J @ Sigma_x @ J^T
cross_cov = Sigma_x @ J^T

2nd order:

mu_y_i += 0.5 * tr(H_i @ Sigma_x)
Sigma_y += correction from Hessians

Attributes:

Name Type Description
order int

Taylor expansion order (1 or 2). Default 1.

correct_variance bool

If True and order=2, apply 2nd-order covariance correction using 4th Gaussian moments. Default True to preserve the historical order=2 behaviour. Set to False for the mean-only correction used in the standard EKF literature. Ignored when order=1.

Source code in src/gaussx/_quadrature/_taylor.py
class TaylorIntegrator(AbstractIntegrator):
    r"""1st or 2nd order Taylor expansion for uncertainty propagation.

    **1st order (EKF)**:

        mu_y = f(mu_x)
        Sigma_y = J @ Sigma_x @ J^T
        cross_cov = Sigma_x @ J^T

    **2nd order**:

        mu_y_i += 0.5 * tr(H_i @ Sigma_x)
        Sigma_y += correction from Hessians

    Attributes:
        order: Taylor expansion order (1 or 2). Default 1.
        correct_variance: If True and order=2, apply 2nd-order covariance
            correction using 4th Gaussian moments. Default True to preserve
            the historical ``order=2`` behaviour. Set to False for the
            mean-only correction used in the standard EKF literature.
            Ignored when order=1.
    """

    order: int = eqx.field(static=True, default=1)
    correct_variance: bool = eqx.field(static=True, default=True)

    def integrate(
        self,
        fn: Callable[[Float[Array, " N"]], Float[Array, " M"]],
        state: GaussianState,
    ) -> PropagationResult:
        """Propagate Gaussian via Taylor expansion."""
        if self.order not in (1, 2):
            msg = f"TaylorIntegrator.order must be 1 or 2, got {self.order}"
            raise ValueError(msg)
        mu = state.mean
        Sigma = state.cov.as_matrix()

        # Evaluate function and Jacobian at the mean
        f_mu = fn(mu)
        J = jax.jacobian(fn)(mu)  # (M, N)

        if self.order == 1:
            mu_y = f_mu
            Sigma_y = J @ Sigma @ J.T
        else:
            # 2nd order correction: mu_y_i += 0.5 * tr(H_i @ Sigma_x)
            H_fn = jax.hessian(fn)
            H = H_fn(mu)  # (M, N, N)
            corrections = jax.vmap(lambda H_i: jnp.trace(H_i @ Sigma))(H)
            mu_y = f_mu + 0.5 * corrections
            Sigma_y = J @ Sigma @ J.T
            if self.correct_variance:
                second_order_cov = jax.vmap(
                    lambda H_i: jax.vmap(
                        lambda H_j: 0.5 * jnp.trace(H_i @ Sigma @ H_j @ Sigma)
                    )(H)
                )(H)
                Sigma_y = Sigma_y + second_order_cov

        # Symmetrize for numerical stability
        Sigma_y = symmetrize(Sigma_y)

        # Cross-covariance: Sigma_x @ J^T
        cross_cov = Sigma @ J.T

        cov_y = lx.MatrixLinearOperator(Sigma_y, lx.positive_semidefinite_tag)
        out_state = GaussianState(mean=mu_y, cov=cov_y)

        return PropagationResult(state=out_state, cross_cov=cross_cov)

integrate(fn: Callable[[Float[Array, ' N']], Float[Array, ' M']], state: GaussianState) -> PropagationResult

Propagate Gaussian via Taylor expansion.

Source code in src/gaussx/_quadrature/_taylor.py
def integrate(
    self,
    fn: Callable[[Float[Array, " N"]], Float[Array, " M"]],
    state: GaussianState,
) -> PropagationResult:
    """Propagate Gaussian via Taylor expansion."""
    if self.order not in (1, 2):
        msg = f"TaylorIntegrator.order must be 1 or 2, got {self.order}"
        raise ValueError(msg)
    mu = state.mean
    Sigma = state.cov.as_matrix()

    # Evaluate function and Jacobian at the mean
    f_mu = fn(mu)
    J = jax.jacobian(fn)(mu)  # (M, N)

    if self.order == 1:
        mu_y = f_mu
        Sigma_y = J @ Sigma @ J.T
    else:
        # 2nd order correction: mu_y_i += 0.5 * tr(H_i @ Sigma_x)
        H_fn = jax.hessian(fn)
        H = H_fn(mu)  # (M, N, N)
        corrections = jax.vmap(lambda H_i: jnp.trace(H_i @ Sigma))(H)
        mu_y = f_mu + 0.5 * corrections
        Sigma_y = J @ Sigma @ J.T
        if self.correct_variance:
            second_order_cov = jax.vmap(
                lambda H_i: jax.vmap(
                    lambda H_j: 0.5 * jnp.trace(H_i @ Sigma @ H_j @ Sigma)
                )(H)
            )(H)
            Sigma_y = Sigma_y + second_order_cov

    # Symmetrize for numerical stability
    Sigma_y = symmetrize(Sigma_y)

    # Cross-covariance: Sigma_x @ J^T
    cross_cov = Sigma @ J.T

    cov_y = lx.MatrixLinearOperator(Sigma_y, lx.positive_semidefinite_tag)
    out_state = GaussianState(mean=mu_y, cov=cov_y)

    return PropagationResult(state=out_state, cross_cov=cross_cov)

MonteCarloIntegrator

Bases: AbstractIntegrator

Monte Carlo moment matching: sample, propagate, compute moments.

Propagates uncertainty by drawing samples from the input Gaussian, evaluating the function at each sample, and computing empirical output moments:

x_i ~ N(mu, Sigma)       (n_samples points)
y_i = f(x_i)
mu_y = mean(y_i)
Sigma_y = cov(y_i) + regularization * I
cross_cov = cov(x_i, y_i)

Attributes:

Name Type Description
n_samples int

Number of Monte Carlo samples. Default 1000.

regularization float

Diagonal jitter for numerical stability.

key Array | None

PRNG key. If None, uses jax.random.key(0).

Source code in src/gaussx/_quadrature/_monte_carlo.py
class MonteCarloIntegrator(AbstractIntegrator):
    r"""Monte Carlo moment matching: sample, propagate, compute moments.

    Propagates uncertainty by drawing samples from the input Gaussian,
    evaluating the function at each sample, and computing empirical
    output moments:

        x_i ~ N(mu, Sigma)       (n_samples points)
        y_i = f(x_i)
        mu_y = mean(y_i)
        Sigma_y = cov(y_i) + regularization * I
        cross_cov = cov(x_i, y_i)

    Attributes:
        n_samples: Number of Monte Carlo samples. Default ``1000``.
        regularization: Diagonal jitter for numerical stability.
        key: PRNG key. If ``None``, uses ``jax.random.key(0)``.
    """

    n_samples: int = eqx.field(static=True, default=1000)
    regularization: float = eqx.field(static=True, default=1e-6)
    key: jax.Array | None = None

    def integrate(
        self,
        fn: Callable[[Float[Array, " N"]], Float[Array, " M"]],
        state: GaussianState,
    ) -> PropagationResult:
        """Propagate Gaussian via Monte Carlo sampling."""
        if self.n_samples < 2:
            msg = (
                f"MonteCarloIntegrator requires n_samples >= 2 for "
                f"Bessel-corrected covariance, got {self.n_samples}."
            )
            raise ValueError(msg)
        mu = state.mean
        N = mu.shape[0]

        key = self.key if self.key is not None else jr.key(0)

        # Sample from input Gaussian: xᵢ = μ + L εᵢ
        L = cholesky(state.cov).as_matrix()
        eps = jr.normal(key, (self.n_samples, N))
        chi = mu[None, :] + eps @ L.T  # (S, N)

        # Propagate samples
        Y = jax.vmap(fn)(chi)  # (S, M)

        # Uniform weights = 1/S for empirical moments
        # Use 1/(S−1) Bessel correction via covariance weights
        S = self.n_samples
        w_m = jnp.full(S, 1.0 / S)
        w_c = jnp.full(S, 1.0 / (S - 1))

        result = assemble_propagation_result(chi, Y, mu, w_m, w_c)

        # Add regularization jitter to output covariance
        Sigma_y = result.state.cov.as_matrix()
        M = Y.shape[1]
        import lineax as lx

        Sigma_y = Sigma_y + self.regularization * jnp.eye(M)
        cov_y = lx.MatrixLinearOperator(Sigma_y, lx.positive_semidefinite_tag)
        out_state = GaussianState(mean=result.state.mean, cov=cov_y)

        return PropagationResult(state=out_state, cross_cov=result.cross_cov)

integrate(fn: Callable[[Float[Array, ' N']], Float[Array, ' M']], state: GaussianState) -> PropagationResult

Propagate Gaussian via Monte Carlo sampling.

Source code in src/gaussx/_quadrature/_monte_carlo.py
def integrate(
    self,
    fn: Callable[[Float[Array, " N"]], Float[Array, " M"]],
    state: GaussianState,
) -> PropagationResult:
    """Propagate Gaussian via Monte Carlo sampling."""
    if self.n_samples < 2:
        msg = (
            f"MonteCarloIntegrator requires n_samples >= 2 for "
            f"Bessel-corrected covariance, got {self.n_samples}."
        )
        raise ValueError(msg)
    mu = state.mean
    N = mu.shape[0]

    key = self.key if self.key is not None else jr.key(0)

    # Sample from input Gaussian: xᵢ = μ + L εᵢ
    L = cholesky(state.cov).as_matrix()
    eps = jr.normal(key, (self.n_samples, N))
    chi = mu[None, :] + eps @ L.T  # (S, N)

    # Propagate samples
    Y = jax.vmap(fn)(chi)  # (S, M)

    # Uniform weights = 1/S for empirical moments
    # Use 1/(S−1) Bessel correction via covariance weights
    S = self.n_samples
    w_m = jnp.full(S, 1.0 / S)
    w_c = jnp.full(S, 1.0 / (S - 1))

    result = assemble_propagation_result(chi, Y, mu, w_m, w_c)

    # Add regularization jitter to output covariance
    Sigma_y = result.state.cov.as_matrix()
    M = Y.shape[1]
    import lineax as lx

    Sigma_y = Sigma_y + self.regularization * jnp.eye(M)
    cov_y = lx.MatrixLinearOperator(Sigma_y, lx.positive_semidefinite_tag)
    out_state = GaussianState(mean=result.state.mean, cov=cov_y)

    return PropagationResult(state=out_state, cross_cov=result.cross_cov)

AssumedDensityFilter

Bases: AbstractIntegrator

KL-optimal Gaussian projection via moment matching.

Projects the (possibly non-Gaussian) output distribution onto the Gaussian family by matching first and second moments. Equivalent to argmin_q KL(p(y) || q(y)) within the Gaussian family.

Adds adaptive regularization and optional diagnostics for detecting non-Gaussianity:

eps = eps_base * trace(Sigma_y) / n_dim

Attributes:

Name Type Description
n_samples int

Number of Monte Carlo samples. Default 5000.

regularization float

Base regularization. Default 1e-6.

adaptive_regularization bool

Scale regularization by output variance. Default True.

key Array | None

PRNG key. If None, uses jax.random.key(0).

Source code in src/gaussx/_quadrature/_adf.py
class AssumedDensityFilter(AbstractIntegrator):
    r"""KL-optimal Gaussian projection via moment matching.

    Projects the (possibly non-Gaussian) output distribution onto the
    Gaussian family by matching first and second moments. Equivalent to
    ``argmin_q KL(p(y) || q(y))`` within the Gaussian family.

    Adds adaptive regularization and optional diagnostics for detecting
    non-Gaussianity:

        eps = eps_base * trace(Sigma_y) / n_dim

    Attributes:
        n_samples: Number of Monte Carlo samples. Default ``5000``.
        regularization: Base regularization. Default ``1e-6``.
        adaptive_regularization: Scale regularization by output
            variance. Default ``True``.
        key: PRNG key. If ``None``, uses ``jax.random.key(0)``.
    """

    n_samples: int = eqx.field(static=True, default=5000)
    regularization: float = eqx.field(static=True, default=1e-6)
    adaptive_regularization: bool = eqx.field(static=True, default=True)
    key: jax.Array | None = None

    def integrate(
        self,
        fn: Callable[[Float[Array, " N"]], Float[Array, " M"]],
        state: GaussianState,
    ) -> PropagationResult:
        """Propagate Gaussian via assumed density filtering."""
        result, _ = self._integrate_impl(fn, state)
        return result

    def integrate_with_diagnostics(
        self,
        fn: Callable[[Float[Array, " N"]], Float[Array, " M"]],
        state: GaussianState,
    ) -> tuple[PropagationResult, dict]:
        """Propagate Gaussian and return non-Gaussianity diagnostics.

        Args:
            fn: Nonlinear function mapping ``(N,) -> (M,)``.
            state: Input Gaussian distribution.

        Returns:
            Tuple ``(result, diagnostics)`` where diagnostics contains
            ``skewness``, ``kurtosis``, ``min_eigval``, and
            ``condition_number``.
        """
        return self._integrate_impl(fn, state, compute_diagnostics=True)

    def _integrate_impl(
        self,
        fn: Callable,
        state: GaussianState,
        compute_diagnostics: bool = False,
    ) -> tuple[PropagationResult, dict]:
        """Core implementation with optional diagnostics."""
        mu = state.mean
        N = mu.shape[0]

        key = self.key if self.key is not None else jr.key(0)

        # Sample from input Gaussian
        L = cholesky(state.cov).as_matrix()
        eps = jr.normal(key, (self.n_samples, N))
        x_samples = mu[None, :] + eps @ L.T

        # Propagate samples
        y_samples = jax.vmap(fn)(x_samples)

        # Moment matching (KL-optimal Gaussian projection)
        mu_y = jnp.mean(y_samples, axis=0)
        dy = y_samples - mu_y[None, :]
        Sigma_y = (dy.T @ dy) / (self.n_samples - 1)
        M = mu_y.shape[0]

        # Adaptive regularization
        if self.adaptive_regularization:
            eps_reg = self.regularization * jnp.trace(Sigma_y) / M
        else:
            eps_reg = self.regularization
        Sigma_y = Sigma_y + eps_reg * jnp.eye(M)
        Sigma_y = symmetrize(Sigma_y)

        # Cross-covariance
        dx = x_samples - mu[None, :]
        cross_cov = (dx.T @ dy) / (self.n_samples - 1)

        cov_y = lx.MatrixLinearOperator(Sigma_y, lx.positive_semidefinite_tag)
        out_state = GaussianState(mean=mu_y, cov=cov_y)
        result = PropagationResult(state=out_state, cross_cov=cross_cov)

        diagnostics: dict = {}
        if compute_diagnostics:
            # Compute non-Gaussianity diagnostics via the gaussx
            # eigvals primitive (dispatches on operator structure).
            eigvals = _gaussx_eigvals(cov_y)
            min_eigval = jnp.min(eigvals)
            max_eigval = jnp.max(eigvals)
            cond = max_eigval / jnp.maximum(min_eigval, 1e-30)

            # Per-dimension skewness and kurtosis
            std_dy = dy / jnp.sqrt(jnp.diag(Sigma_y))[None, :]
            skewness = jnp.mean(std_dy**3, axis=0)
            kurtosis = jnp.mean(std_dy**4, axis=0)

            diagnostics = {
                "skewness": skewness,
                "kurtosis": kurtosis,
                "min_eigval": min_eigval,
                "condition_number": cond,
            }

        return result, diagnostics

integrate(fn: Callable[[Float[Array, ' N']], Float[Array, ' M']], state: GaussianState) -> PropagationResult

Propagate Gaussian via assumed density filtering.

Source code in src/gaussx/_quadrature/_adf.py
def integrate(
    self,
    fn: Callable[[Float[Array, " N"]], Float[Array, " M"]],
    state: GaussianState,
) -> PropagationResult:
    """Propagate Gaussian via assumed density filtering."""
    result, _ = self._integrate_impl(fn, state)
    return result

integrate_with_diagnostics(fn: Callable[[Float[Array, ' N']], Float[Array, ' M']], state: GaussianState) -> tuple[PropagationResult, dict]

Propagate Gaussian and return non-Gaussianity diagnostics.

Parameters:

Name Type Description Default
fn Callable[[Float[Array, ' N']], Float[Array, ' M']]

Nonlinear function mapping (N,) -> (M,).

required
state GaussianState

Input Gaussian distribution.

required

Returns:

Type Description
PropagationResult

Tuple (result, diagnostics) where diagnostics contains

dict

skewness, kurtosis, min_eigval, and

tuple[PropagationResult, dict]

condition_number.

Source code in src/gaussx/_quadrature/_adf.py
def integrate_with_diagnostics(
    self,
    fn: Callable[[Float[Array, " N"]], Float[Array, " M"]],
    state: GaussianState,
) -> tuple[PropagationResult, dict]:
    """Propagate Gaussian and return non-Gaussianity diagnostics.

    Args:
        fn: Nonlinear function mapping ``(N,) -> (M,)``.
        state: Input Gaussian distribution.

    Returns:
        Tuple ``(result, diagnostics)`` where diagnostics contains
        ``skewness``, ``kurtosis``, ``min_eigval``, and
        ``condition_number``.
    """
    return self._integrate_impl(fn, state, compute_diagnostics=True)

Quadrature rules

The raw point sets behind the integrators, for when a recipe needs direct control.

Structured linear algebra and Gaussian primitives for JAX.

gauss_hermite_points(order: int, dim: int) -> tuple[Float[Array, 'P D'], Float[Array, ' P']]

Gauss-Hermite quadrature points and weights.

Generates tensor-product Gauss-Hermite quadrature points for integrating functions against a standard Gaussian measure (probabilists' Hermite polynomials). P = order^dim total points.

Parameters:

Name Type Description Default
order int

Number of quadrature points per dimension.

required
dim int

Dimensionality of the integration domain.

required

Returns:

Type Description
Float[Array, 'P D']

Tuple (points, weights) where points has shape

Float[Array, ' P']

(order^dim, dim) and weights has shape (order^dim,).

Source code in src/gaussx/_quadrature/_quadrature.py
def gauss_hermite_points(
    order: int,
    dim: int,
) -> tuple[Float[Array, "P D"], Float[Array, " P"]]:
    r"""Gauss-Hermite quadrature points and weights.

    Generates tensor-product Gauss-Hermite quadrature points for
    integrating functions against a standard Gaussian measure
    (probabilists' Hermite polynomials).
    ``P = order^dim`` total points.

    Args:
        order: Number of quadrature points per dimension.
        dim: Dimensionality of the integration domain.

    Returns:
        Tuple ``(points, weights)`` where points has shape
        ``(order^dim, dim)`` and weights has shape ``(order^dim,)``.
    """
    # 1D probabilists' Gauss-Hermite (weight = exp(-x^2/2))
    x1d_np, w1d_np = np.polynomial.hermite_e.hermegauss(order)
    x1d = jnp.array(x1d_np)
    w1d = jnp.array(w1d_np)

    if dim == 1:
        return x1d[:, None], w1d

    # Tensor product via meshgrid
    grids = jnp.meshgrid(*([x1d] * dim), indexing="ij")
    stacked = jnp.stack(grids, axis=0)  # (dim, *grid_shape)
    points = rearrange(stacked, "D ... -> (...) D")

    weight_grids = jnp.meshgrid(*([w1d] * dim), indexing="ij")
    weight_stack = jnp.stack(weight_grids, axis=0)  # (dim, *grid_shape)
    weights = reduce(weight_stack, "D ... -> (...)", "prod")

    return points, weights

cubature_points(mean: Float[Array, ' N'], cov: lx.AbstractLinearOperator) -> tuple[Float[Array, 'P N'], Float[Array, ' P']]

Spherical-radial cubature points and weights.

Generates 2N cubature points with equal weights 1/(2N). This is the cubature Kalman filter (CKF) point set.

Uses gaussx.sqrt(cov) for structured square root dispatch.

Parameters:

Name Type Description Default
mean Float[Array, ' N']

Mean vector, shape (N,).

required
cov AbstractLinearOperator

Covariance operator, shape (N, N).

required

Returns:

Type Description
Float[Array, 'P N']

Tuple (chi, weights) where:

Float[Array, ' P']
  • chi: Cubature points, shape (2N, N).
tuple[Float[Array, 'P N'], Float[Array, ' P']]
  • weights: Equal weights, shape (2N,).
Source code in src/gaussx/_quadrature/_quadrature.py
def cubature_points(
    mean: Float[Array, " N"],
    cov: lx.AbstractLinearOperator,
) -> tuple[Float[Array, "P N"], Float[Array, " P"]]:
    r"""Spherical-radial cubature points and weights.

    Generates ``2N`` cubature points with equal weights ``1/(2N)``.
    This is the cubature Kalman filter (CKF) point set.

    Uses ``gaussx.sqrt(cov)`` for structured square root dispatch.

    Args:
        mean: Mean vector, shape ``(N,)``.
        cov: Covariance operator, shape ``(N, N)``.

    Returns:
        Tuple ``(chi, weights)`` where:
        - ``chi``: Cubature points, shape ``(2N, N)``.
        - ``weights``: Equal weights, shape ``(2N,)``.
    """
    N = mean.shape[0]

    S = sqrt(cov)
    S_mat = S.as_matrix()
    S_scaled = jnp.sqrt(N) * S_mat  # (N, N)

    chi_plus = mean[None, :] + S_scaled.T  # (N, N)
    chi_minus = mean[None, :] - S_scaled.T  # (N, N)
    chi = jnp.concatenate([chi_plus, chi_minus], axis=0)  # (2N, N)

    weights = jnp.full(2 * N, 1.0 / (2.0 * N))

    return chi, weights

sigma_points(mean: Float[Array, ' N'], cov: lx.AbstractLinearOperator, alpha: float = 0.001, beta: float = 2.0, kappa: float = 0.0) -> tuple[Float[Array, 'P N'], Float[Array, ' P'], Float[Array, ' P']]

Unscented transform sigma points and weights.

Generates 2N+1 deterministic sigma points for a Gaussian with the given mean and covariance, using the scaled unscented transform.

Uses gaussx.sqrt(cov) for structured square root dispatch.

Parameters:

Name Type Description Default
mean Float[Array, ' N']

Mean vector, shape (N,).

required
cov AbstractLinearOperator

Covariance operator, shape (N, N).

required
alpha float

Spread parameter. Controls how far sigma points are from the mean. Default 1e-3.

0.001
beta float

Prior distribution parameter. beta=2 is optimal for Gaussians. Default 2.0.

2.0
kappa float

Secondary scaling parameter. Default 0.0.

0.0

Returns:

Type Description
Float[Array, 'P N']

Tuple (chi, w_m, w_c) where:

Float[Array, ' P']
  • chi: Sigma points, shape (2N+1, N).
Float[Array, ' P']
  • w_m: Mean weights, shape (2N+1,).
tuple[Float[Array, 'P N'], Float[Array, ' P'], Float[Array, ' P']]
  • w_c: Covariance weights, shape (2N+1,).
Source code in src/gaussx/_quadrature/_quadrature.py
def sigma_points(
    mean: Float[Array, " N"],
    cov: lx.AbstractLinearOperator,
    alpha: float = 1e-3,
    beta: float = 2.0,
    kappa: float = 0.0,
) -> tuple[Float[Array, "P N"], Float[Array, " P"], Float[Array, " P"]]:
    r"""Unscented transform sigma points and weights.

    Generates ``2N+1`` deterministic sigma points for a Gaussian with
    the given mean and covariance, using the scaled unscented transform.

    Uses ``gaussx.sqrt(cov)`` for structured square root dispatch.

    Args:
        mean: Mean vector, shape ``(N,)``.
        cov: Covariance operator, shape ``(N, N)``.
        alpha: Spread parameter. Controls how far sigma points are
            from the mean. Default ``1e-3``.
        beta: Prior distribution parameter. ``beta=2`` is optimal for
            Gaussians. Default ``2.0``.
        kappa: Secondary scaling parameter. Default ``0.0``.

    Returns:
        Tuple ``(chi, w_m, w_c)`` where:
        - ``chi``: Sigma points, shape ``(2N+1, N)``.
        - ``w_m``: Mean weights, shape ``(2N+1,)``.
        - ``w_c``: Covariance weights, shape ``(2N+1,)``.
    """
    N = mean.shape[0]
    lam = alpha**2 * (N + kappa) - N
    c = N + lam

    # Matrix square root: S where cov = S S^T
    S = sqrt(cov)
    S_mat = S.as_matrix()  # (N, N)
    S_scaled = jnp.sqrt(c) * S_mat

    # Sigma points: chi_0 = mu, chi_i = mu + S_i, chi_{N+i} = mu - S_i
    chi_0 = mean[None, :]  # (1, N)
    chi_plus = mean[None, :] + S_scaled.T  # (N, N) — each row is a point
    chi_minus = mean[None, :] - S_scaled.T  # (N, N)
    chi = jnp.concatenate([chi_0, chi_plus, chi_minus], axis=0)  # (2N+1, N)

    # Mean weights
    w_m_0 = lam / c
    w_m_i = 1.0 / (2.0 * c)
    w_m = jnp.concatenate(
        [
            jnp.array([w_m_0]),
            jnp.full(2 * N, w_m_i),
        ]
    )

    # Covariance weights
    w_c_0 = lam / c + (1.0 - alpha**2 + beta)
    w_c = jnp.concatenate(
        [
            jnp.array([w_c_0]),
            jnp.full(2 * N, w_m_i),
        ]
    )

    return chi, w_m, w_c

Likelihoods

Observation models with quadrature-friendly log_prob surfaces, shared by the expectation helpers and the SSM / CVI recipes.

Structured linear algebra and Gaussian primitives for JAX.

AbstractLikelihood

Bases: Module

Base class for likelihood functions with optional analytical ELL.

Subclasses that support closed-form expected log-likelihood under a Gaussian variational distribution should override has_analytical_ell to return True and implement analytical_expected_log_likelihood.

Source code in src/gaussx/_quadrature/_likelihood.py
class AbstractLikelihood(eqx.Module):
    """Base class for likelihood functions with optional analytical ELL.

    Subclasses that support closed-form expected log-likelihood under
    a Gaussian variational distribution should override
    ``has_analytical_ell`` to return ``True`` and implement
    ``analytical_expected_log_likelihood``.
    """

    @abc.abstractmethod
    def log_prob(
        self,
        f: Float[Array, " N"],
    ) -> Float[Array, ""]:
        """Evaluate ``log p(y | f)`` for fixed observations.

        Args:
            f: Latent function values, shape ``(N,)``.

        Returns:
            Scalar log-likelihood.
        """
        ...

    def has_analytical_ell(self) -> bool:
        """Whether this likelihood supports closed-form ELL."""
        return False

    def analytical_expected_log_likelihood(
        self,
        q_mu: Float[Array, " N"],
        q_cov: lx.AbstractLinearOperator,
    ) -> Float[Array, ""]:
        """Closed-form ``E_q[log p(y | f)]`` where ``q = N(q_mu, q_cov)``.

        Args:
            q_mu: Variational mean, shape ``(N,)``.
            q_cov: Variational covariance operator, shape ``(N, N)``.

        Returns:
            Scalar expected log-likelihood.

        Raises:
            NotImplementedError: If no analytical form exists.
        """
        msg = f"{type(self).__name__} has no analytical ELL"
        raise NotImplementedError(msg)

log_prob(f: Float[Array, ' N']) -> Float[Array, ''] abstractmethod

Evaluate log p(y | f) for fixed observations.

Parameters:

Name Type Description Default
f Float[Array, ' N']

Latent function values, shape (N,).

required

Returns:

Type Description
Float[Array, '']

Scalar log-likelihood.

Source code in src/gaussx/_quadrature/_likelihood.py
@abc.abstractmethod
def log_prob(
    self,
    f: Float[Array, " N"],
) -> Float[Array, ""]:
    """Evaluate ``log p(y | f)`` for fixed observations.

    Args:
        f: Latent function values, shape ``(N,)``.

    Returns:
        Scalar log-likelihood.
    """
    ...

has_analytical_ell() -> bool

Whether this likelihood supports closed-form ELL.

Source code in src/gaussx/_quadrature/_likelihood.py
def has_analytical_ell(self) -> bool:
    """Whether this likelihood supports closed-form ELL."""
    return False

analytical_expected_log_likelihood(q_mu: Float[Array, ' N'], q_cov: lx.AbstractLinearOperator) -> Float[Array, '']

Closed-form E_q[log p(y | f)] where q = N(q_mu, q_cov).

Parameters:

Name Type Description Default
q_mu Float[Array, ' N']

Variational mean, shape (N,).

required
q_cov AbstractLinearOperator

Variational covariance operator, shape (N, N).

required

Returns:

Type Description
Float[Array, '']

Scalar expected log-likelihood.

Raises:

Type Description
NotImplementedError

If no analytical form exists.

Source code in src/gaussx/_quadrature/_likelihood.py
def analytical_expected_log_likelihood(
    self,
    q_mu: Float[Array, " N"],
    q_cov: lx.AbstractLinearOperator,
) -> Float[Array, ""]:
    """Closed-form ``E_q[log p(y | f)]`` where ``q = N(q_mu, q_cov)``.

    Args:
        q_mu: Variational mean, shape ``(N,)``.
        q_cov: Variational covariance operator, shape ``(N, N)``.

    Returns:
        Scalar expected log-likelihood.

    Raises:
        NotImplementedError: If no analytical form exists.
    """
    msg = f"{type(self).__name__} has no analytical ELL"
    raise NotImplementedError(msg)

GaussianLikelihood

Bases: AbstractLikelihood

Gaussian likelihood log N(y | f, noise_var * I).

Supports closed-form expected log-likelihood:

E_q[log N(y | f, \sigma^2 I)]
    = log N(y | q_\mu, \sigma^2 I)
      - 0.5 / \sigma^2 \cdot tr(q_{cov})

Attributes:

Name Type Description
y Float[Array, ' N']

Observed targets, shape (N,).

noise_var float

Observation noise variance (scalar).

Source code in src/gaussx/_quadrature/_likelihood.py
class GaussianLikelihood(AbstractLikelihood):
    r"""Gaussian likelihood ``log N(y | f, noise_var * I)``.

    Supports closed-form expected log-likelihood:

        E_q[log N(y | f, \sigma^2 I)]
            = log N(y | q_\mu, \sigma^2 I)
              - 0.5 / \sigma^2 \cdot tr(q_{cov})

    Attributes:
        y: Observed targets, shape ``(N,)``.
        noise_var: Observation noise variance (scalar).
    """

    y: Float[Array, " N"]
    noise_var: float

    def log_prob(
        self,
        f: Float[Array, " N"],
    ) -> Float[Array, ""]:
        """Evaluate ``log N(y | f, noise_var * I)``."""
        N = self.y.shape[-1]
        residual = self.y - f
        return -0.5 * (
            N * _LOG_2PI
            + N * jnp.log(self.noise_var)
            + jnp.sum(residual**2) / self.noise_var
        )

    def has_analytical_ell(self) -> bool:
        """Gaussian likelihood has closed-form ELL."""
        return True

    def analytical_expected_log_likelihood(
        self,
        q_mu: Float[Array, " N"],
        q_cov: lx.AbstractLinearOperator,
    ) -> Float[Array, ""]:
        r"""Closed-form ``E_q[log N(y | f, \sigma^2 I)]``.

        Uses:

            E_q[log N(y|f,R)] = log N(y | q_mu, R) - 0.5 tr(R^{-1} q_cov)

        where ``R = noise_var * I``. Delegates the log-density term to
        `gaussx.gaussian_log_prob` (which exploits the diagonal
        noise structure) and computes the trace correction directly via
        the structural ``trace(q_cov) / noise_var`` shortcut, so
        Kronecker/BlockDiag-structured ``q_cov`` keeps its O(n)
        ``prod(trace_factor)`` / per-block ``trace`` fast paths instead
        of materializing through ``trace_product(R^{-1}, q_cov)``.
        """
        from gaussx._distributions._gaussian import gaussian_log_prob
        from gaussx._primitives._trace import trace

        N = self.y.shape[-1]
        noise = lx.DiagonalLinearOperator(jnp.full(N, self.noise_var))
        log_pdf = gaussian_log_prob(q_mu, noise, self.y)
        # Structural fast path: tr(R^{-1} q_cov) = trace(q_cov) / noise_var
        # for scalar isotropic noise. ``trace`` dispatches on operator
        # structure (Kronecker, BlockDiag, …) via gaussx primitives.
        tr_term = trace(q_cov) / self.noise_var
        return log_pdf - 0.5 * tr_term

log_prob(f: Float[Array, ' N']) -> Float[Array, '']

Evaluate log N(y | f, noise_var * I).

Source code in src/gaussx/_quadrature/_likelihood.py
def log_prob(
    self,
    f: Float[Array, " N"],
) -> Float[Array, ""]:
    """Evaluate ``log N(y | f, noise_var * I)``."""
    N = self.y.shape[-1]
    residual = self.y - f
    return -0.5 * (
        N * _LOG_2PI
        + N * jnp.log(self.noise_var)
        + jnp.sum(residual**2) / self.noise_var
    )

has_analytical_ell() -> bool

Gaussian likelihood has closed-form ELL.

Source code in src/gaussx/_quadrature/_likelihood.py
def has_analytical_ell(self) -> bool:
    """Gaussian likelihood has closed-form ELL."""
    return True

analytical_expected_log_likelihood(q_mu: Float[Array, ' N'], q_cov: lx.AbstractLinearOperator) -> Float[Array, '']

Closed-form E_q[log N(y | f, \sigma^2 I)].

Uses:

E_q[log N(y|f,R)] = log N(y | q_mu, R) - 0.5 tr(R^{-1} q_cov)

where R = noise_var * I. Delegates the log-density term to gaussx.gaussian_log_prob (which exploits the diagonal noise structure) and computes the trace correction directly via the structural trace(q_cov) / noise_var shortcut, so Kronecker/BlockDiag-structured q_cov keeps its O(n) prod(trace_factor) / per-block trace fast paths instead of materializing through trace_product(R^{-1}, q_cov).

Source code in src/gaussx/_quadrature/_likelihood.py
def analytical_expected_log_likelihood(
    self,
    q_mu: Float[Array, " N"],
    q_cov: lx.AbstractLinearOperator,
) -> Float[Array, ""]:
    r"""Closed-form ``E_q[log N(y | f, \sigma^2 I)]``.

    Uses:

        E_q[log N(y|f,R)] = log N(y | q_mu, R) - 0.5 tr(R^{-1} q_cov)

    where ``R = noise_var * I``. Delegates the log-density term to
    `gaussx.gaussian_log_prob` (which exploits the diagonal
    noise structure) and computes the trace correction directly via
    the structural ``trace(q_cov) / noise_var`` shortcut, so
    Kronecker/BlockDiag-structured ``q_cov`` keeps its O(n)
    ``prod(trace_factor)`` / per-block ``trace`` fast paths instead
    of materializing through ``trace_product(R^{-1}, q_cov)``.
    """
    from gaussx._distributions._gaussian import gaussian_log_prob
    from gaussx._primitives._trace import trace

    N = self.y.shape[-1]
    noise = lx.DiagonalLinearOperator(jnp.full(N, self.noise_var))
    log_pdf = gaussian_log_prob(q_mu, noise, self.y)
    # Structural fast path: tr(R^{-1} q_cov) = trace(q_cov) / noise_var
    # for scalar isotropic noise. ``trace`` dispatches on operator
    # structure (Kronecker, BlockDiag, …) via gaussx primitives.
    tr_term = trace(q_cov) / self.noise_var
    return log_pdf - 0.5 * tr_term

HeteroscedasticGaussianLikelihood

Bases: AbstractLikelihood

Heteroscedastic Gaussian likelihood with input-dependent noise.

Attributes:

Name Type Description
y Float[Array, ' N']

Observations, shape (N,).

Source code in src/gaussx/_quadrature/_likelihoods.py
class HeteroscedasticGaussianLikelihood(AbstractLikelihood):
    r"""Heteroscedastic Gaussian likelihood with input-dependent noise.

    Attributes:
        y: Observations, shape ``(N,)``.
    """

    y: Float[Array, " N"]
    latent_dim: int = eqx.field(static=True, default=2)

    def log_prob(self, f: Float[Array, " 2N"]) -> Float[Array, ""]:
        """Evaluate heteroscedastic Gaussian log-likelihood."""
        N = self.y.shape[0]
        f_mean = f[:N]
        f_noise = f[N:]
        noise_std = jax.nn.softplus(f_noise)
        noise_var = noise_std**2

        log_2pi = jnp.log(2.0 * jnp.pi)
        residual = self.y - f_mean
        return jnp.sum(-0.5 * (log_2pi + jnp.log(noise_var) + residual**2 / noise_var))

log_prob(f: Float[Array, ' 2N']) -> Float[Array, '']

Evaluate heteroscedastic Gaussian log-likelihood.

Source code in src/gaussx/_quadrature/_likelihoods.py
def log_prob(self, f: Float[Array, " 2N"]) -> Float[Array, ""]:
    """Evaluate heteroscedastic Gaussian log-likelihood."""
    N = self.y.shape[0]
    f_mean = f[:N]
    f_noise = f[N:]
    noise_std = jax.nn.softplus(f_noise)
    noise_var = noise_std**2

    log_2pi = jnp.log(2.0 * jnp.pi)
    residual = self.y - f_mean
    return jnp.sum(-0.5 * (log_2pi + jnp.log(noise_var) + residual**2 / noise_var))

BernoulliLikelihood

Bases: AbstractLikelihood

Bernoulli likelihood with logit link.

Attributes:

Name Type Description
y Float[Array, ' N']

Binary observations, shape (N,).

Source code in src/gaussx/_quadrature/_likelihoods.py
class BernoulliLikelihood(AbstractLikelihood):
    r"""Bernoulli likelihood with logit link.

    Attributes:
        y: Binary observations, shape ``(N,)``.
    """

    y: Float[Array, " N"]

    def log_prob(self, f: Float[Array, " N"]) -> Float[Array, ""]:
        """Evaluate Bernoulli log-likelihood with logit link."""
        return jnp.sum(
            self.y * jax.nn.log_sigmoid(f) + (1.0 - self.y) * jax.nn.log_sigmoid(-f)
        )

log_prob(f: Float[Array, ' N']) -> Float[Array, '']

Evaluate Bernoulli log-likelihood with logit link.

Source code in src/gaussx/_quadrature/_likelihoods.py
def log_prob(self, f: Float[Array, " N"]) -> Float[Array, ""]:
    """Evaluate Bernoulli log-likelihood with logit link."""
    return jnp.sum(
        self.y * jax.nn.log_sigmoid(f) + (1.0 - self.y) * jax.nn.log_sigmoid(-f)
    )

PoissonLikelihood

Bases: AbstractLikelihood

Poisson likelihood with log link.

Attributes:

Name Type Description
y Float[Array, ' N']

Count observations, shape (N,).

Source code in src/gaussx/_quadrature/_likelihoods.py
class PoissonLikelihood(AbstractLikelihood):
    r"""Poisson likelihood with log link.

    Attributes:
        y: Count observations, shape ``(N,)``.
    """

    y: Float[Array, " N"]

    def log_prob(self, f: Float[Array, " N"]) -> Float[Array, ""]:
        """Evaluate Poisson log-likelihood with log link."""
        return jnp.sum(self.y * f - jnp.exp(f) - jax.scipy.special.gammaln(self.y + 1))

log_prob(f: Float[Array, ' N']) -> Float[Array, '']

Evaluate Poisson log-likelihood with log link.

Source code in src/gaussx/_quadrature/_likelihoods.py
def log_prob(self, f: Float[Array, " N"]) -> Float[Array, ""]:
    """Evaluate Poisson log-likelihood with log link."""
    return jnp.sum(self.y * f - jnp.exp(f) - jax.scipy.special.gammaln(self.y + 1))

SoftmaxLikelihood

Bases: AbstractLikelihood

Softmax (categorical) likelihood for multi-class classification.

Parameters:

Name Type Description Default
y Int[Array, ' N']

Integer class labels, shape (N,).

required
num_classes int

Number of classes C.

required
Source code in src/gaussx/_quadrature/_likelihoods.py
class SoftmaxLikelihood(AbstractLikelihood):
    r"""Softmax (categorical) likelihood for multi-class classification.

    Args:
        y: Integer class labels, shape ``(N,)``.
        num_classes: Number of classes C.
    """

    y: Int[Array, " N"]
    num_classes: int = eqx.field(static=True)
    latent_dim: int = eqx.field(static=True, default=1)

    def __init__(self, y: Int[Array, " N"], num_classes: int):
        self.y = y
        self.num_classes = num_classes
        self.latent_dim = num_classes

    def log_prob(self, f: Float[Array, " NC"]) -> Float[Array, ""]:
        """Evaluate softmax log-likelihood."""
        f_2d = rearrange(f, "(N C) -> N C", C=self.num_classes)
        log_probs = jax.nn.log_softmax(f_2d, axis=-1)
        return jnp.sum(log_probs[jnp.arange(self.y.shape[0]), self.y])

log_prob(f: Float[Array, ' NC']) -> Float[Array, '']

Evaluate softmax log-likelihood.

Source code in src/gaussx/_quadrature/_likelihoods.py
def log_prob(self, f: Float[Array, " NC"]) -> Float[Array, ""]:
    """Evaluate softmax log-likelihood."""
    f_2d = rearrange(f, "(N C) -> N C", C=self.num_classes)
    log_probs = jax.nn.log_softmax(f_2d, axis=-1)
    return jnp.sum(log_probs[jnp.arange(self.y.shape[0]), self.y])

StudentTLikelihood

Bases: AbstractLikelihood

Student-t likelihood for robust regression.

Attributes:

Name Type Description
y Float[Array, ' N']

Observations, shape (N,).

df float

Degrees of freedom (> 0).

scale float

Scale parameter (> 0).

Source code in src/gaussx/_quadrature/_likelihoods.py
class StudentTLikelihood(AbstractLikelihood):
    r"""Student-t likelihood for robust regression.

    Attributes:
        y: Observations, shape ``(N,)``.
        df: Degrees of freedom (> 0).
        scale: Scale parameter (> 0).
    """

    y: Float[Array, " N"]
    df: float
    scale: float

    def log_prob(self, f: Float[Array, " N"]) -> Float[Array, ""]:
        """Evaluate Student-t log-likelihood."""
        df = self.df
        scale = self.scale
        residual = self.y - f
        half_df = 0.5 * df
        half_dfp1 = 0.5 * (df + 1.0)

        log_norm = (
            jax.scipy.special.gammaln(half_dfp1)
            - jax.scipy.special.gammaln(half_df)
            - 0.5 * jnp.log(df * jnp.pi * scale**2)
        )
        log_kernel = -half_dfp1 * jnp.log1p(residual**2 / (df * scale**2))
        return jnp.sum(log_norm + log_kernel)

log_prob(f: Float[Array, ' N']) -> Float[Array, '']

Evaluate Student-t log-likelihood.

Source code in src/gaussx/_quadrature/_likelihoods.py
def log_prob(self, f: Float[Array, " N"]) -> Float[Array, ""]:
    """Evaluate Student-t log-likelihood."""
    df = self.df
    scale = self.scale
    residual = self.y - f
    half_df = 0.5 * df
    half_dfp1 = 0.5 * (df + 1.0)

    log_norm = (
        jax.scipy.special.gammaln(half_dfp1)
        - jax.scipy.special.gammaln(half_df)
        - 0.5 * jnp.log(df * jnp.pi * scale**2)
    )
    log_kernel = -half_dfp1 * jnp.log1p(residual**2 / (df * scale**2))
    return jnp.sum(log_norm + log_kernel)

Expectations & EP moments

Expected log-likelihoods (the ELL term of every ELBO), generic mean / gradient / cost expectations, and the tilted-moment matching at the heart of expectation propagation.

Structured linear algebra and Gaussian primitives for JAX.

elbo(likelihood: AbstractLikelihood, state: GaussianState, kl: Float[Array, ''], integrator: AbstractIntegrator | None = None) -> Float[Array, '']

Evidence lower bound (ELBO).

Computes:

ELBO = E_q[log p(y | f)] - KL(q || p)

Dispatches to analytical expected log-likelihood when available (e.g. Gaussian likelihood), or uses numerical integration.

Parameters:

Name Type Description Default
likelihood AbstractLikelihood

Likelihood object with log_prob method.

required
state GaussianState

Variational Gaussian distribution q(f).

required
kl Float[Array, '']

KL divergence KL(q || p) (scalar, precomputed).

required
integrator AbstractIntegrator | None

Integration method for non-conjugate likelihoods. Ignored when the likelihood has an analytical fast path.

None

Returns:

Type Description
Float[Array, '']

Scalar ELBO value.

Source code in src/gaussx/_quadrature/_expectations.py
def elbo(
    likelihood: AbstractLikelihood,
    state: GaussianState,
    kl: Float[Array, ""],
    integrator: AbstractIntegrator | None = None,
) -> Float[Array, ""]:
    r"""Evidence lower bound (ELBO).

    Computes:

        ELBO = E_q[log p(y | f)] - KL(q || p)

    Dispatches to analytical expected log-likelihood when available
    (e.g. Gaussian likelihood), or uses numerical integration.

    Args:
        likelihood: Likelihood object with ``log_prob`` method.
        state: Variational Gaussian distribution ``q(f)``.
        kl: KL divergence ``KL(q || p)`` (scalar, precomputed).
        integrator: Integration method for non-conjugate likelihoods.
            Ignored when the likelihood has an analytical fast path.

    Returns:
        Scalar ELBO value.
    """
    ell = expected_log_likelihood(likelihood, state, integrator)
    return ell - kl

expected_log_likelihood(likelihood: AbstractLikelihood, state: GaussianState, integrator: AbstractIntegrator | None = None) -> Float[Array, '']

Unified expected log-likelihood with analytical dispatch.

Computes E_q[log p(y | f)] where q = N(mu, Sigma).

If the likelihood has a closed-form expected log-likelihood (e.g. GaussianLikelihood), it is used directly without an integrator. Otherwise, an integrator must be provided for numerical approximation.

Parameters:

Name Type Description Default
likelihood AbstractLikelihood

Likelihood object with log_prob method.

required
state GaussianState

Variational Gaussian distribution.

required
integrator AbstractIntegrator | None

Integration method. Required for non-conjugate likelihoods; ignored when the likelihood has an analytical fast path.

None

Returns:

Type Description
Float[Array, '']

Scalar expected log-likelihood.

Raises:

Type Description
ValueError

If no integrator is provided and the likelihood has no analytical form.

Source code in src/gaussx/_quadrature/_expectations.py
def expected_log_likelihood(
    likelihood: AbstractLikelihood,
    state: GaussianState,
    integrator: AbstractIntegrator | None = None,
) -> Float[Array, ""]:
    r"""Unified expected log-likelihood with analytical dispatch.

    Computes ``E_q[log p(y | f)]`` where ``q = N(mu, Sigma)``.

    If the likelihood has a closed-form expected log-likelihood
    (e.g. ``GaussianLikelihood``), it is used directly without
    an integrator. Otherwise, an ``integrator`` must be provided
    for numerical approximation.

    Args:
        likelihood: Likelihood object with ``log_prob`` method.
        state: Variational Gaussian distribution.
        integrator: Integration method. Required for non-conjugate
            likelihoods; ignored when the likelihood has an
            analytical fast path.

    Returns:
        Scalar expected log-likelihood.

    Raises:
        ValueError: If no integrator is provided and the likelihood
            has no analytical form.
    """
    if likelihood.has_analytical_ell():
        return likelihood.analytical_expected_log_likelihood(state.mean, state.cov)
    if integrator is None:
        msg = (
            f"{type(likelihood).__name__} has no analytical ELL; provide an integrator"
        )
        raise ValueError(msg)
    return log_likelihood_expectation(likelihood.log_prob, state, integrator)

log_likelihood_expectation(likelihood_fn: Callable[[Float[Array, ' N']], Float[Array, '']], state: GaussianState, integrator: AbstractIntegrator) -> Float[Array, '']

Compute E[log p(y_{obs} | f(x))] where x ~ N(mu, Sigma).

For non-conjugate likelihoods (Bernoulli, Poisson, etc.) where the expectation has no closed form.

Parameters:

Name Type Description Default
likelihood_fn Callable[[Float[Array, ' N']], Float[Array, '']]

Function mapping latent values to scalar log-likelihood: (N,) -> ().

required
state GaussianState

Input Gaussian distribution.

required
integrator AbstractIntegrator

Integration method.

required

Returns:

Type Description
Float[Array, '']

Scalar expected log-likelihood.

Source code in src/gaussx/_quadrature/_expectations.py
def log_likelihood_expectation(
    likelihood_fn: Callable[[Float[Array, " N"]], Float[Array, ""]],
    state: GaussianState,
    integrator: AbstractIntegrator,
) -> Float[Array, ""]:
    r"""Compute ``E[log p(y_{obs} | f(x))]`` where ``x ~ N(mu, Sigma)``.

    For non-conjugate likelihoods (Bernoulli, Poisson, etc.) where
    the expectation has no closed form.

    Args:
        likelihood_fn: Function mapping latent values to scalar
            log-likelihood: ``(N,) -> ()``.
        state: Input Gaussian distribution.
        integrator: Integration method.

    Returns:
        Scalar expected log-likelihood.
    """
    return mean_expectation(
        lambda x: jnp.atleast_1d(likelihood_fn(x)),
        state,
        integrator,
    )[0]

mean_expectation(fn: Callable[[Float[Array, ' N']], Float[Array, ' M']], state: GaussianState, integrator: AbstractIntegrator) -> Float[Array, ' M']

Compute E[f(x)] where x ~ N(mu, Sigma).

Parameters:

Name Type Description Default
fn Callable[[Float[Array, ' N']], Float[Array, ' M']]

Function mapping (N,) -> (M,).

required
state GaussianState

Input Gaussian distribution.

required
integrator AbstractIntegrator

Integration method (Taylor, Unscented, MC, etc.).

required

Returns:

Type Description
Float[Array, ' M']

Expected function value, shape (M,).

Source code in src/gaussx/_quadrature/_expectations.py
def mean_expectation(
    fn: Callable[[Float[Array, " N"]], Float[Array, " M"]],
    state: GaussianState,
    integrator: AbstractIntegrator,
) -> Float[Array, " M"]:
    r"""Compute ``E[f(x)]`` where ``x ~ N(mu, Sigma)``.

    Args:
        fn: Function mapping ``(N,) -> (M,)``.
        state: Input Gaussian distribution.
        integrator: Integration method (Taylor, Unscented, MC, etc.).

    Returns:
        Expected function value, shape ``(M,)``.
    """
    result = integrator.integrate(fn, state)
    return result.state.mean

gradient_expectation(fn: Callable[[Float[Array, ' N']], Float[Array, '']], state: GaussianState, integrator: AbstractIntegrator) -> Float[Array, ' N']

Compute E[nabla f(x)] via Stein's lemma.

Uses the identity:

E[nabla f(x)] = Sigma^{-1} Cov[x, f(x)]

Parameters:

Name Type Description Default
fn Callable[[Float[Array, ' N']], Float[Array, '']]

Scalar-valued function mapping (N,) -> ().

required
state GaussianState

Input Gaussian distribution.

required
integrator AbstractIntegrator

Integration method.

required

Returns:

Type Description
Float[Array, ' N']

Expected gradient, shape (N,).

Source code in src/gaussx/_quadrature/_expectations.py
def gradient_expectation(
    fn: Callable[[Float[Array, " N"]], Float[Array, ""]],
    state: GaussianState,
    integrator: AbstractIntegrator,
) -> Float[Array, " N"]:
    r"""Compute ``E[nabla f(x)]`` via Stein's lemma.

    Uses the identity:

        E[nabla f(x)] = Sigma^{-1} Cov[x, f(x)]

    Args:
        fn: Scalar-valued function mapping ``(N,) -> ()``.
        state: Input Gaussian distribution.
        integrator: Integration method.

    Returns:
        Expected gradient, shape ``(N,)``.
    """
    from gaussx._primitives._solve import solve

    # Wrap scalar fn to return (1,) for the integrator
    def fn_vec(x: Float[Array, " N"]) -> Float[Array, " 1"]:
        return jnp.atleast_1d(fn(x))

    result = integrator.integrate(fn_vec, state)
    cross_cov = result.cross_cov
    assert cross_cov is not None, "Integrator must return cross_cov"
    # E[nabla f] = Sigma^{-1} @ Cov[x, f]  (N, 1) -> (N,)
    return solve(state.cov, cross_cov[:, 0])

cost_expectation(prediction_fn: Callable[[Float[Array, ' N']], Float[Array, ' M']], cost_fn: Callable[[Float[Array, ' M'], Float[Array, ' M']], Float[Array, '']], state: GaussianState, target: Float[Array, ' M'], integrator: AbstractIntegrator) -> Float[Array, '']

Compute E[Cost(f(x), target)] where x ~ N(mu, Sigma).

For model-based RL: expected cost of a policy under state uncertainty.

Parameters:

Name Type Description Default
prediction_fn Callable[[Float[Array, ' N']], Float[Array, ' M']]

Maps state to prediction, (N,) -> (M,).

required
cost_fn Callable[[Float[Array, ' M'], Float[Array, ' M']], Float[Array, '']]

Cost function, (M,), (M,) -> ().

required
state GaussianState

Input Gaussian distribution (uncertain state).

required
target Float[Array, ' M']

Target value, shape (M,).

required
integrator AbstractIntegrator

Integration method.

required

Returns:

Type Description
Float[Array, '']

Scalar expected cost.

Source code in src/gaussx/_quadrature/_expectations.py
def cost_expectation(
    prediction_fn: Callable[[Float[Array, " N"]], Float[Array, " M"]],
    cost_fn: Callable[[Float[Array, " M"], Float[Array, " M"]], Float[Array, ""]],
    state: GaussianState,
    target: Float[Array, " M"],
    integrator: AbstractIntegrator,
) -> Float[Array, ""]:
    r"""Compute ``E[Cost(f(x), target)]`` where ``x ~ N(mu, Sigma)``.

    For model-based RL: expected cost of a policy under state
    uncertainty.

    Args:
        prediction_fn: Maps state to prediction, ``(N,) -> (M,)``.
        cost_fn: Cost function, ``(M,), (M,) -> ()``.
        state: Input Gaussian distribution (uncertain state).
        target: Target value, shape ``(M,)``.
        integrator: Integration method.

    Returns:
        Scalar expected cost.
    """

    def combined_fn(x: Float[Array, " N"]) -> Float[Array, " 1"]:
        pred = prediction_fn(x)
        return jnp.atleast_1d(cost_fn(pred, target))

    return mean_expectation(combined_fn, state, integrator)[0]

ep_tilted_moments(log_lik_fn: Callable[[Float[Array, '']], Float[Array, '']], cav_mean: Float[Array, ' *batch'], cav_var: Float[Array, ' *batch'], *, order: int = 20) -> tuple[Float[Array, ' *batch'], Float[Array, ' *batch']]

Compute tilted distribution moments via Gauss-Hermite quadrature.

Parameters:

Name Type Description Default
log_lik_fn Callable[[Float[Array, '']], Float[Array, '']]

Scalar function mapping latent value f to scalar log-likelihood log p(y|f).

required
cav_mean Float[Array, ' *batch']

Cavity means, shape (*batch,).

required
cav_var Float[Array, ' *batch']

Cavity variances (positive), shape (*batch,).

required
order int

Number of Gauss-Hermite quadrature points. Default 20.

20

Returns:

Type Description
tuple[Float[Array, ' *batch'], Float[Array, ' *batch']]

Tuple (tilted_mean, tilted_var).

Source code in src/gaussx/_quadrature/_tilted_moments.py
def ep_tilted_moments(
    log_lik_fn: Callable[[Float[Array, ""]], Float[Array, ""]],
    cav_mean: Float[Array, " *batch"],
    cav_var: Float[Array, " *batch"],
    *,
    order: int = 20,
) -> tuple[Float[Array, " *batch"], Float[Array, " *batch"]]:
    r"""Compute tilted distribution moments via Gauss-Hermite quadrature.

    Args:
        log_lik_fn: Scalar function mapping latent value ``f`` to scalar
            log-likelihood ``log p(y|f)``.
        cav_mean: Cavity means, shape ``(*batch,)``.
        cav_var: Cavity variances (positive), shape ``(*batch,)``.
        order: Number of Gauss-Hermite quadrature points. Default 20.

    Returns:
        Tuple ``(tilted_mean, tilted_var)``.
    """
    from gaussx._quadrature._quadrature import gauss_hermite_points

    z, w = gauss_hermite_points(order, dim=1)
    z = z.squeeze(-1)
    log_w = jnp.log(w)

    def _compute_moments(mean_i: Float[Array, ""], var_i: Float[Array, ""]):
        std_i = jnp.sqrt(var_i)
        f_nodes = mean_i + std_i * z

        log_lik_vals = jax.vmap(log_lik_fn)(f_nodes)
        log_joint = log_w + log_lik_vals

        log_Z = jax.scipy.special.logsumexp(log_joint)
        weights = jnp.exp(log_joint - log_Z)

        t_mean = jnp.sum(weights * f_nodes)
        t_var = jnp.sum(weights * (f_nodes - t_mean) ** 2)
        t_var = jnp.maximum(t_var, 1e-10)
        return t_mean, t_var

    orig_shape = cav_mean.shape
    flat_mean = cav_mean.ravel()
    flat_var = cav_var.ravel()

    flat_t_mean, flat_t_var = jax.vmap(_compute_moments)(flat_mean, flat_var)

    tilted_mean = flat_t_mean.reshape(orig_shape)
    tilted_var = flat_t_var.reshape(orig_shape)
    return tilted_mean, tilted_var

Kernel expectations & uncertain-input GPs

The \(\Psi\)-statistics \(\Psi_0 = \mathbb{E}[k(x,x)]\), \(\Psi_1 = \mathbb{E}[k(x, X)]\), \(\Psi_2 = \mathbb{E}[k(x,\cdot)k(x,\cdot)^\top]\) and the GP / SVGP / VGP / BGPLVM predictive equations for inputs that are themselves Gaussian.

Structured linear algebra and Gaussian primitives for JAX.

AnalyticalPsiStatistics

Bases: Protocol

Protocol for kernels with closed-form Ψ statistics.

Ψ statistics are required for uncertain-input GP models (e.g., BGPLVM). A kernel implementing this protocol provides analytical formulae instead of requiring numerical integration.

Source code in src/gaussx/_quadrature/_psi_statistics.py
@runtime_checkable
class AnalyticalPsiStatistics(Protocol):
    """Protocol for kernels with closed-form Ψ statistics.

    Ψ statistics are required for uncertain-input GP models
    (e.g., BGPLVM). A kernel implementing this protocol provides
    analytical formulae instead of requiring numerical integration.
    """

    def psi0(self, state: GaussianState) -> Float[Array, ""]:
        """Compute Ψ₀ = E[k(x, x)] (scalar)."""
        ...

    def psi1(
        self,
        state: GaussianState,
        X_train: Float[Array, "M D"],
    ) -> Float[Array, " M"]:
        """Compute Ψ₁ᵢ = E[k(x, xᵢ)], shape ``(M,)``."""
        ...

    def psi2(
        self,
        state: GaussianState,
        X_train: Float[Array, "M D"],
    ) -> Float[Array, "M M"]:
        """Compute Ψ₂ᵢⱼ = E[k(x, xᵢ) k(x, xⱼ)], shape ``(M, M)``."""
        ...

psi0(state: GaussianState) -> Float[Array, '']

Compute Ψ₀ = E[k(x, x)] (scalar).

Source code in src/gaussx/_quadrature/_psi_statistics.py
def psi0(self, state: GaussianState) -> Float[Array, ""]:
    """Compute Ψ₀ = E[k(x, x)] (scalar)."""
    ...

psi1(state: GaussianState, X_train: Float[Array, 'M D']) -> Float[Array, ' M']

Compute Ψ₁ᵢ = E[k(x, xᵢ)], shape (M,).

Source code in src/gaussx/_quadrature/_psi_statistics.py
def psi1(
    self,
    state: GaussianState,
    X_train: Float[Array, "M D"],
) -> Float[Array, " M"]:
    """Compute Ψ₁ᵢ = E[k(x, xᵢ)], shape ``(M,)``."""
    ...

psi2(state: GaussianState, X_train: Float[Array, 'M D']) -> Float[Array, 'M M']

Compute Ψ₂ᵢⱼ = E[k(x, xᵢ) k(x, xⱼ)], shape (M, M).

Source code in src/gaussx/_quadrature/_psi_statistics.py
def psi2(
    self,
    state: GaussianState,
    X_train: Float[Array, "M D"],
) -> Float[Array, "M M"]:
    """Compute Ψ₂ᵢⱼ = E[k(x, xᵢ) k(x, xⱼ)], shape ``(M, M)``."""
    ...

kernel_expectations(kernel_fn: Callable[[Float[Array, ' D'], Float[Array, ' D']], Float[Array, '']], state: GaussianState, X_train: Float[Array, 'N_train D'], integrator: AbstractIntegrator) -> tuple[Float[Array, ''], Float[Array, ' N_train'], Float[Array, 'N_train N_train']]

Compute kernel expectations Psi_0, Psi_1, Psi_2 for uncertain inputs.

These are the core quantities for GP inference with uncertain inputs:

Psi_0 = E[k(x, x)]                      scalar
Psi_1_i = E[k(x, x_i)]                  (N_train,)
Psi_2_{ij} = E[k(x, x_i) k(x, x_j)]    (N_train, N_train)

Parameters:

Name Type Description Default
kernel_fn Callable[[Float[Array, ' D'], Float[Array, ' D']], Float[Array, '']]

Kernel function k(x, x') -> scalar.

required
state GaussianState

Uncertain input distribution x ~ N(mu, Sigma).

required
X_train Float[Array, 'N_train D']

Training points, shape (N_train, D).

required
integrator AbstractIntegrator

Integration method.

required

Returns:

Type Description
tuple[Float[Array, ''], Float[Array, ' N_train'], Float[Array, 'N_train N_train']]

Tuple (Psi_0, Psi_1, Psi_2).

Source code in src/gaussx/_quadrature/_gp_predict.py
def kernel_expectations(
    kernel_fn: Callable[[Float[Array, " D"], Float[Array, " D"]], Float[Array, ""]],
    state: GaussianState,
    X_train: Float[Array, "N_train D"],
    integrator: AbstractIntegrator,
) -> tuple[Float[Array, ""], Float[Array, " N_train"], Float[Array, "N_train N_train"]]:
    r"""Compute kernel expectations Psi_0, Psi_1, Psi_2 for uncertain inputs.

    These are the core quantities for GP inference with uncertain inputs:

        Psi_0 = E[k(x, x)]                      scalar
        Psi_1_i = E[k(x, x_i)]                  (N_train,)
        Psi_2_{ij} = E[k(x, x_i) k(x, x_j)]    (N_train, N_train)

    Args:
        kernel_fn: Kernel function ``k(x, x') -> scalar``.
        state: Uncertain input distribution ``x ~ N(mu, Sigma)``.
        X_train: Training points, shape ``(N_train, D)``.
        integrator: Integration method.

    Returns:
        Tuple ``(Psi_0, Psi_1, Psi_2)``.
    """
    from gaussx._quadrature._expectations import mean_expectation

    # Psi_0 = E[k(x, x)]
    Psi_0 = mean_expectation(
        lambda x: jnp.atleast_1d(kernel_fn(x, x)),
        state,
        integrator,
    )[0]

    # Psi_1_i = E[k(x, x_i)]
    def psi1_fn(x: Float[Array, " D"]) -> Float[Array, " N_train"]:
        return jax.vmap(lambda xi: kernel_fn(x, xi))(X_train)

    Psi_1 = mean_expectation(psi1_fn, state, integrator)

    # Psi_2_{ij} = E[k(x, x_i) * k(x, x_j)]
    def psi2_fn(
        x: Float[Array, " D"],
    ) -> Float[Array, "N_train N_train"]:
        k_vec = jax.vmap(lambda xi: kernel_fn(x, xi))(X_train)
        return jnp.outer(k_vec, k_vec)

    from gaussx._einx import rearrange

    N_train = X_train.shape[0]
    psi2_flat_fn = lambda x: rearrange(psi2_fn(x), "i j -> (i j)")
    Psi_2_flat = mean_expectation(psi2_flat_fn, state, integrator)
    Psi_2 = rearrange(Psi_2_flat, "(i j) -> i j", i=N_train, j=N_train)

    return Psi_0, Psi_1, Psi_2

compute_psi_statistics(kernel: object, state: GaussianState, X_train: Float[Array, 'M D'], *, integrator: AbstractIntegrator | None = None) -> tuple[Float[Array, ''], Float[Array, ' M'], Float[Array, 'M M']]

Compute Ψ statistics, dispatching to analytical or numerical.

If kernel implements AnalyticalPsiStatistics, uses the closed-form methods. Otherwise, falls back to numerical integration via the provided integrator:

Ψ₀   = E[k(x, x)]                   scalar
Ψ₁ᵢ  = E[k(x, xᵢ)]                 (M,)
Ψ₂ᵢⱼ = E[k(x, xᵢ) k(x, xⱼ)]       (M, M)

Parameters:

Name Type Description Default
kernel object

Kernel object, optionally implementing AnalyticalPsiStatistics.

required
state GaussianState

Input Gaussian distribution x ~ 𝒩(μ, Σ).

required
X_train Float[Array, 'M D']

Training/inducing points, shape (M, D).

required
integrator AbstractIntegrator | None

Numerical integrator for fallback. Required if kernel does not implement analytical Ψ statistics.

None

Returns:

Type Description
tuple[Float[Array, ''], Float[Array, ' M'], Float[Array, 'M M']]

Tuple (Ψ₀, Ψ₁, Ψ₂) of Psi statistics.

Raises:

Type Description
ValueError

If kernel has no analytical Ψ statistics and no integrator is provided.

Source code in src/gaussx/_quadrature/_psi_statistics.py
def compute_psi_statistics(
    kernel: object,
    state: GaussianState,
    X_train: Float[Array, "M D"],
    *,
    integrator: AbstractIntegrator | None = None,
) -> tuple[Float[Array, ""], Float[Array, " M"], Float[Array, "M M"]]:
    """Compute Ψ statistics, dispatching to analytical or numerical.

    If ``kernel`` implements `AnalyticalPsiStatistics`, uses
    the closed-form methods. Otherwise, falls back to numerical
    integration via the provided integrator:

        Ψ₀   = E[k(x, x)]                   scalar
        Ψ₁ᵢ  = E[k(x, xᵢ)]                 (M,)
        Ψ₂ᵢⱼ = E[k(x, xᵢ) k(x, xⱼ)]       (M, M)

    Args:
        kernel: Kernel object, optionally implementing
            `AnalyticalPsiStatistics`.
        state: Input Gaussian distribution x ~ 𝒩(μ, Σ).
        X_train: Training/inducing points, shape ``(M, D)``.
        integrator: Numerical integrator for fallback. Required if
            ``kernel`` does not implement analytical Ψ statistics.

    Returns:
        Tuple ``(Ψ₀, Ψ₁, Ψ₂)`` of Psi statistics.

    Raises:
        ValueError: If ``kernel`` has no analytical Ψ statistics
            and no integrator is provided.
    """
    if isinstance(kernel, AnalyticalPsiStatistics):
        psi0 = kernel.psi0(state)
        psi1 = kernel.psi1(state, X_train)
        psi2 = kernel.psi2(state, X_train)
        return psi0, psi1, psi2

    if integrator is None:
        msg = (
            "Kernel does not implement AnalyticalPsiStatistics and no "
            "integrator was provided. Either implement the protocol on "
            "the kernel or pass an integrator for numerical computation."
        )
        raise ValueError(msg)

    # ── Numerical fallback ────────────────────────────────────────
    k_call = cast(Callable, kernel)

    # Ψ₀ = E[k(x, x)]
    def _k_self(x: Float[Array, " D"]) -> Float[Array, " 1"]:
        return jnp.atleast_1d(k_call(x, x))

    psi0_result = integrator.integrate(_k_self, state)
    psi0 = psi0_result.state.mean[0]  # scalar

    # Ψ₁ᵢ = E[k(x, xᵢ)]
    def _k_cross(x: Float[Array, " D"]) -> Float[Array, " M"]:
        return jax.vmap(lambda xj: k_call(x, xj))(X_train)

    psi1_result = integrator.integrate(_k_cross, state)
    psi1 = psi1_result.state.mean  # (M,)

    # Ψ₂ᵢⱼ = E[k(x, xᵢ) k(x, xⱼ)]
    M = X_train.shape[0]

    def _k_outer(x: Float[Array, " D"]) -> Float[Array, " flat"]:
        kx = jax.vmap(lambda xj: k_call(x, xj))(X_train)
        return rearrange(jnp.outer(kx, kx), "i j -> (i j)")  # (M²,)

    psi2_result = integrator.integrate(_k_outer, state)
    psi2 = rearrange(
        psi2_result.state.mean,
        "(i j) -> i j",
        i=M,
        j=M,
    )  # (M, M)

    return psi0, psi1, psi2

uncertain_gp_predict(kernel_fn: Callable, X_train: Float[Array, 'N_train D'], alpha: Float[Array, ' N_train'], K_inv: lx.AbstractLinearOperator, state: GaussianState, integrator: AbstractIntegrator) -> tuple[Float[Array, ''], Float[Array, '']]

Predictive mean and variance for GP with uncertain inputs.

Uses kernel expectations:

mu_pred = Psi_1 @ alpha
var_pred = Psi_0 - tr(K_inv @ Psi_2) + alpha^T @ Psi_2 @ alpha - mu_pred^2

Parameters:

Name Type Description Default
kernel_fn Callable

Kernel function k(x, x') -> scalar.

required
X_train Float[Array, 'N_train D']

Training points, shape (N_train, D).

required
alpha Float[Array, ' N_train']

Precomputed weights K^{-1} y, shape (N_train,).

required
K_inv AbstractLinearOperator

Inverse of training kernel matrix operator.

required
state GaussianState

Uncertain test input x ~ N(mu, Sigma).

required
integrator AbstractIntegrator

Integration method.

required

Returns:

Type Description
tuple[Float[Array, ''], Float[Array, '']]

Tuple (mean, variance) — scalar predictive moments.

Source code in src/gaussx/_quadrature/_gp_predict.py
def uncertain_gp_predict(
    kernel_fn: Callable,
    X_train: Float[Array, "N_train D"],
    alpha: Float[Array, " N_train"],
    K_inv: lx.AbstractLinearOperator,
    state: GaussianState,
    integrator: AbstractIntegrator,
) -> tuple[Float[Array, ""], Float[Array, ""]]:
    r"""Predictive mean and variance for GP with uncertain inputs.

    Uses kernel expectations:

        mu_pred = Psi_1 @ alpha
        var_pred = Psi_0 - tr(K_inv @ Psi_2) + alpha^T @ Psi_2 @ alpha - mu_pred^2

    Args:
        kernel_fn: Kernel function ``k(x, x') -> scalar``.
        X_train: Training points, shape ``(N_train, D)``.
        alpha: Precomputed weights ``K^{-1} y``, shape ``(N_train,)``.
        K_inv: Inverse of training kernel matrix operator.
        state: Uncertain test input ``x ~ N(mu, Sigma)``.
        integrator: Integration method.

    Returns:
        Tuple ``(mean, variance)`` — scalar predictive moments.
    """
    from gaussx._linalg._linalg import trace_product

    Psi_0, Psi_1, Psi_2 = kernel_expectations(kernel_fn, state, X_train, integrator)

    mu_pred = jnp.dot(Psi_1, alpha)

    Psi_2_op = lx.MatrixLinearOperator(Psi_2)
    tr_term = trace_product(K_inv, Psi_2_op)
    quad_term = alpha @ Psi_2 @ alpha

    var_pred = Psi_0 - tr_term + quad_term - mu_pred**2
    var_pred = jnp.clip(var_pred, 0.0)

    return mu_pred, var_pred

uncertain_gp_predict_mc(predict_fn: Callable[[Float[Array, ' D']], tuple[Float[Array, ''], Float[Array, '']]], state: GaussianState, n_particles: int = 100, key: jax.Array | None = None) -> tuple[Float[Array, ''], Float[Array, '']]

Monte Carlo GP prediction with uncertain inputs.

Alternative to analytic kernel expectations when Psi integrals are intractable. Uses law of total variance:

mu = mean(particle_means)
var = var(particle_means) + mean(particle_vars)

Parameters:

Name Type Description Default
predict_fn Callable[[Float[Array, ' D']], tuple[Float[Array, ''], Float[Array, '']]]

GP predictor mapping (D,) -> (mean, var).

required
state GaussianState

Uncertain test input x ~ N(mu, Sigma).

required
n_particles int

Number of Monte Carlo particles. Default 100.

100
key Array | None

PRNG key. If None, uses jax.random.key(0).

None

Returns:

Type Description
tuple[Float[Array, ''], Float[Array, '']]

Tuple (mean, variance) — scalar predictive moments.

Source code in src/gaussx/_quadrature/_gp_predict.py
def uncertain_gp_predict_mc(
    predict_fn: Callable[
        [Float[Array, " D"]], tuple[Float[Array, ""], Float[Array, ""]]
    ],
    state: GaussianState,
    n_particles: int = 100,
    key: jax.Array | None = None,
) -> tuple[Float[Array, ""], Float[Array, ""]]:
    r"""Monte Carlo GP prediction with uncertain inputs.

    Alternative to analytic kernel expectations when Psi integrals are
    intractable. Uses law of total variance:

        mu = mean(particle_means)
        var = var(particle_means) + mean(particle_vars)

    Args:
        predict_fn: GP predictor mapping ``(D,) -> (mean, var)``.
        state: Uncertain test input ``x ~ N(mu, Sigma)``.
        n_particles: Number of Monte Carlo particles. Default ``100``.
        key: PRNG key. If ``None``, uses ``jax.random.key(0)``.

    Returns:
        Tuple ``(mean, variance)`` — scalar predictive moments.
    """
    mu = state.mean
    N = mu.shape[0]

    if key is None:
        key = jr.key(0)

    # Sample inputs
    L = cholesky(state.cov).as_matrix()
    eps = jr.normal(key, (n_particles, N))
    x_samples = mu[None, :] + eps @ L.T

    # Predict at each sample
    means, variances = jax.vmap(predict_fn)(x_samples)

    # Law of total variance
    pred_mean = jnp.mean(means)
    pred_var = jnp.var(means) + jnp.mean(variances)

    return pred_mean, pred_var

uncertain_svgp_predict(kernel_fn: Callable, Z: Float[Array, 'M D'], alpha: Float[Array, ' M'], Q: lx.AbstractLinearOperator, state: GaussianState, integrator: AbstractIntegrator) -> tuple[Float[Array, ''], Float[Array, '']]

Predictive mean and variance for SVGP with uncertain inputs.

Uses kernel expectations with inducing points:

mu_pred = Psi_1 @ alpha
var_pred = Psi_0 + tr(Q @ Psi_2) + alpha^T @ Psi_2 @ alpha - mu_pred^2

where Q = K_{zz}^{-1} S K_{zz}^{-1} - K_{zz}^{-1} is the variance adjustment operator (see gaussx.svgp_variance_adjustment). The exact uncertain GP trace correction is recovered by setting Q = -K_{zz}^{-1}; if Z equals the training inputs, this matches gaussx.uncertain_gp_predict.

Parameters:

Name Type Description Default
kernel_fn Callable

Kernel function k(x, x') -> scalar.

required
Z Float[Array, 'M D']

Inducing points, shape (M, D).

required
alpha Float[Array, ' M']

Effective weights, shape (M,).

required
Q AbstractLinearOperator

Variance adjustment operator K_{zz}^{-1} S K_{zz}^{-1} - K_{zz}^{-1}, shape (M, M).

required
state GaussianState

Uncertain test input x ~ N(mu, Sigma).

required
integrator AbstractIntegrator

Integration method.

required

Returns:

Type Description
tuple[Float[Array, ''], Float[Array, '']]

Tuple (mean, variance) — scalar predictive moments.

Source code in src/gaussx/_quadrature/_gp_predict.py
def uncertain_svgp_predict(
    kernel_fn: Callable,
    Z: Float[Array, "M D"],
    alpha: Float[Array, " M"],
    Q: lx.AbstractLinearOperator,
    state: GaussianState,
    integrator: AbstractIntegrator,
) -> tuple[Float[Array, ""], Float[Array, ""]]:
    r"""Predictive mean and variance for SVGP with uncertain inputs.

    Uses kernel expectations with inducing points:

        mu_pred = Psi_1 @ alpha
        var_pred = Psi_0 + tr(Q @ Psi_2) + alpha^T @ Psi_2 @ alpha - mu_pred^2

    where ``Q = K_{zz}^{-1} S K_{zz}^{-1} - K_{zz}^{-1}`` is the variance
    adjustment operator (see `gaussx.svgp_variance_adjustment`).
    The exact uncertain GP trace correction is recovered by setting
    ``Q = -K_{zz}^{-1}``; if ``Z`` equals the training inputs, this matches
    `gaussx.uncertain_gp_predict`.

    Args:
        kernel_fn: Kernel function ``k(x, x') -> scalar``.
        Z: Inducing points, shape ``(M, D)``.
        alpha: Effective weights, shape ``(M,)``.
        Q: Variance adjustment operator ``K_{zz}^{-1} S K_{zz}^{-1} - K_{zz}^{-1}``,
            shape ``(M, M)``.
        state: Uncertain test input ``x ~ N(mu, Sigma)``.
        integrator: Integration method.

    Returns:
        Tuple ``(mean, variance)`` — scalar predictive moments.
    """
    from gaussx._linalg._linalg import trace_product

    Psi_0, Psi_1, Psi_2 = kernel_expectations(kernel_fn, state, Z, integrator)

    mu_pred = jnp.dot(Psi_1, alpha)

    Psi_2_op = lx.MatrixLinearOperator(Psi_2)
    tr_term = trace_product(Q, Psi_2_op)
    quad_term = alpha @ Psi_2 @ alpha

    var_pred = Psi_0 + tr_term + quad_term - mu_pred**2
    var_pred = jnp.clip(var_pred, 0.0)

    return mu_pred, var_pred

uncertain_vgp_predict(kernel_fn: Callable, X_train: Float[Array, 'N_train D'], alpha: Float[Array, ' N_train'], Q: lx.AbstractLinearOperator, state: GaussianState, integrator: AbstractIntegrator) -> tuple[Float[Array, ''], Float[Array, '']]

Predictive mean and variance for dense VGP with uncertain inputs.

Uses kernel expectations with training points:

mu_pred = Psi_1 @ alpha
var_pred = Psi_0 + tr(Q @ Psi_2) + alpha^T @ Psi_2 @ alpha - mu_pred^2

where Q = K^{-1} S K^{-1} - K^{-1} and alpha = K^{-1} m. The exact uncertain GP is the special case Q = -K^{-1}. By contrast, Q = 0 corresponds to S = K and removes only the trace correction.

Parameters:

Name Type Description Default
kernel_fn Callable

Kernel function k(x, x') -> scalar.

required
X_train Float[Array, 'N_train D']

Training points, shape (N_train, D).

required
alpha Float[Array, ' N_train']

Precomputed weights K^{-1} m, shape (N_train,).

required
Q AbstractLinearOperator

Variance adjustment operator K^{-1} S K^{-1} - K^{-1}, shape (N_train, N_train).

required
state GaussianState

Uncertain test input x ~ N(mu, Sigma).

required
integrator AbstractIntegrator

Integration method.

required

Returns:

Type Description
tuple[Float[Array, ''], Float[Array, '']]

Tuple (mean, variance) — scalar predictive moments.

Source code in src/gaussx/_quadrature/_gp_predict.py
def uncertain_vgp_predict(
    kernel_fn: Callable,
    X_train: Float[Array, "N_train D"],
    alpha: Float[Array, " N_train"],
    Q: lx.AbstractLinearOperator,
    state: GaussianState,
    integrator: AbstractIntegrator,
) -> tuple[Float[Array, ""], Float[Array, ""]]:
    r"""Predictive mean and variance for dense VGP with uncertain inputs.

    Uses kernel expectations with training points:

        mu_pred = Psi_1 @ alpha
        var_pred = Psi_0 + tr(Q @ Psi_2) + alpha^T @ Psi_2 @ alpha - mu_pred^2

    where ``Q = K^{-1} S K^{-1} - K^{-1}`` and ``alpha = K^{-1} m``.
    The exact uncertain GP is the special case ``Q = -K^{-1}``. By contrast,
    ``Q = 0`` corresponds to ``S = K`` and removes only the trace correction.

    Args:
        kernel_fn: Kernel function ``k(x, x') -> scalar``.
        X_train: Training points, shape ``(N_train, D)``.
        alpha: Precomputed weights ``K^{-1} m``, shape ``(N_train,)``.
        Q: Variance adjustment operator ``K^{-1} S K^{-1} - K^{-1}``,
            shape ``(N_train, N_train)``.
        state: Uncertain test input ``x ~ N(mu, Sigma)``.
        integrator: Integration method.

    Returns:
        Tuple ``(mean, variance)`` — scalar predictive moments.
    """
    from gaussx._linalg._linalg import trace_product

    Psi_0, Psi_1, Psi_2 = kernel_expectations(kernel_fn, state, X_train, integrator)

    mu_pred = jnp.dot(Psi_1, alpha)

    Psi_2_op = lx.MatrixLinearOperator(Psi_2)
    tr_term = trace_product(Q, Psi_2_op)
    quad_term = alpha @ Psi_2 @ alpha

    var_pred = Psi_0 + tr_term + quad_term - mu_pred**2
    var_pred = jnp.clip(var_pred, 0.0)

    return mu_pred, var_pred

uncertain_bgplvm_predict(kernel_fn: Callable, X_train: Float[Array, 'N_train D'], alpha: Float[Array, 'N_train D_out'], K_inv: lx.AbstractLinearOperator, state: GaussianState, integrator: AbstractIntegrator) -> tuple[Float[Array, ' D_out'], Float[Array, ' D_out']]

Multi-output uncertain GP prediction for BGPLVM.

Maps a latent variable z ~ N(mu, Sigma) to high-dimensional reconstruction y using per-output GP weights:

mu_pred_d = Psi_1 @ alpha_d
var_pred_d = Psi_0 - tr(K_inv @ Psi_2)
           + alpha_d^T @ Psi_2 @ alpha_d - mu_pred_d^2

This intentionally uses the exact GP trace term for every output dimension. Unlike gaussx.uncertain_vgp_predict and gaussx.uncertain_svgp_predict, there is no separate variational covariance correction operator.

Parameters:

Name Type Description Default
kernel_fn Callable

Kernel function k(x, x') -> scalar.

required
X_train Float[Array, 'N_train D']

Training points, shape (N_train, D).

required
alpha Float[Array, 'N_train D_out']

Multi-output weights K^{-1} Y, shape (N_train, D_out).

required
K_inv AbstractLinearOperator

Inverse of training kernel matrix operator.

required
state GaussianState

Uncertain latent input z ~ N(mu, Sigma).

required
integrator AbstractIntegrator

Integration method.

required

Returns:

Type Description
tuple[Float[Array, ' D_out'], Float[Array, ' D_out']]

Tuple (mean, variance) — predictive moments each of shape (D_out,).

Source code in src/gaussx/_quadrature/_gp_predict.py
def uncertain_bgplvm_predict(
    kernel_fn: Callable,
    X_train: Float[Array, "N_train D"],
    alpha: Float[Array, "N_train D_out"],
    K_inv: lx.AbstractLinearOperator,
    state: GaussianState,
    integrator: AbstractIntegrator,
) -> tuple[Float[Array, " D_out"], Float[Array, " D_out"]]:
    r"""Multi-output uncertain GP prediction for BGPLVM.

    Maps a latent variable ``z ~ N(mu, Sigma)`` to high-dimensional
    reconstruction ``y`` using per-output GP weights:

        mu_pred_d = Psi_1 @ alpha_d
        var_pred_d = Psi_0 - tr(K_inv @ Psi_2)
                   + alpha_d^T @ Psi_2 @ alpha_d - mu_pred_d^2

    This intentionally uses the exact GP trace term for every output dimension.
    Unlike `gaussx.uncertain_vgp_predict` and
    `gaussx.uncertain_svgp_predict`, there is no separate variational
    covariance correction operator.

    Args:
        kernel_fn: Kernel function ``k(x, x') -> scalar``.
        X_train: Training points, shape ``(N_train, D)``.
        alpha: Multi-output weights ``K^{-1} Y``, shape ``(N_train, D_out)``.
        K_inv: Inverse of training kernel matrix operator.
        state: Uncertain latent input ``z ~ N(mu, Sigma)``.
        integrator: Integration method.

    Returns:
        Tuple ``(mean, variance)`` — predictive moments each of shape ``(D_out,)``.
    """
    from gaussx._linalg._linalg import trace_product

    Psi_0, Psi_1, Psi_2 = kernel_expectations(kernel_fn, state, X_train, integrator)

    # mu_pred_d = Psi_1 @ alpha_d  for each output dimension
    mu_pred = Psi_1 @ alpha  # (D_out,)

    Psi_2_op = lx.MatrixLinearOperator(Psi_2)
    tr_term = trace_product(K_inv, Psi_2_op)

    # quad_term_d = alpha_d^T @ Psi_2 @ alpha_d  for each output
    quad_term = jnp.sum(alpha * (Psi_2 @ alpha), axis=0)  # (D_out,)

    var_pred = Psi_0 - tr_term + quad_term - mu_pred**2
    var_pred = jnp.clip(var_pred, 0.0)

    return mu_pred, var_pred