Skip to content

Gaussian Processes

Layer 3 recipes for GP inference: conditioning, whitening, prediction caches, pathwise (Matheron) sampling, variational bounds, and cross-validation — all expressed over covariance operators so structured kernels keep their fast paths end to end. The modelling shell (kernels with hyperparameter priors, NumPyro sites) lives downstream; gaussx owns the math.

Conditioning & prediction

The standard posterior

\[ \mu_* = K_{*f}\,K^{-1}(y - \mu), \qquad \Sigma_{**} = K_{**} - K_{*f}\,K^{-1}K_{f*} \]

plus a precomputed-cache variant for repeated test-time queries and a Kronecker-structured path for separable kernels on grids.

Structured linear algebra and Gaussian primitives for JAX.

PredictionCache

Bases: Module

Cached training solve for amortized predictions.

Stores alpha = K_y^{-1} y so that downstream predictions only require a matrix-vector product rather than a fresh solve.

Attributes:

Name Type Description
alpha Float[Array, ' N']

Solved weights K_y^{-1} y, shape (N,).

Source code in src/gaussx/_gp/_prediction_cache.py
class PredictionCache(eqx.Module):
    """Cached training solve for amortized predictions.

    Stores ``alpha = K_y^{-1} y`` so that downstream predictions only
    require a matrix-vector product rather than a fresh solve.

    Attributes:
        alpha: Solved weights ``K_y^{-1} y``, shape ``(N,)``.
    """

    alpha: Float[Array, " N"]

base_conditional(K_mm: Float[Array, 'M M'], K_mn: Float[Array, 'M N'], K_nn: Float[Array, 'N N'] | Float[Array, ' N'], f: Float[Array, 'M R'], *, q_sqrt: Float[Array, 'R M M'] | Float[Array, 'M R'] | None = None, white: bool = False, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, 'N R'], Float[Array, ...]]

Gaussian conditional distribution via Schur complement.

Computes the conditional distribution q(f_* | u) given:

  • Prior covariance K_mm at inducing locations
  • Cross-covariance K_mn between inducing and test locations
  • Prior (co)variance K_nn at test locations (full or diagonal)
  • Inducing function values f (or whitened values if white=True)
  • Optional variational posterior q(u) = N(f, q_sqrt q_sqrt^T)

The conditional mean is:

mu = K_nm K_mm^{-1} f   (or  K_nm L_mm^{-T} f  if white)

The conditional covariance is:

Sigma = K_nn - K_nm K_mm^{-1} K_mn + K_nm K_mm^{-1} S K_mm^{-1} K_mn

where S = q_sqrt @ q_sqrt^T is the variational covariance.

Parameters:

Name Type Description Default
K_mm Float[Array, 'M M']

Prior covariance at inducing points, shape (M, M).

required
K_mn Float[Array, 'M N']

Cross-covariance, shape (M, N).

required
K_nn Float[Array, 'N N'] | Float[Array, ' N']

Test-point covariance. Full (N, N) or diagonal (N,).

required
f Float[Array, 'M R']

Inducing function values, shape (M, R).

required
q_sqrt Float[Array, 'R M M'] | Float[Array, 'M R'] | None

Optional variational Cholesky factor. Full: (R, M, M), diagonal: (M, R), or None.

None
white bool

If True, f and q_sqrt are in whitened space (prior is N(0, I)).

False
solver AbstractSolverStrategy | None

Optional solver strategy for structured linear algebra. When None, falls back to structural dispatch. This parameter is accepted for API consistency but is not currently used by the Cholesky decomposition in this function.

None

Returns:

Type Description
Float[Array, 'N R']

(mean, var) where mean has shape (N, R) and var

Float[Array, ...]

has shape (N, N, R) (full K_nn) or (N, R) (diagonal K_nn).

Source code in src/gaussx/_gp/_base_conditional.py
def base_conditional(
    K_mm: Float[Array, "M M"],
    K_mn: Float[Array, "M N"],
    K_nn: Float[Array, "N N"] | Float[Array, " N"],
    f: Float[Array, "M R"],
    *,
    q_sqrt: Float[Array, "R M M"] | Float[Array, "M R"] | None = None,
    white: bool = False,
    solver: AbstractSolverStrategy | None = None,
) -> tuple[Float[Array, "N R"], Float[Array, ...]]:
    r"""Gaussian conditional distribution via Schur complement.

    Computes the conditional distribution ``q(f_* | u)`` given:

    - Prior covariance ``K_mm`` at inducing locations
    - Cross-covariance ``K_mn`` between inducing and test locations
    - Prior (co)variance ``K_nn`` at test locations (full or diagonal)
    - Inducing function values ``f`` (or whitened values if ``white=True``)
    - Optional variational posterior ``q(u) = N(f, q_sqrt q_sqrt^T)``

    The conditional mean is:

        mu = K_nm K_mm^{-1} f   (or  K_nm L_mm^{-T} f  if white)

    The conditional covariance is:

        Sigma = K_nn - K_nm K_mm^{-1} K_mn + K_nm K_mm^{-1} S K_mm^{-1} K_mn

    where ``S = q_sqrt @ q_sqrt^T`` is the variational covariance.

    Args:
        K_mm: Prior covariance at inducing points, shape ``(M, M)``.
        K_mn: Cross-covariance, shape ``(M, N)``.
        K_nn: Test-point covariance.  Full ``(N, N)`` or diagonal ``(N,)``.
        f: Inducing function values, shape ``(M, R)``.
        q_sqrt: Optional variational Cholesky factor.
            Full: ``(R, M, M)``, diagonal: ``(M, R)``, or ``None``.
        white: If ``True``, ``f`` and ``q_sqrt`` are in whitened space
            (prior is ``N(0, I)``).
        solver: Optional solver strategy for structured linear algebra.
            When ``None``, falls back to structural dispatch. This parameter
            is accepted for API consistency but is not currently used by the
            Cholesky decomposition in this function.

    Returns:
        ``(mean, var)`` where ``mean`` has shape ``(N, R)`` and ``var``
        has shape ``(N, N, R)`` (full K_nn) or ``(N, R)`` (diagonal K_nn).
    """
    del solver  # cholesky does not accept a solver; parameter reserved for future use
    R = f.shape[1]

    # Cholesky of prior
    L_mm = cholesky(  # (M, M)
        lx.MatrixLinearOperator(K_mm, lx.positive_semidefinite_tag)
    ).as_matrix()

    # A = L_mm^{-1} K_mn  ->  (M, N)
    A = jsla.solve_triangular(L_mm, K_mn, lower=True)

    # --- Conditional mean ---
    if white:
        mean = A.T @ f  # (N, R)
    else:
        alpha = jsla.solve_triangular(L_mm, f, lower=True)  # (M, R)
        mean = A.T @ alpha  # (N, R)

    # --- Conditional variance ---
    is_diag_knn = K_nn.ndim == 1

    # Prior variance reduction: K_nn - A^T A
    if is_diag_knn:
        prior_reduction = jnp.sum(A**2, axis=0)  # (N,)
        var_base = K_nn - prior_reduction  # (N,)
    else:
        var_base = K_nn - A.T @ A  # (N, N)

    if q_sqrt is not None:
        is_diag_q = q_sqrt.ndim == 2

        if is_diag_q:
            # q_sqrt: (M, R) — diagonal standard deviations
            # For non-white: precompute B = L_mm^{-1}, so
            #   var_adj_r = diag(A^T diag(B q_sqrt_r)^T diag(B q_sqrt_r) A)
            #             = diag((q_sqrt_r * B)^T A)^2 summed over M
            # For white: A_scaled = q_sqrt_r * A elementwise

            if white:

                def _var_adj_diag_white(q_r: Float[Array, " M"]) -> Float[Array, " N"]:
                    A_scaled = q_r[:, None] * A  # (M, N)
                    return jnp.sum(A_scaled**2, axis=0)

                var_adj = jax.vmap(_var_adj_diag_white, in_axes=1, out_axes=1)(
                    q_sqrt
                )  # (N, R)
            else:
                # B = L_mm^{-T} (solve once)
                # Then for each r: scaled = q_sqrt_r * B^T @ A
                # But simpler: A_scaled_r = (B * q_sqrt_r[None, :]) @ ... no.
                # Actually: var_adj = diag(A^T (L_mm^{-1} diag(s_r))^T
                #                         (L_mm^{-1} diag(s_r)) A)
                # Let C = L_mm^{-1}  (M, M), then
                #   A_scaled = (C * s_r[None,:])^T @ ... nah.
                # Simplest efficient: C = L_mm^{-1} (precompute once)
                # Then for each r: D_r = C * s_r  (broadcast M,M * M -> M,M)
                #   var_adj_r = sum((D_r @ ... wait, we need D_r.T @ A
                # Let me just do: for each r, solve L_mm @ x = diag(s_r),
                # but diag(s_r) is diagonal so L_mm^{-1} diag(s_r) = C * s_r
                # where C[:,j] * s_r[j].  Then (C * s_r).T @ A = ...
                # Let's just precompute C once:
                C = jsla.solve_triangular(L_mm, jnp.eye(L_mm.shape[0]), lower=True)

                def _var_adj_diag_nonwhite(
                    q_r: Float[Array, " M"],
                ) -> Float[Array, " N"]:
                    # C_scaled[i,j] = C[i,j] * q_r[j]
                    C_scaled = C * q_r[None, :]  # (M, M)
                    A_scaled = C_scaled.T @ A  # (M, N)
                    return jnp.sum(A_scaled**2, axis=0)

                var_adj = jax.vmap(_var_adj_diag_nonwhite, in_axes=1, out_axes=1)(
                    q_sqrt
                )  # (N, R)

            if is_diag_knn:
                var = var_base[:, None] + var_adj  # (N, R)
            else:
                # Full K_nn: need full covariance adjustment per r
                # Recompute with full matrices
                if white:

                    def _var_full_diag_white(
                        q_r: Float[Array, " M"],
                    ) -> Float[Array, "N N"]:
                        A_scaled = q_r[:, None] * A
                        return var_base + A_scaled.T @ A_scaled

                    var = jax.vmap(_var_full_diag_white, in_axes=1, out_axes=-1)(
                        q_sqrt
                    )  # (N, N, R)
                else:

                    def _var_full_diag_nonwhite(
                        q_r: Float[Array, " M"],
                    ) -> Float[Array, "N N"]:
                        C_scaled = C * q_r[None, :]
                        A_scaled = C_scaled.T @ A
                        return var_base + A_scaled.T @ A_scaled

                    var = jax.vmap(_var_full_diag_nonwhite, in_axes=1, out_axes=-1)(
                        q_sqrt
                    )  # (N, N, R)
        else:
            # q_sqrt: (R, M, M) — full Cholesky factors
            if white:

                def _var_adj_full_white(
                    L_q: Float[Array, "M M"],
                ) -> Float[Array, " N"]:
                    A_scaled = L_q.T @ A  # (M, N)
                    return jnp.sum(A_scaled**2, axis=0)

                def _var_full_full_white(
                    L_q: Float[Array, "M M"],
                ) -> Float[Array, "N N"]:
                    A_scaled = L_q.T @ A
                    return var_base + A_scaled.T @ A_scaled

            else:

                def _var_adj_full_nonwhite(
                    L_q: Float[Array, "M M"],
                ) -> Float[Array, " N"]:
                    L_q_proj = jsla.solve_triangular(L_mm, L_q, lower=True)
                    A_scaled = L_q_proj.T @ A
                    return jnp.sum(A_scaled**2, axis=0)

                def _var_full_full_nonwhite(
                    L_q: Float[Array, "M M"],
                ) -> Float[Array, "N N"]:
                    L_q_proj = jsla.solve_triangular(L_mm, L_q, lower=True)
                    A_scaled = L_q_proj.T @ A
                    return var_base + A_scaled.T @ A_scaled

            if is_diag_knn:
                if white:
                    var_adj = jax.vmap(_var_adj_full_white)(q_sqrt)  # (R, N)
                else:
                    var_adj = jax.vmap(_var_adj_full_nonwhite)(q_sqrt)  # (R, N)
                var = var_base[None, :] + var_adj  # (R, N)
                var = var.T  # (N, R)
            else:
                if white:
                    var = jax.vmap(_var_full_full_white)(q_sqrt)  # (R, N, N)
                else:
                    var = jax.vmap(_var_full_full_nonwhite)(q_sqrt)  # (R, N, N)
                # Transpose from (R, N, N) to (N, N, R)
                var = rearrange(var, "R N1 N2 -> N1 N2 R")
    else:
        # No variational posterior — just prior conditional
        if is_diag_knn:
            var = repeat(var_base, "N -> N R", R=R)
        else:
            var = repeat(var_base, "N1 N2 -> N1 N2 R", R=R)

    return mean, var

build_prediction_cache(operator: lx.AbstractLinearOperator, y: Float[Array, ' N'], *, solver: AbstractSolveStrategy | None = None) -> PredictionCache

Solve A alpha = y and cache the result.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

Training covariance operator K_y, shape (N, N).

required
y Float[Array, ' N']

Training targets, shape (N,).

required
solver AbstractSolveStrategy | None

Optional solve strategy. When None, falls back to structural-dispatch gaussx.solve.

None

Returns:

Type Description
PredictionCache

A PredictionCache holding the solved weights.

Source code in src/gaussx/_gp/_prediction_cache.py
def build_prediction_cache(
    operator: lx.AbstractLinearOperator,
    y: Float[Array, " N"],
    *,
    solver: AbstractSolveStrategy | None = None,
) -> PredictionCache:
    """Solve ``A alpha = y`` and cache the result.

    Args:
        operator: Training covariance operator ``K_y``, shape ``(N, N)``.
        y: Training targets, shape ``(N,)``.
        solver: Optional solve strategy. When ``None``, falls back
            to structural-dispatch `gaussx.solve`.

    Returns:
        A `PredictionCache` holding the solved weights.
    """
    alpha = dispatch_solve(operator, y, solver)
    return PredictionCache(alpha=alpha)

predict_mean(cache: PredictionCache, K_cross: Float[Array, 'Nt N']) -> Float[Array, ' Nt']

Predictive mean: mu* = K_*f @ alpha.

Parameters:

Name Type Description Default
cache PredictionCache

Prediction cache from build_prediction_cache.

required
K_cross Float[Array, 'Nt N']

Cross-covariance matrix, shape (Nt, N).

required

Returns:

Type Description
Float[Array, ' Nt']

Predictive mean, shape (Nt,).

Source code in src/gaussx/_gp/_prediction_cache.py
def predict_mean(
    cache: PredictionCache,
    K_cross: Float[Array, "Nt N"],
) -> Float[Array, " Nt"]:
    """Predictive mean: ``mu* = K_*f @ alpha``.

    Args:
        cache: Prediction cache from `build_prediction_cache`.
        K_cross: Cross-covariance matrix, shape ``(Nt, N)``.

    Returns:
        Predictive mean, shape ``(Nt,)``.
    """
    return K_cross @ cache.alpha

predict_variance(K_cross: Float[Array, 'Nt N'], K_test_diag: Float[Array, ' Nt'], operator: lx.AbstractLinearOperator, *, solver: AbstractSolveStrategy | None = None) -> Float[Array, ' Nt']

Predictive variance: sigma^2* = k_** - diag(K_*f K_y^{-1} K_f*).

For each test point i, solves K_y v_i = K_cross[i, :] and computes sigma^2_i = K_test_diag[i] - K_cross[i, :] @ v_i.

Parameters:

Name Type Description Default
K_cross Float[Array, 'Nt N']

Cross-covariance matrix, shape (Nt, N).

required
K_test_diag Float[Array, ' Nt']

Prior variance at test points, shape (Nt,).

required
operator AbstractLinearOperator

Training covariance operator K_y, shape (N, N).

required
solver AbstractSolveStrategy | None

Optional solve strategy. When None, falls back to structural-dispatch gaussx.solve.

None

Returns:

Type Description
Float[Array, ' Nt']

Predictive variance, shape (Nt,).

Source code in src/gaussx/_gp/_prediction_cache.py
def predict_variance(
    K_cross: Float[Array, "Nt N"],
    K_test_diag: Float[Array, " Nt"],
    operator: lx.AbstractLinearOperator,
    *,
    solver: AbstractSolveStrategy | None = None,
) -> Float[Array, " Nt"]:
    """Predictive variance: ``sigma^2* = k_** - diag(K_*f K_y^{-1} K_f*)``.

    For each test point *i*, solves ``K_y v_i = K_cross[i, :]`` and
    computes ``sigma^2_i = K_test_diag[i] - K_cross[i, :] @ v_i``.

    Args:
        K_cross: Cross-covariance matrix, shape ``(Nt, N)``.
        K_test_diag: Prior variance at test points, shape ``(Nt,)``.
        operator: Training covariance operator ``K_y``, shape ``(N, N)``.
        solver: Optional solve strategy. When ``None``, falls back
            to structural-dispatch `gaussx.solve`.

    Returns:
        Predictive variance, shape ``(Nt,)``.
    """

    from gaussx._linalg._linalg import solve_rows

    V = solve_rows(operator, K_cross, solver=solver)
    return K_test_diag - reduce(K_cross * V, "N M -> N", "sum")

conditional_interpolate(A_fwd: Float[Array, 'd d'], Q_fwd: Float[Array, 'd d'], A_bwd: Float[Array, 'd d'], Q_bwd: Float[Array, 'd d'], mu_prev: Float[Array, ' d'], P_prev: Float[Array, 'd d'], mu_next: Float[Array, ' d'], P_next: Float[Array, 'd d'], *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, ' d'], Float[Array, 'd d']]

Interpolated marginal at time t given posteriors at t^- and t^+.

For an SDE-discretized state-space model:

x_t     | x_{t^-} \sim N(A_{fwd} x_{t^-}, Q_{fwd})
x_{t^+} | x_t     \sim N(A_{bwd} x_t,     Q_{bwd})

computes p(x_t | x_{t^-}, x_{t^+}) using information fusion of the forward and backward predictions:

\Lambda_{fwd} = (A_{fwd} P_{prev} A_{fwd}^T + Q_{fwd})^{-1}
\Lambda_{bwd} = A_{bwd}^T (P_{next} + Q_{bwd})^{-1} A_{bwd}
\eta_{1,fwd} = \Lambda_{fwd} m_{fwd}
\eta_{1,bwd} = A_{bwd}^T (P_{next} + Q_{bwd})^{-1} \mu_{next}
\Lambda = \Lambda_{fwd} + \Lambda_{bwd}
P = \Lambda^{-1}
\mu = P (\eta_{1,fwd} + \eta_{1,bwd})

Parameters:

Name Type Description Default
A_fwd Float[Array, 'd d']

Forward transition from t^- to t, shape (d, d).

required
Q_fwd Float[Array, 'd d']

Forward process noise, shape (d, d).

required
A_bwd Float[Array, 'd d']

Backward transition from t to t^+, shape (d, d).

required
Q_bwd Float[Array, 'd d']

Backward process noise, shape (d, d).

required
mu_prev Float[Array, ' d']

Marginal mean at t^-, shape (d,).

required
P_prev Float[Array, 'd d']

Marginal covariance at t^-, shape (d, d).

required
mu_next Float[Array, ' d']

Marginal mean at t^+, shape (d,).

required
P_next Float[Array, 'd d']

Marginal covariance at t^+, shape (d, d).

required
solver AbstractSolverStrategy | None

Optional solver strategy for structured linear algebra. When None, falls back to structural dispatch.

None

Returns:

Type Description
tuple[Float[Array, ' d'], Float[Array, 'd d']]

Tuple (mean, cov) — interpolated marginal at t.

Source code in src/gaussx/_gp/_interpolation.py
def conditional_interpolate(
    A_fwd: Float[Array, "d d"],
    Q_fwd: Float[Array, "d d"],
    A_bwd: Float[Array, "d d"],
    Q_bwd: Float[Array, "d d"],
    mu_prev: Float[Array, " d"],
    P_prev: Float[Array, "d d"],
    mu_next: Float[Array, " d"],
    P_next: Float[Array, "d d"],
    *,
    solver: AbstractSolverStrategy | None = None,
) -> tuple[Float[Array, " d"], Float[Array, "d d"]]:
    r"""Interpolated marginal at time ``t`` given posteriors at ``t^-`` and ``t^+``.

    For an SDE-discretized state-space model:

        x_t     | x_{t^-} \sim N(A_{fwd} x_{t^-}, Q_{fwd})
        x_{t^+} | x_t     \sim N(A_{bwd} x_t,     Q_{bwd})

    computes ``p(x_t | x_{t^-}, x_{t^+})`` using information fusion
    of the forward and backward predictions:

        \Lambda_{fwd} = (A_{fwd} P_{prev} A_{fwd}^T + Q_{fwd})^{-1}
        \Lambda_{bwd} = A_{bwd}^T (P_{next} + Q_{bwd})^{-1} A_{bwd}
        \eta_{1,fwd} = \Lambda_{fwd} m_{fwd}
        \eta_{1,bwd} = A_{bwd}^T (P_{next} + Q_{bwd})^{-1} \mu_{next}
        \Lambda = \Lambda_{fwd} + \Lambda_{bwd}
        P = \Lambda^{-1}
        \mu = P (\eta_{1,fwd} + \eta_{1,bwd})

    Args:
        A_fwd: Forward transition from ``t^-`` to ``t``, shape ``(d, d)``.
        Q_fwd: Forward process noise, shape ``(d, d)``.
        A_bwd: Backward transition from ``t`` to ``t^+``, shape ``(d, d)``.
        Q_bwd: Backward process noise, shape ``(d, d)``.
        mu_prev: Marginal mean at ``t^-``, shape ``(d,)``.
        P_prev: Marginal covariance at ``t^-``, shape ``(d, d)``.
        mu_next: Marginal mean at ``t^+``, shape ``(d,)``.
        P_next: Marginal covariance at ``t^+``, shape ``(d, d)``.
        solver: Optional solver strategy for structured linear algebra.
            When ``None``, falls back to structural dispatch.

    Returns:
        Tuple ``(mean, cov)`` — interpolated marginal at ``t``.
    """
    # Forward prediction to t
    m_fwd = A_fwd @ mu_prev
    P_fwd = A_fwd @ P_prev @ A_fwd.T + Q_fwd

    # Forward information
    P_fwd_op = lx.MatrixLinearOperator(P_fwd, lx.positive_semidefinite_tag)
    Lambda_fwd = inv(P_fwd_op).as_matrix()
    eta1_fwd = dispatch_solve(P_fwd_op, m_fwd, solver)

    # Backward information from t+
    S_bwd = P_next + Q_bwd
    S_bwd_op = lx.MatrixLinearOperator(S_bwd, lx.positive_semidefinite_tag)
    Lambda_bwd = A_bwd.T @ solve_columns(S_bwd_op, A_bwd, solver=solver)
    eta1_bwd = A_bwd.T @ dispatch_solve(S_bwd_op, mu_next, solver)

    # Fuse forward and backward
    Lambda = Lambda_fwd + Lambda_bwd
    Lambda_op = lx.MatrixLinearOperator(Lambda, lx.positive_semidefinite_tag)
    P = inv(Lambda_op).as_matrix()
    m = P @ (eta1_fwd + eta1_bwd)

    return m, P

kronecker_posterior_predictive(K_factors: list[lx.AbstractLinearOperator], y: Float[Array, ' N'], noise_var: float, grid_shape: tuple[int, ...], K_cross_factors: list[Float[Array, 'Ni_test Ni_train']], *, K_test_diag_factors: list[Float[Array, ' Ni_test']]) -> tuple[Float[Array, ' N_test'], Float[Array, ' N_test']]

Posterior mean and variance for a Kronecker GP at test points.

Uses the eigendecomposition trick: projects cross-covariances onto the eigenbasis and weights by inverse eigenvalues:

mu_* = K_{*f} (K_{ff} + sigma^2 I)^{-1} y
var_* = k_{**} - K_{*f} (K_{ff} + sigma^2 I)^{-1} K_{f*}

Both computed via per-factor eigendecomposition in O(sum n_i^3 + N + N_test) time.

Parameters:

Name Type Description Default
K_factors list[AbstractLinearOperator]

List of per-dimension training kernel operators.

required
y Float[Array, ' N']

Observations, shape (N,) where N = prod(grid_shape).

required
noise_var float

Observation noise variance sigma^2.

required
grid_shape tuple[int, ...]

Grid shape, e.g. (n1, n2).

required
K_cross_factors list[Float[Array, 'Ni_test Ni_train']]

Per-dimension cross-covariance matrices, each shape (n_i_test, n_i_train).

required
K_test_diag_factors list[Float[Array, ' Ni_test']]

Per-dimension prior diagonals at the test points, each shape (n_i_test,).

required

Returns:

Type Description
tuple[Float[Array, ' N_test'], Float[Array, ' N_test']]

Tuple (mean, variance) at test points.

Source code in src/gaussx/_gp/_kronecker_gp.py
def kronecker_posterior_predictive(
    K_factors: list[lx.AbstractLinearOperator],
    y: Float[Array, " N"],
    noise_var: float,
    grid_shape: tuple[int, ...],
    K_cross_factors: list[Float[Array, "Ni_test Ni_train"]],
    *,
    K_test_diag_factors: list[Float[Array, " Ni_test"]],
) -> tuple[Float[Array, " N_test"], Float[Array, " N_test"]]:
    r"""Posterior mean and variance for a Kronecker GP at test points.

    Uses the eigendecomposition trick: projects cross-covariances onto
    the eigenbasis and weights by inverse eigenvalues:

        mu_* = K_{*f} (K_{ff} + sigma^2 I)^{-1} y
        var_* = k_{**} - K_{*f} (K_{ff} + sigma^2 I)^{-1} K_{f*}

    Both computed via per-factor eigendecomposition in
    ``O(sum n_i^3 + N + N_test)`` time.

    Args:
        K_factors: List of per-dimension training kernel operators.
        y: Observations, shape ``(N,)`` where ``N = prod(grid_shape)``.
        noise_var: Observation noise variance ``sigma^2``.
        grid_shape: Grid shape, e.g. ``(n1, n2)``.
        K_cross_factors: Per-dimension cross-covariance matrices,
            each shape ``(n_i_test, n_i_train)``.
        K_test_diag_factors: Per-dimension prior diagonals at the test points,
            each shape ``(n_i_test,)``.

    Returns:
        Tuple ``(mean, variance)`` at test points.
    """
    if len(K_factors) != len(grid_shape):
        msg = "grid_shape must have one entry per Kronecker factor"
        raise ValueError(msg)
    if len(K_cross_factors) != len(K_factors):
        msg = "K_cross_factors must have one matrix per Kronecker factor"
        raise ValueError(msg)
    if len(K_test_diag_factors) != len(K_factors):
        msg = "K_test_diag_factors must have one vector per Kronecker factor"
        raise ValueError(msg)

    # Per-factor eigendecomposition of training kernels
    factor_eigs = [eig(K_i) for K_i in K_factors]
    all_vals = [vals for vals, _ in factor_eigs]
    all_vecs = [vecs for _, vecs in factor_eigs]

    # Combined training eigenvalues
    combined_vals = ft.reduce(jnp.kron, all_vals)
    inv_noisy_vals = 1.0 / (combined_vals + noise_var)

    # Rotate observations: alpha = Q^T y
    alpha = _kron_rotate(all_vecs, y, grid_shape)

    # Weights in eigenbasis: w = diag(1/(lambda + sigma^2)) Q^T y
    w = inv_noisy_vals * alpha

    # Per-factor cross-covariance projected onto eigenbasis: A_i = K_cross_i @ Q_i
    A_factors = [K_cross_factors[i] @ all_vecs[i] for i in range(len(K_factors))]

    # Posterior mean: mu_* = (A_1 kron A_2 kron ...) w
    # Mean: project back from eigenbasis
    mean = _kron_matvec(A_factors, w, grid_shape)

    # Variance: k_** - sum_j (A_j^2 / (lambda_j + sigma^2))
    # = k_** - (A_1 kron A_2 kron ...)^2 @ inv_noisy_vals element-wise
    A_sq_factors = [A**2 for A in A_factors]
    var_reduction = _kron_matvec(A_sq_factors, inv_noisy_vals, grid_shape)

    # Prior diagonal at test points: diag(K_**) = kron_i diag(K_i(test, test))
    K_test_prior = ft.reduce(jnp.kron, K_test_diag_factors)

    variance = jnp.clip(K_test_prior - var_reduction, 0.0)

    return mean, variance

kronecker_mll(K_factors: list[lx.AbstractLinearOperator], y: Float[Array, ' N'], noise_var: float, grid_shape: tuple[int, ...]) -> Float[Array, '']

Exact marginal log-likelihood for a Kronecker-structured GP.

For a GP with covariance K = K_1 \otimes K_2 \otimes \ldots + sigma^2 I, computes the log marginal likelihood via per-factor eigendecomposition:

log p(y) = -0.5 * (y^T (K + sigma^2 I)^{-1} y
                   + log|K + sigma^2 I|
                   + N log(2 pi))

The Kronecker eigendecomposition avoids forming the full N x N matrix: if K_i = Q_i Lambda_i Q_i^T, the combined eigenvalues are the outer products of per-factor eigenvalues and the combined eigenvectors are the Kronecker product of per-factor eigenvectors.

Complexity: O(sum n_i^3 + N) instead of O(N^3) where N = prod n_i.

Parameters:

Name Type Description Default
K_factors list[AbstractLinearOperator]

List of per-dimension kernel operators. Each must be square and symmetric.

required
y Float[Array, ' N']

Observations, shape (N,) where N = prod(grid_shape).

required
noise_var float

Observation noise variance sigma^2.

required
grid_shape tuple[int, ...]

Shape of the grid, e.g. (n1, n2) for 2D.

required

Returns:

Type Description
Float[Array, '']

Scalar log marginal likelihood.

Source code in src/gaussx/_gp/_kronecker_gp.py
def kronecker_mll(
    K_factors: list[lx.AbstractLinearOperator],
    y: Float[Array, " N"],
    noise_var: float,
    grid_shape: tuple[int, ...],
) -> Float[Array, ""]:
    r"""Exact marginal log-likelihood for a Kronecker-structured GP.

    For a GP with covariance ``K = K_1 \otimes K_2 \otimes \ldots + sigma^2 I``,
    computes the log marginal likelihood via per-factor eigendecomposition:

        log p(y) = -0.5 * (y^T (K + sigma^2 I)^{-1} y
                           + log|K + sigma^2 I|
                           + N log(2 pi))

    The Kronecker eigendecomposition avoids forming the full ``N x N`` matrix:
    if ``K_i = Q_i Lambda_i Q_i^T``, the combined eigenvalues are the outer
    products of per-factor eigenvalues and the combined eigenvectors are the
    Kronecker product of per-factor eigenvectors.

    Complexity: ``O(sum n_i^3 + N)`` instead of ``O(N^3)`` where ``N = prod n_i``.

    Args:
        K_factors: List of per-dimension kernel operators. Each must be
            square and symmetric.
        y: Observations, shape ``(N,)`` where ``N = prod(grid_shape)``.
        noise_var: Observation noise variance ``sigma^2``.
        grid_shape: Shape of the grid, e.g. ``(n1, n2)`` for 2D.

    Returns:
        Scalar log marginal likelihood.
    """
    N = y.shape[0]

    # Per-factor eigendecomposition
    factor_eigs = [eig(K_i) for K_i in K_factors]
    all_vals = [vals for vals, _ in factor_eigs]
    all_vecs = [vecs for _, vecs in factor_eigs]

    # Combined eigenvalues: outer product of per-factor eigenvalues
    combined_vals = ft.reduce(jnp.kron, all_vals)  # (N,)
    noisy_vals = combined_vals + noise_var  # (N,)

    # Rotate data into eigenbasis: alpha = Q^T y
    # Q = Q_1 kron Q_2 kron ..., so Q^T y can be computed factor-by-factor
    alpha = _kron_rotate(all_vecs, y, grid_shape)

    # MLL in eigenbasis: all operations are O(N)
    data_fit = jnp.sum(alpha**2 / noisy_vals)
    log_det = jnp.sum(jnp.log(noisy_vals))
    const = N * jnp.log(2.0 * jnp.pi)

    return -0.5 * (data_fit + log_det + const)

Pathwise sampling

Matheron's rule turns joint prior draws \((a, b)\) into posterior draws: \(a + \mathrm{Cov}(a,b)\,\mathrm{Cov}(b,b)^{-1}(\beta - b)\).

Structured linear algebra and Gaussian primitives for JAX.

matheron_update(prior_sample_target: Float[Array, 'S N_star'], prior_sample_conditioning: Float[Array, 'S M'], observed_value: Float[Array, ' M'], cross_covariance: lx.AbstractLinearOperator, conditioning_covariance: lx.AbstractLinearOperator, *, solver: AbstractSolveStrategy | None = None) -> Float[Array, 'S N_star']

Posterior samples via Matheron's-rule correction.

Given joint prior draws (a, b) and an observed conditioning value β, Matheron's rule samples from a | b = β by applying

\[ a + \operatorname{Cov}(a, b)\operatorname{Cov}(b, b)^{-1}(β - b). \]

This helper keeps both covariance arguments as lineax operators, so the conditioning solve uses the existing GaussX structural dispatch and the target correction is a rectangular matvec.

Parameters:

Name Type Description Default
prior_sample_target Float[Array, 'S N_star']

Prior samples at target points, shape (S, N_star).

required
prior_sample_conditioning Float[Array, 'S M']

Joint prior samples at conditioning points, shape (S, M).

required
observed_value Float[Array, ' M']

Observed conditioning value, shape (M,).

required
cross_covariance AbstractLinearOperator

Cross-covariance operator Cov(a, b), shape (N_star, M).

required
conditioning_covariance AbstractLinearOperator

Conditioning covariance operator Cov(b, b), shape (M, M).

required
solver AbstractSolveStrategy | None

Optional solver strategy for the conditioning solve (e.g. gaussx.CGSolver, gaussx.BBMMSolver). When None, routes through structural dispatch on conditioning_covariance.

None

Returns:

Type Description
Float[Array, 'S N_star']

Corrected posterior samples, shape (S, N_star).

Source code in src/gaussx/_gp/_matheron.py
def matheron_update(
    prior_sample_target: Float[Array, "S N_star"],
    prior_sample_conditioning: Float[Array, "S M"],
    observed_value: Float[Array, " M"],
    cross_covariance: lx.AbstractLinearOperator,
    conditioning_covariance: lx.AbstractLinearOperator,
    *,
    solver: AbstractSolveStrategy | None = None,
) -> Float[Array, "S N_star"]:
    r"""Posterior samples via Matheron's-rule correction.

    Given joint prior draws ``(a, b)`` and an observed conditioning value
    ``β``, Matheron's rule samples from ``a | b = β`` by applying

    $$
    a + \operatorname{Cov}(a, b)\operatorname{Cov}(b, b)^{-1}(β - b).
    $$

    This helper keeps both covariance arguments as lineax operators, so the
    conditioning solve uses the existing GaussX structural dispatch and the
    target correction is a rectangular matvec.

    Args:
        prior_sample_target: Prior samples at target points, shape
            ``(S, N_star)``.
        prior_sample_conditioning: Joint prior samples at conditioning
            points, shape ``(S, M)``.
        observed_value: Observed conditioning value, shape ``(M,)``.
        cross_covariance: Cross-covariance operator ``Cov(a, b)``, shape
            ``(N_star, M)``.
        conditioning_covariance: Conditioning covariance operator
            ``Cov(b, b)``, shape ``(M, M)``.
        solver: Optional solver strategy for the conditioning solve
            (e.g. `gaussx.CGSolver`, `gaussx.BBMMSolver`).
            When ``None``, routes through structural dispatch on
            ``conditioning_covariance``.

    Returns:
        Corrected posterior samples, shape ``(S, N_star)``.
    """
    prior_sample_target = jnp.asarray(prior_sample_target)
    prior_sample_conditioning = jnp.asarray(prior_sample_conditioning)
    observed_value = jnp.asarray(observed_value)
    dtype = jnp.result_type(
        prior_sample_target,
        prior_sample_conditioning,
        observed_value,
    )
    # Promote sample arrays and observations together so the residual,
    # structured solve, and matvec all use the same precision.
    prior_sample_target = prior_sample_target.astype(dtype)
    prior_sample_conditioning = prior_sample_conditioning.astype(dtype)
    observed_value = observed_value.astype(dtype)

    if prior_sample_target.ndim != 2:
        raise ValueError("prior_sample_target must have shape (S, N_star).")
    if prior_sample_conditioning.ndim != 2:
        raise ValueError("prior_sample_conditioning must have shape (S, M).")
    if observed_value.ndim != 1:
        raise ValueError("observed_value must have shape (M,).")
    if prior_sample_target.shape[0] != prior_sample_conditioning.shape[0]:
        raise ValueError("prior samples must have the same sample dimension S.")

    num_conditioning = prior_sample_conditioning.shape[1]
    num_target = prior_sample_target.shape[1]
    if observed_value.shape[0] != num_conditioning:
        raise ValueError("observed_value must match the conditioning dimension M.")
    if conditioning_covariance.in_size() != num_conditioning:
        raise ValueError(
            "conditioning_covariance input size must match the conditioning "
            "dimension M."
        )
    if conditioning_covariance.out_size() != num_conditioning:
        raise ValueError(
            "conditioning_covariance output size must match the conditioning "
            "dimension M."
        )
    if cross_covariance.in_size() != num_conditioning:
        raise ValueError(
            "cross_covariance input size must match the conditioning dimension M."
        )
    if cross_covariance.out_size() != num_target:
        raise ValueError(
            "cross_covariance output size must match the target dimension N_star."
        )

    residuals = observed_value[None, :] - prior_sample_conditioning
    if solver is not None:
        solve_one = lambda r: dispatch_solve(conditioning_covariance, r, solver)
    else:
        solve_one = lambda r: solve(conditioning_covariance, r)
    solves = jax.vmap(solve_one)(residuals)
    corrections = jax.vmap(cross_covariance.mv)(solves)
    return prior_sample_target + corrections

Whitening

The whitened parameterization \(u = Lv\), \(v \sim \mathcal{N}(0, I)\) that keeps sparse-variational optimization well-conditioned, and the whitened SVGP predictive that consumes it.

Structured linear algebra and Gaussian primitives for JAX.

whiten_covariance = unwhiten_covariance module-attribute

unwhiten(m_tilde: Float[Array, ' M'], L: lx.AbstractLinearOperator) -> Float[Array, ' M']

Unwhiten variational mean: m = L @ m_tilde.

Parameters:

Name Type Description Default
m_tilde Float[Array, ' M']

Whitened mean vector, shape (M,).

required
L AbstractLinearOperator

Cholesky factor, shape (M, M).

required

Returns:

Type Description
Float[Array, ' M']

Unwhitened mean m, shape (M,).

Source code in src/gaussx/_gp/_unwhiten.py
def unwhiten(
    m_tilde: Float[Array, " M"],
    L: lx.AbstractLinearOperator,
) -> Float[Array, " M"]:
    """Unwhiten variational mean: ``m = L @ m_tilde``.

    Args:
        m_tilde: Whitened mean vector, shape ``(M,)``.
        L: Cholesky factor, shape ``(M, M)``.

    Returns:
        Unwhitened mean m, shape ``(M,)``.
    """
    return L.mv(m_tilde)

unwhiten_covariance(L: lx.AbstractLinearOperator, S_tilde: lx.AbstractLinearOperator) -> lx.MatrixLinearOperator

Unwhiten variational covariance: S = L S̃ Lᵀ.

Delegates to cov_transform.

Parameters:

Name Type Description Default
L AbstractLinearOperator

Cholesky factor, shape (M, M).

required
S_tilde AbstractLinearOperator

Whitened variational covariance, shape (M, M).

required

Returns:

Type Description
MatrixLinearOperator

Unwhitened covariance operator S.

Source code in src/gaussx/_gp/_unwhiten.py
def unwhiten_covariance(
    L: lx.AbstractLinearOperator,
    S_tilde: lx.AbstractLinearOperator,
) -> lx.MatrixLinearOperator:
    """Unwhiten variational covariance: S = L S̃ Lᵀ.

    Delegates to `cov_transform`.

    Args:
        L: Cholesky factor, shape ``(M, M)``.
        S_tilde: Whitened variational covariance, shape ``(M, M)``.

    Returns:
        Unwhitened covariance operator S.
    """
    return cov_transform(L.as_matrix(), S_tilde)

whitened_svgp_predict(K_zz_op: lx.AbstractLinearOperator, K_xz: Float[Array, 'N M'], u_mean: Float[Array, ' M'], u_chol: Float[Array, 'M M'], K_xx_diag: Float[Array, ' N']) -> tuple[Float[Array, ' N'], Float[Array, ' N']]

Whitened SVGP prediction: mean and variance at test points.

Computes the predictive mean and variance for a sparse variational GP using the whitened parameterization:

L_{zz} = cholesky(K_{zz})
A = L_{zz}^{-1} K_{zx}           (triangular solve)
f_{loc} = A^T u_{mean}
Q_{xx} = sum(A^2, axis=0)         (prior variance reduction)
W = u_{chol}^T A
S_{contrib} = sum(W^2, axis=0)    (posterior variance contribution)
f_{var} = K_{xx,diag} - Q_{xx} + S_{contrib}

Parameters:

Name Type Description Default
K_zz_op AbstractLinearOperator

Inducing-point covariance operator, shape (M, M).

required
K_xz Float[Array, 'N M']

Cross-covariance matrix, shape (N, M).

required
u_mean Float[Array, ' M']

Whitened variational mean, shape (M,).

required
u_chol Float[Array, 'M M']

Whitened variational Cholesky factor, shape (M, M). Lower-triangular matrix such that the variational covariance in whitened space is u_chol @ u_chol^T.

required
K_xx_diag Float[Array, ' N']

Prior diagonal variances at test points, shape (N,).

required

Returns:

Type Description
Float[Array, ' N']

Tuple (f_loc, f_var) — predictive mean shape (N,) and

Float[Array, ' N']

predictive variance shape (N,).

Source code in src/gaussx/_gp/_svgp.py
def whitened_svgp_predict(
    K_zz_op: lx.AbstractLinearOperator,
    K_xz: Float[Array, "N M"],
    u_mean: Float[Array, " M"],
    u_chol: Float[Array, "M M"],
    K_xx_diag: Float[Array, " N"],
) -> tuple[Float[Array, " N"], Float[Array, " N"]]:
    r"""Whitened SVGP prediction: mean and variance at test points.

    Computes the predictive mean and variance for a sparse variational
    GP using the whitened parameterization:

        L_{zz} = cholesky(K_{zz})
        A = L_{zz}^{-1} K_{zx}           (triangular solve)
        f_{loc} = A^T u_{mean}
        Q_{xx} = sum(A^2, axis=0)         (prior variance reduction)
        W = u_{chol}^T A
        S_{contrib} = sum(W^2, axis=0)    (posterior variance contribution)
        f_{var} = K_{xx,diag} - Q_{xx} + S_{contrib}

    Args:
        K_zz_op: Inducing-point covariance operator, shape ``(M, M)``.
        K_xz: Cross-covariance matrix, shape ``(N, M)``.
        u_mean: Whitened variational mean, shape ``(M,)``.
        u_chol: Whitened variational Cholesky factor, shape ``(M, M)``.
            Lower-triangular matrix such that the variational covariance
            in whitened space is ``u_chol @ u_chol^T``.
        K_xx_diag: Prior diagonal variances at test points, shape ``(N,)``.

    Returns:
        Tuple ``(f_loc, f_var)`` — predictive mean shape ``(N,)`` and
        predictive variance shape ``(N,)``.
    """
    L_zz = cholesky(K_zz_op)

    # A = L_zz^{-1} K_xz^T  -> shape (M, N)
    # Solve L_zz @ A_col = K_xz^T_col for each column of K_xzᵀ
    from gaussx._linalg._linalg import solve_columns

    K_zx = K_xz.T  # (M, N)
    A = solve_columns(L_zz, K_zx)

    # Predictive mean: f_loc = A^T @ u_mean = K_xz @ L_zz^{-T} @ u_mean
    f_loc = A.T @ u_mean

    # Prior variance reduction: Q_xx = Σₘ Aₘₙ²
    Q_xx = reduce(A**2, "M N -> N", "sum")

    # Posterior variance contribution: W = u_cholᵀ @ A, S = Σₖ Wₖₙ²
    W = u_chol.T @ A
    S_contrib = reduce(W**2, "K N -> N", "sum")

    # Predictive variance
    f_var = jnp.clip(K_xx_diag - Q_xx + S_contrib, 0.0)

    return f_loc, f_var

svgp_variance_adjustment(K_zz_op: lx.AbstractLinearOperator, S_u: lx.AbstractLinearOperator) -> lx.AbstractLinearOperator

Compute the SVGP variance adjustment operator.

Builds the operator Q = K_{zz}^{-1} S_u K_{zz}^{-1} - K_{zz}^{-1} which appears in every sparse GP predictive variance computation:

Var[f_*] = k_{**} - k_{*z} (K_{zz}^{-1} - Q) k_{z*}

The returned value is exposed as a linear operator, but the current implementation materializes dense (M, M) intermediates while building it.

Parameters:

Name Type Description Default
K_zz_op AbstractLinearOperator

Inducing-point covariance operator, shape (M, M).

required
S_u AbstractLinearOperator

Variational covariance operator, shape (M, M).

required

Returns:

Type Description
AbstractLinearOperator

Operator Q of shape (M, M) such that

AbstractLinearOperator

Q @ v = K_{zz}^{-1} S_u K_{zz}^{-1} v - K_{zz}^{-1} v.

Source code in src/gaussx/_gp/_svgp_variance.py
def svgp_variance_adjustment(
    K_zz_op: lx.AbstractLinearOperator,
    S_u: lx.AbstractLinearOperator,
) -> lx.AbstractLinearOperator:
    r"""Compute the SVGP variance adjustment operator.

    Builds the operator ``Q = K_{zz}^{-1} S_u K_{zz}^{-1} - K_{zz}^{-1}``
    which appears in every sparse GP predictive variance computation:

        Var[f_*] = k_{**} - k_{*z} (K_{zz}^{-1} - Q) k_{z*}

    The returned value is exposed as a linear operator, but the current
    implementation materializes dense ``(M, M)`` intermediates while building
    it.

    Args:
        K_zz_op: Inducing-point covariance operator, shape ``(M, M)``.
        S_u: Variational covariance operator, shape ``(M, M)``.

    Returns:
        Operator ``Q`` of shape ``(M, M)`` such that
        ``Q @ v = K_{zz}^{-1} S_u K_{zz}^{-1} v - K_{zz}^{-1} v``.
    """
    K_inv = inv(K_zz_op)
    K_inv_S = lx.MatrixLinearOperator(
        _compose_dense(K_inv, S_u),
    )
    # Q = K_inv @ S_u @ K_inv - K_inv
    # Build as (K_inv @ S_u - I) @ K_inv
    import jax.numpy as jnp

    M = K_zz_op.out_structure().shape[0]
    K_inv_S_minus_I = lx.AddLinearOperator(
        K_inv_S,
        lx.DiagonalLinearOperator(-jnp.ones(M)),
    )
    Q_dense = _compose_dense(K_inv_S_minus_I, K_inv)
    return lx.MatrixLinearOperator(Q_dense)

Variational bounds & KL

ELBOs for Gaussian and Monte-Carlo variational families, the collapsed (Titsias) sparse bound, and the Gaussian-to-Gaussian KL term.

Structured linear algebra and Gaussian primitives for JAX.

variational_elbo_gaussian(y: Float[Array, ' N'], f_loc: Float[Array, ' N'], f_var: Float[Array, ' N'], noise_var: float, kl: Float[Array, '']) -> Float[Array, '']

Titsias collapsed ELBO for Gaussian likelihoods.

Computes:

ELBO = E_q[log p(y|f)] - KL(q||p)

where the expected log-likelihood under a Gaussian variational distribution with diagonal variance has the closed form:

E_q[log N(y|f, sigma^2 I)]
    = -0.5 * N * log(2 pi sigma^2)
      -0.5 / sigma^2 * (||y - f_loc||^2 + sum(f_var))

Parameters:

Name Type Description Default
y Float[Array, ' N']

Observations, shape (N,).

required
f_loc Float[Array, ' N']

Variational mean, shape (N,).

required
f_var Float[Array, ' N']

Variational marginal variances, shape (N,).

required
noise_var float

Observation noise variance (scalar).

required
kl Float[Array, '']

KL divergence term KL(q || p) (scalar).

required

Returns:

Type Description
Float[Array, '']

Scalar ELBO value.

Source code in src/gaussx/_gp/_elbo.py
def variational_elbo_gaussian(
    y: Float[Array, " N"],
    f_loc: Float[Array, " N"],
    f_var: Float[Array, " N"],
    noise_var: float,
    kl: Float[Array, ""],
) -> Float[Array, ""]:
    """Titsias collapsed ELBO for Gaussian likelihoods.

    Computes:

        ELBO = E_q[log p(y|f)] - KL(q||p)

    where the expected log-likelihood under a Gaussian variational
    distribution with diagonal variance has the closed form:

        E_q[log N(y|f, sigma^2 I)]
            = -0.5 * N * log(2 pi sigma^2)
              -0.5 / sigma^2 * (||y - f_loc||^2 + sum(f_var))

    Args:
        y: Observations, shape ``(N,)``.
        f_loc: Variational mean, shape ``(N,)``.
        f_var: Variational marginal variances, shape ``(N,)``.
        noise_var: Observation noise variance (scalar).
        kl: KL divergence term ``KL(q || p)`` (scalar).

    Returns:
        Scalar ELBO value.
    """
    N = y.shape[-1]
    residual = y - f_loc
    ell = -0.5 * N * jnp.log(noise_var) - 0.5 * N * _LOG_2PI
    ell = ell - 0.5 / noise_var * (jnp.sum(residual**2) + jnp.sum(f_var))
    return ell - kl

variational_elbo_mc(log_likelihood_fn: Callable[[Float[Array, ' N']], Float[Array, '']], f_samples: Float[Array, 'S N'], kl: Float[Array, '']) -> Float[Array, '']

Monte Carlo ELBO for non-conjugate likelihoods.

Computes:

ELBO = (1/S) sum_s log p(y|f_s) - KL(q||p)

where f_s ~ q(f) are samples from the variational distribution. Supports any likelihood (Poisson, Bernoulli, Pareto, etc.).

Parameters:

Name Type Description Default
log_likelihood_fn Callable[[Float[Array, ' N']], Float[Array, '']]

Function mapping latent samples to scalar log-likelihood. Signature (N,) -> scalar.

required
f_samples Float[Array, 'S N']

Samples from the variational posterior, shape (S, N) where S is the number of samples.

required
kl Float[Array, '']

KL divergence term KL(q || p) (scalar).

required

Returns:

Type Description
Float[Array, '']

Scalar ELBO value.

Source code in src/gaussx/_gp/_elbo.py
def variational_elbo_mc(
    log_likelihood_fn: Callable[[Float[Array, " N"]], Float[Array, ""]],
    f_samples: Float[Array, "S N"],
    kl: Float[Array, ""],
) -> Float[Array, ""]:
    """Monte Carlo ELBO for non-conjugate likelihoods.

    Computes:

        ELBO = (1/S) sum_s log p(y|f_s) - KL(q||p)

    where ``f_s ~ q(f)`` are samples from the variational distribution.
    Supports any likelihood (Poisson, Bernoulli, Pareto, etc.).

    Args:
        log_likelihood_fn: Function mapping latent samples to scalar
            log-likelihood. Signature ``(N,) -> scalar``.
        f_samples: Samples from the variational posterior, shape
            ``(S, N)`` where S is the number of samples.
        kl: KL divergence term ``KL(q || p)`` (scalar).

    Returns:
        Scalar ELBO value.
    """
    import jax

    ell = jnp.mean(jax.vmap(log_likelihood_fn)(f_samples))
    return ell - kl

collapsed_elbo(y: Float[Array, ' N'], K_diag: Float[Array, ' N'], K_xz: Float[Array, 'N M'], K_zz: Float[Array, 'M M'], noise_var: float, *, jitter: float = 1e-06, solver: AbstractSolverStrategy | None = None) -> Float[Array, '']

Collapsed ELBO (Titsias bound) for sparse GP regression.

Computes the variational lower bound on the log marginal likelihood using the matrix determinant lemma for O(NM² + M³) cost:

ELBO = log 𝒩(y | 0, Q_ff + σ²I) − ½σ⁻² tr(K_ff − Q_ff)

where Q_ff = K_xz K_zz⁻¹ K_xzᵀ is the Nyström approximation.

Parameters:

Name Type Description Default
y Float[Array, ' N']

Observations, shape (N,).

required
K_diag Float[Array, ' N']

Diagonal of full kernel matrix K_ff, shape (N,).

required
K_xz Float[Array, 'N M']

Cross-covariance between data and inducing points, shape (N, M).

required
K_zz Float[Array, 'M M']

Inducing point kernel matrix, shape (M, M).

required
noise_var float

Observation noise variance σ² (scalar).

required
jitter float

Diagonal jitter for numerical stability in Cholesky decomposition of K_zz.

1e-06
solver AbstractSolverStrategy | None

Optional solver strategy for structured linear algebra. When None, falls back to structural dispatch. This parameter is accepted for API consistency but is not currently used by the Cholesky decompositions in this function.

None

Returns:

Type Description
Float[Array, '']

Scalar ELBO value.

Source code in src/gaussx/_gp/_collapsed_elbo.py
def collapsed_elbo(
    y: Float[Array, " N"],
    K_diag: Float[Array, " N"],
    K_xz: Float[Array, "N M"],
    K_zz: Float[Array, "M M"],
    noise_var: float,
    *,
    jitter: float = 1e-6,
    solver: AbstractSolverStrategy | None = None,
) -> Float[Array, ""]:
    """Collapsed ELBO (Titsias bound) for sparse GP regression.

    Computes the variational lower bound on the log marginal likelihood
    using the matrix determinant lemma for O(NM² + M³) cost:

        ELBO = log 𝒩(y | 0, Q_ff + σ²I) − ½σ⁻² tr(K_ff − Q_ff)

    where Q_ff = K_xz K_zz⁻¹ K_xzᵀ is the Nyström approximation.

    Args:
        y: Observations, shape ``(N,)``.
        K_diag: Diagonal of full kernel matrix K_ff, shape ``(N,)``.
        K_xz: Cross-covariance between data and inducing points,
            shape ``(N, M)``.
        K_zz: Inducing point kernel matrix, shape ``(M, M)``.
        noise_var: Observation noise variance σ² (scalar).
        jitter: Diagonal jitter for numerical stability in Cholesky
            decomposition of K_zz.
        solver: Optional solver strategy for structured linear algebra.
            When ``None``, falls back to structural dispatch. This parameter
            is accepted for API consistency but is not currently used by the
            Cholesky decompositions in this function.

    Returns:
        Scalar ELBO value.
    """
    del solver  # cholesky does not accept a solver; parameter reserved for future use
    N = y.shape[0]
    M = K_zz.shape[0]

    # L_zz L_zzᵀ = K_zz + jitter · I
    K_zz_jitter = K_zz + jitter * jnp.eye(M)
    L_zz = cholesky(  # (M, M)
        lx.MatrixLinearOperator(K_zz_jitter, lx.positive_semidefinite_tag)
    ).as_matrix()

    # V = L_zz⁻¹ K_xzᵀ
    V = lax.linalg.triangular_solve(
        L_zz,
        K_xz.T,
        left_side=True,
        lower=True,
    )  # (M, N)

    # B = I_M + σ⁻² V Vᵀ
    B = jnp.eye(M) + (1.0 / noise_var) * (V @ V.T)  # (M, M)
    L_B = cholesky(  # (M, M)
        lx.MatrixLinearOperator(B, lx.positive_semidefinite_tag)
    ).as_matrix()

    # log|Q_ff + σ²I| = N log σ² + log|B|
    from gaussx._primitives._logdet import cholesky_logdet

    log_det = N * jnp.log(noise_var) + cholesky_logdet(L_B)

    # Quadratic form via Woodbury:
    # yᵀ (Q_ff + σ²I)⁻¹ y = σ⁻²(‖y‖² − σ⁻² ‖L_B⁻¹ V y‖²)
    Vy = V @ y  # (M,)
    LBinv_Vy = lax.linalg.triangular_solve(
        L_B,
        Vy,
        left_side=True,
        lower=True,
    )  # (M,)
    quad = (1.0 / noise_var) * (
        jnp.sum(y**2) - (1.0 / noise_var) * jnp.sum(LBinv_Vy**2)
    )

    # Trace penalty: −½σ⁻² (tr(K_ff) − tr(Q_ff))
    # where tr(Q_ff) = ‖V‖²_F
    trace_penalty = -0.5 / noise_var * (jnp.sum(K_diag) - jnp.sum(V**2))

    return -0.5 * (log_det + quad + N * _LOG_2PI) + trace_penalty

gauss_kl(q_mu: Float[Array, 'M R'], q_sqrt: Float[Array, 'R M M'] | Float[Array, 'M R'], K: Float[Array, 'M M'] | None = None, *, solver: AbstractSolverStrategy | None = None) -> Float[Array, '']

KL divergence KL[q(u) || p(u)] between Gaussian distributions.

Cholesky-parameterised variant of dist_kl_divergence designed for GP/SVGP models. The Cholesky representation avoids explicit covariance matrix construction and supports both full and diagonal q_sqrt. For lineax-operator covariances, use dist_kl_divergence instead.

Computes the KL divergence where:

  • q = N(q_mu, q_sqrt @ q_sqrt^T)
  • p = N(0, K) or p = N(0, I) if K is None (white prior)

Handles both full and diagonal q_sqrt:

  • Full q_sqrt: shape (R, M, M) — lower-triangular Cholesky factors of the variational covariance per output dimension.
  • Diagonal q_sqrt: shape (M, R) — diagonal standard deviations.

Parameters:

Name Type Description Default
q_mu Float[Array, 'M R']

Variational mean, shape (M, R).

required
q_sqrt Float[Array, 'R M M'] | Float[Array, 'M R']

Variational Cholesky factor or diagonal std devs.

required
K Float[Array, 'M M'] | None

Prior covariance matrix, shape (M, M). If None, uses white prior (identity).

None
solver AbstractSolverStrategy | None

Optional solver strategy for structured linear algebra. When None, falls back to structural dispatch. This parameter is accepted for API consistency but is not currently used by the Cholesky decomposition in this function.

None

Returns:

Type Description
Float[Array, '']

Scalar KL divergence summed over all R output dimensions.

See Also

dist_kl_divergence: General KL between two multivariate normals with lineax covariance operators.

Source code in src/gaussx/_gp/_gauss_kl.py
def gauss_kl(
    q_mu: Float[Array, "M R"],
    q_sqrt: Float[Array, "R M M"] | Float[Array, "M R"],
    K: Float[Array, "M M"] | None = None,
    *,
    solver: AbstractSolverStrategy | None = None,
) -> Float[Array, ""]:
    r"""KL divergence ``KL[q(u) || p(u)]`` between Gaussian distributions.

    Cholesky-parameterised variant of
    `dist_kl_divergence` designed for
    GP/SVGP models.  The Cholesky representation avoids explicit covariance
    matrix construction and supports both full and diagonal ``q_sqrt``.
    For lineax-operator covariances, use
    `dist_kl_divergence` instead.

    Computes the KL divergence where:

    - ``q = N(q_mu, q_sqrt @ q_sqrt^T)``
    - ``p = N(0, K)`` or ``p = N(0, I)`` if ``K is None`` (white prior)

    Handles both full and diagonal ``q_sqrt``:

    - **Full** ``q_sqrt``: shape ``(R, M, M)`` — lower-triangular Cholesky
      factors of the variational covariance per output dimension.
    - **Diagonal** ``q_sqrt``: shape ``(M, R)`` — diagonal standard
      deviations.

    Args:
        q_mu: Variational mean, shape ``(M, R)``.
        q_sqrt: Variational Cholesky factor or diagonal std devs.
        K: Prior covariance matrix, shape ``(M, M)``.
            If ``None``, uses white prior (identity).
        solver: Optional solver strategy for structured linear algebra.
            When ``None``, falls back to structural dispatch. This parameter
            is accepted for API consistency but is not currently used by the
            Cholesky decomposition in this function.

    Returns:
        Scalar KL divergence summed over all ``R`` output dimensions.

    See Also:
        `dist_kl_divergence`: General KL
        between two multivariate normals with lineax covariance operators.
    """
    del solver  # cholesky does not accept a solver; parameter reserved for future use
    M = q_mu.shape[0]
    R = q_mu.shape[1]

    is_diagonal = q_sqrt.ndim == 2

    # Prior Cholesky factor
    L_K = (
        cholesky(lx.MatrixLinearOperator(K, lx.positive_semidefinite_tag)).as_matrix()
        if K is not None
        else None
    )

    if is_diagonal:
        # q_sqrt shape: (M, R) — diagonal standard deviations
        q_var = q_sqrt**2  # (M, R)

        if L_K is not None:
            # alpha = L_K^{-1} q_mu  ->  Mahalanobis term
            alpha = jsla.solve_triangular(L_K, q_mu, lower=True)  # (M, R)
            mahal = jnp.sum(alpha**2)

            # tr(K^{-1} diag(q_var_r)) = sum_i q_var[i,r] * (K^{-1})_{ii}
            # diag(K^{-1}) = diag(L_K^{-T} L_K^{-1}) via cho_solve on
            # each standard basis vector — but that's O(M^2) total.
            # More efficient: solve L_K x_i = e_i for each i, then
            # diag(K^{-1})_i = ||x_i||^2.  Use a single batched solve:
            Kinv_diag = jnp.sum(
                jsla.solve_triangular(L_K, jnp.eye(M), lower=True) ** 2,
                axis=0,
            )  # (M,)
            trace_term = jnp.sum(q_var * Kinv_diag[:, None])

            # log|K| − log|S|
            logdet_K = cholesky_logdet(L_K)
            logdet_S = jnp.sum(jnp.log(q_var))
            logdet_diff = R * logdet_K - logdet_S
        else:
            # White prior: K = I
            mahal = jnp.sum(q_mu**2)
            trace_term = jnp.sum(q_var)
            logdet_diff = -jnp.sum(jnp.log(q_var))

        return 0.5 * (logdet_diff - M * R + trace_term + mahal)

    # q_sqrt shape: (R, M, M) — full lower-triangular Cholesky factors
    # Hoist prior-only quantities outside the per-output computation.
    logdet_K = cholesky_logdet(L_K) if L_K is not None else 0.0

    def _kl_single(
        q_mu_r: Float[Array, " M"], L_q_r: Float[Array, "M M"]
    ) -> Float[Array, ""]:
        logdet_q = cholesky_logdet(L_q_r)

        if L_K is not None:
            alpha = jsla.solve_triangular(L_K, q_mu_r, lower=True)
            mahal_r = jnp.sum(alpha**2)
            L_K_inv_L_q = jsla.solve_triangular(L_K, L_q_r, lower=True)
            trace_r = jnp.sum(L_K_inv_L_q**2)
            logdet_diff_r = logdet_K - logdet_q
        else:
            mahal_r = jnp.sum(q_mu_r**2)
            trace_r = jnp.sum(L_q_r**2)
            logdet_diff_r = -logdet_q

        return 0.5 * (logdet_diff_r - M + trace_r + mahal_r)

    # vmap over R output dimensions: q_mu.T -> (R, M), q_sqrt -> (R, M, M)
    kl_per_output = jax.vmap(_kl_single)(q_mu.T, q_sqrt)
    return jnp.sum(kl_per_output)

Cross-validation & diagnostics

LOVE-style cached predictive variances and closed-form leave-one-out cross-validation from a single factorization.

Structured linear algebra and Gaussian primitives for JAX.

LOVECache

Bases: Module

Cached Lanczos factorization for fast predictive variance.

Stores the eigenvector basis Q and inverse eigenvalues such that K^{-1} \approx Q \Lambda^{-1} Q^T.

Attributes:

Name Type Description
Q Float[Array, 'N k']

Lanczos eigenvector basis, shape (N, k).

inv_eigvals Float[Array, ' k']

Inverse eigenvalues 1 / lambda_i, shape (k,).

Source code in src/gaussx/_gp/_love.py
class LOVECache(eqx.Module):
    r"""Cached Lanczos factorization for fast predictive variance.

    Stores the eigenvector basis ``Q`` and inverse eigenvalues such that
    ``K^{-1} \approx Q \Lambda^{-1} Q^T``.

    Attributes:
        Q: Lanczos eigenvector basis, shape ``(N, k)``.
        inv_eigvals: Inverse eigenvalues ``1 / lambda_i``, shape ``(k,)``.
    """

    Q: Float[Array, "N k"]
    inv_eigvals: Float[Array, " k"]

LOOResult

Bases: Module

Result of leave-one-out cross-validation.

Attributes:

Name Type Description
loo_log_likelihood Float[Array, '']

Scalar LOO-CV log-likelihood.

loo_means Float[Array, ' N']

Per-point LOO predictive means, shape (N,).

loo_variances Float[Array, ' N']

Per-point LOO predictive variances, shape (N,).

Source code in src/gaussx/_gp/_loo.py
class LOOResult(eqx.Module):
    """Result of leave-one-out cross-validation.

    Attributes:
        loo_log_likelihood: Scalar LOO-CV log-likelihood.
        loo_means: Per-point LOO predictive means, shape ``(N,)``.
        loo_variances: Per-point LOO predictive variances, shape ``(N,)``.
    """

    loo_log_likelihood: Float[Array, ""]
    loo_means: Float[Array, " N"]
    loo_variances: Float[Array, " N"]

love_cache(K_op: lx.AbstractLinearOperator, lanczos_order: int = 50, key: jax.Array | None = None) -> LOVECache

Precompute Lanczos factorization of K^{-1} for fast variance.

Builds a rank-k approximation K^{-1} \approx Q \Lambda^{-1} Q^T using the symmetric Lanczos algorithm via partial eigendecomposition.

This amortizes the cost of predictive variance: once cached, each test point needs only O(Nk) instead of O(N^2) for a solve.

Parameters:

Name Type Description Default
K_op AbstractLinearOperator

Training kernel operator, shape (N, N). Must be symmetric positive definite.

required
lanczos_order int

Number of Lanczos iterations (rank of approximation). Default 50.

50
key Array | None

PRNG key for the initial random vector. If None, uses jax.random.PRNGKey(0).

None

Returns:

Type Description
LOVECache

A LOVECache object.

Source code in src/gaussx/_gp/_love.py
def love_cache(
    K_op: lx.AbstractLinearOperator,
    lanczos_order: int = 50,
    key: jax.Array | None = None,
) -> LOVECache:
    r"""Precompute Lanczos factorization of ``K^{-1}`` for fast variance.

    Builds a rank-``k`` approximation ``K^{-1} \approx Q \Lambda^{-1} Q^T``
    using the symmetric Lanczos algorithm via partial eigendecomposition.

    This amortizes the cost of predictive variance: once cached, each
    test point needs only ``O(Nk)`` instead of ``O(N^2)`` for a solve.

    Args:
        K_op: Training kernel operator, shape ``(N, N)``. Must be
            symmetric positive definite.
        lanczos_order: Number of Lanczos iterations (rank of approximation).
            Default ``50``.
        key: PRNG key for the initial random vector. If ``None``, uses
            ``jax.random.PRNGKey(0)``.

    Returns:
        A `LOVECache` object.
    """
    inverse_root = root_inv_decomposition(
        K_op,
        rank=lanczos_order,
        method="lanczos",
        key=key,
    ).root
    # For Lanczos inverse roots, recover 1 / λᵢ via ||R⁻[:, i]||² = 1 / λᵢ.
    inv_eigvals = jnp.sum(inverse_root**2, axis=0)
    floor = jnp.finfo(inv_eigvals.dtype).tiny
    Q = inverse_root / jnp.sqrt(jnp.maximum(inv_eigvals, floor))[None, :]
    return LOVECache(Q=Q, inv_eigvals=inv_eigvals)

love_variance(cache: LOVECache, K_star_row: Float[Array, ' N']) -> Float[Array, '']

Fast predictive variance using a LOVE cache.

Computes k_*^T K^{-1} k_* in O(Nk) via the cached Lanczos factorization:

k_*^T K^{-1} k_* \approx k_*^T Q \Lambda^{-1} Q^T k_*
                      = \sum_i (q_i^T k_*)^2 / \lambda_i

The predictive variance for a GP is then:

var_* = k(x_*, x_*) - love\_variance(cache, k_*)

Parameters:

Name Type Description Default
cache LOVECache

A LOVECache from love_cache.

required
K_star_row Float[Array, ' N']

Cross-covariance vector k(X_{train}, x_*), shape (N,).

required

Returns:

Type Description
Float[Array, '']

Scalar k_*^T K^{-1} k_*.

Source code in src/gaussx/_gp/_love.py
def love_variance(
    cache: LOVECache,
    K_star_row: Float[Array, " N"],
) -> Float[Array, ""]:
    r"""Fast predictive variance using a LOVE cache.

    Computes ``k_*^T K^{-1} k_*`` in ``O(Nk)`` via the cached Lanczos
    factorization:

        k_*^T K^{-1} k_* \approx k_*^T Q \Lambda^{-1} Q^T k_*
                              = \sum_i (q_i^T k_*)^2 / \lambda_i

    The predictive variance for a GP is then:

        var_* = k(x_*, x_*) - love\_variance(cache, k_*)

    Args:
        cache: A `LOVECache` from `love_cache`.
        K_star_row: Cross-covariance vector ``k(X_{train}, x_*)``,
            shape ``(N,)``.

    Returns:
        Scalar ``k_*^T K^{-1} k_*``.
    """
    # Project onto eigenvector basis: z = Q^T k_*  -> (k,)
    z = cache.Q.T @ K_star_row
    # Weighted sum: sum_i z_i^2 / lambda_i
    return jnp.sum(z**2 * cache.inv_eigvals)

leave_one_out_cv(operator: lx.AbstractLinearOperator, y: Float[Array, ' N'], *, solver: AbstractSolveStrategy | None = None, diag_inv_method: str = 'solve', diag_inv_num_probes: int = 30, diag_inv_key: jax.Array | None = None) -> LOOResult

LOO-CV via the bordered-system identity.

Computes leave-one-out predictive means, variances, and log-likelihood without refitting the model N times.

Math:

alpha = K_y^{-1} y
mu_LOO_i   = y_i - alpha_i / [K_y^{-1}]_{ii}
sigma^2_LOO_i = 1 / [K_y^{-1}]_{ii}
LOO-CV = -(1/2) sum_i [ log sigma^2_LOO_i
                        + (y_i - mu_LOO_i)^2 / sigma^2_LOO_i
                        + log 2 pi ]

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A linear operator representing the (noise-inclusive) kernel matrix K_y.

required
y Float[Array, ' N']

Observation vector of shape (N,).

required
solver AbstractSolveStrategy | None

Optional solve strategy for computing K_y^{-1} y and for the diag_inv computation. When None, falls back to structural dispatch.

None
diag_inv_method str

Method passed to diag_inv. Defaults to "solve" so the LOO variances remain deterministic.

'solve'
diag_inv_num_probes int

Number of Hutchinson probes when diag_inv_method="hutchinson".

30
diag_inv_key Array | None

PRNG key for probe generation when diag_inv_method="hutchinson".

None

Returns:

Type Description
LOOResult

A LOOResult containing the LOO log-likelihood,

LOOResult

predictive means, and predictive variances.

Source code in src/gaussx/_gp/_loo.py
def leave_one_out_cv(
    operator: lx.AbstractLinearOperator,
    y: Float[Array, " N"],
    *,
    solver: AbstractSolveStrategy | None = None,
    diag_inv_method: str = "solve",
    diag_inv_num_probes: int = 30,
    diag_inv_key: jax.Array | None = None,
) -> LOOResult:
    """LOO-CV via the bordered-system identity.

    Computes leave-one-out predictive means, variances, and
    log-likelihood without refitting the model N times.

    Math:

        alpha = K_y^{-1} y
        mu_LOO_i   = y_i - alpha_i / [K_y^{-1}]_{ii}
        sigma^2_LOO_i = 1 / [K_y^{-1}]_{ii}
        LOO-CV = -(1/2) sum_i [ log sigma^2_LOO_i
                                + (y_i - mu_LOO_i)^2 / sigma^2_LOO_i
                                + log 2 pi ]

    Args:
        operator: A linear operator representing the (noise-inclusive)
            kernel matrix K_y.
        y: Observation vector of shape ``(N,)``.
        solver: Optional solve strategy for computing K_y^{-1} y
            and for the ``diag_inv`` computation. When ``None``,
            falls back to structural dispatch.
        diag_inv_method: Method passed to `diag_inv`. Defaults to
            ``"solve"`` so the LOO variances remain deterministic.
        diag_inv_num_probes: Number of Hutchinson probes when
            ``diag_inv_method="hutchinson"``.
        diag_inv_key: PRNG key for probe generation when
            ``diag_inv_method="hutchinson"``.

    Returns:
        A `LOOResult` containing the LOO log-likelihood,
        predictive means, and predictive variances.
    """
    alpha = dispatch_solve(operator, y, solver)
    diag_Kinv = diag_inv(
        operator,
        method=diag_inv_method,
        num_probes=diag_inv_num_probes,
        key=diag_inv_key,
        solver=solver,
    )

    loo_means = y - alpha / diag_Kinv
    loo_variances = 1.0 / diag_Kinv

    # (y_i - mu_i)^2 / sigma^2_i simplifies to alpha_i^2 / diag_Kinv_i
    loo_ll = -0.5 * jnp.sum(
        jnp.log(loo_variances) + alpha**2 / diag_Kinv + jnp.log(2.0 * jnp.pi)
    )

    return LOOResult(
        loo_log_likelihood=loo_ll,
        loo_means=loo_means,
        loo_variances=loo_variances,
    )

Multi-output projections

The orthogonal instantaneous linear mixing model (OILMM): project multi-output observations into independent latent processes and back.

Structured linear algebra and Gaussian primitives for JAX.

oilmm_project(Y: Float[Array, 'N P'], W: Float[Array, 'P L'], noise_var: Float[Array, ' P'] | float) -> tuple[Float[Array, 'N L'], Float[Array, ' L']]

Project multi-output data to independent latent GPs via OILMM.

Given an orthogonal mixing matrix W ∈ ℝᴾˣᴸ with WᵀW = I_L, projects P-output observations to L independent latent channels:

Y_latent    = Y W              (N, L)
σ²_latent   = (W ⊙ W)ᵀ σ²     (L,)

Parameters:

Name Type Description Default
Y Float[Array, 'N P']

Observations, shape (N, P).

required
W Float[Array, 'P L']

Orthogonal mixing matrix, shape (P, L) with WᵀW = I_L.

required
noise_var Float[Array, ' P'] | float

Observation noise variance. Scalar for isotropic noise, or shape (P,) for heteroscedastic noise.

required

Returns:

Type Description
Float[Array, 'N L']

Tuple (Y_latent, noise_latent) with shapes (N, L)

Float[Array, ' L']

and (L,).

Source code in src/gaussx/_gp/_oilmm.py
def oilmm_project(
    Y: Float[Array, "N P"],
    W: Float[Array, "P L"],
    noise_var: Float[Array, " P"] | float,
) -> tuple[Float[Array, "N L"], Float[Array, " L"]]:
    """Project multi-output data to independent latent GPs via OILMM.

    Given an orthogonal mixing matrix W ∈ ℝᴾˣᴸ with WᵀW = I_L, projects
    P-output observations to L independent latent channels:

        Y_latent    = Y W              (N, L)
        σ²_latent   = (W ⊙ W)ᵀ σ²     (L,)

    Args:
        Y: Observations, shape ``(N, P)``.
        W: Orthogonal mixing matrix, shape ``(P, L)`` with WᵀW = I_L.
        noise_var: Observation noise variance. Scalar for isotropic noise,
            or shape ``(P,)`` for heteroscedastic noise.

    Returns:
        Tuple ``(Y_latent, noise_latent)`` with shapes ``(N, L)``
        and ``(L,)``.
    """
    Y_latent = Y @ W  # (N, L)
    noise_var = jnp.broadcast_to(jnp.asarray(noise_var), (W.shape[0],))  # (P,)
    noise_latent = (W**2).T @ noise_var  # (L,)
    return Y_latent, noise_latent

oilmm_back_project(f_means: Float[Array, 'N L'], f_vars: Float[Array, 'N L'], W: Float[Array, 'P L']) -> tuple[Float[Array, 'N P'], Float[Array, 'N P']]

Back-project latent GP predictions to the observation space.

Reconstructs observation-space predictions via:

y_means = f_means Wᵀ              (N, P)
y_vars  = f_vars (W ⊙ W)ᵀ        (N, P)

Parameters:

Name Type Description Default
f_means Float[Array, 'N L']

Latent predictive means, shape (N, L).

required
f_vars Float[Array, 'N L']

Latent predictive variances, shape (N, L).

required
W Float[Array, 'P L']

Orthogonal mixing matrix, shape (P, L) with WᵀW = I_L.

required

Returns:

Type Description
tuple[Float[Array, 'N P'], Float[Array, 'N P']]

Tuple (y_means, y_vars) with shapes (N, P) and (N, P).

Source code in src/gaussx/_gp/_oilmm.py
def oilmm_back_project(
    f_means: Float[Array, "N L"],
    f_vars: Float[Array, "N L"],
    W: Float[Array, "P L"],
) -> tuple[Float[Array, "N P"], Float[Array, "N P"]]:
    """Back-project latent GP predictions to the observation space.

    Reconstructs observation-space predictions via:

        y_means = f_means Wᵀ              (N, P)
        y_vars  = f_vars (W ⊙ W)ᵀ        (N, P)

    Args:
        f_means: Latent predictive means, shape ``(N, L)``.
        f_vars: Latent predictive variances, shape ``(N, L)``.
        W: Orthogonal mixing matrix, shape ``(P, L)`` with WᵀW = I_L.

    Returns:
        Tuple ``(y_means, y_vars)`` with shapes ``(N, P)`` and ``(N, P)``.
    """
    y_means = f_means @ W.T  # (N, P)
    y_vars = f_vars @ (W**2).T  # (N, P)
    return y_means, y_vars