Gaussian Processes¶
Layer 3 recipes for GP inference: conditioning, whitening, prediction caches, pathwise (Matheron) sampling, variational bounds, and cross-validation — all expressed over covariance operators so structured kernels keep their fast paths end to end. The modelling shell (kernels with hyperparameter priors, NumPyro sites) lives downstream; gaussx owns the math.
Conditioning & prediction¶
The standard posterior
plus a precomputed-cache variant for repeated test-time queries and a Kronecker-structured path for separable kernels on grids.
Structured linear algebra and Gaussian primitives for JAX.
PredictionCache
¶
Bases: Module
Cached training solve for amortized predictions.
Stores alpha = K_y^{-1} y so that downstream predictions only
require a matrix-vector product rather than a fresh solve.
Attributes:
| Name | Type | Description |
|---|---|---|
alpha |
Float[Array, ' N']
|
Solved weights |
Source code in src/gaussx/_gp/_prediction_cache.py
base_conditional(K_mm: Float[Array, 'M M'], K_mn: Float[Array, 'M N'], K_nn: Float[Array, 'N N'] | Float[Array, ' N'], f: Float[Array, 'M R'], *, q_sqrt: Float[Array, 'R M M'] | Float[Array, 'M R'] | None = None, white: bool = False, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, 'N R'], Float[Array, ...]]
¶
Gaussian conditional distribution via Schur complement.
Computes the conditional distribution q(f_* | u) given:
- Prior covariance
K_mmat inducing locations - Cross-covariance
K_mnbetween inducing and test locations - Prior (co)variance
K_nnat test locations (full or diagonal) - Inducing function values
f(or whitened values ifwhite=True) - Optional variational posterior
q(u) = N(f, q_sqrt q_sqrt^T)
The conditional mean is:
mu = K_nm K_mm^{-1} f (or K_nm L_mm^{-T} f if white)
The conditional covariance is:
Sigma = K_nn - K_nm K_mm^{-1} K_mn + K_nm K_mm^{-1} S K_mm^{-1} K_mn
where S = q_sqrt @ q_sqrt^T is the variational covariance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
K_mm
|
Float[Array, 'M M']
|
Prior covariance at inducing points, shape |
required |
K_mn
|
Float[Array, 'M N']
|
Cross-covariance, shape |
required |
K_nn
|
Float[Array, 'N N'] | Float[Array, ' N']
|
Test-point covariance. Full |
required |
f
|
Float[Array, 'M R']
|
Inducing function values, shape |
required |
q_sqrt
|
Float[Array, 'R M M'] | Float[Array, 'M R'] | None
|
Optional variational Cholesky factor.
Full: |
None
|
white
|
bool
|
If |
False
|
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy for structured linear algebra.
When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'N R']
|
|
Float[Array, ...]
|
has shape |
Source code in src/gaussx/_gp/_base_conditional.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 | |
build_prediction_cache(operator: lx.AbstractLinearOperator, y: Float[Array, ' N'], *, solver: AbstractSolveStrategy | None = None) -> PredictionCache
¶
Solve A alpha = y and cache the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
Training covariance operator |
required |
y
|
Float[Array, ' N']
|
Training targets, shape |
required |
solver
|
AbstractSolveStrategy | None
|
Optional solve strategy. When |
None
|
Returns:
| Type | Description |
|---|---|
PredictionCache
|
A |
Source code in src/gaussx/_gp/_prediction_cache.py
predict_mean(cache: PredictionCache, K_cross: Float[Array, 'Nt N']) -> Float[Array, ' Nt']
¶
Predictive mean: mu* = K_*f @ alpha.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cache
|
PredictionCache
|
Prediction cache from |
required |
K_cross
|
Float[Array, 'Nt N']
|
Cross-covariance matrix, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ' Nt']
|
Predictive mean, shape |
Source code in src/gaussx/_gp/_prediction_cache.py
predict_variance(K_cross: Float[Array, 'Nt N'], K_test_diag: Float[Array, ' Nt'], operator: lx.AbstractLinearOperator, *, solver: AbstractSolveStrategy | None = None) -> Float[Array, ' Nt']
¶
Predictive variance: sigma^2* = k_** - diag(K_*f K_y^{-1} K_f*).
For each test point i, solves K_y v_i = K_cross[i, :] and
computes sigma^2_i = K_test_diag[i] - K_cross[i, :] @ v_i.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
K_cross
|
Float[Array, 'Nt N']
|
Cross-covariance matrix, shape |
required |
K_test_diag
|
Float[Array, ' Nt']
|
Prior variance at test points, shape |
required |
operator
|
AbstractLinearOperator
|
Training covariance operator |
required |
solver
|
AbstractSolveStrategy | None
|
Optional solve strategy. When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' Nt']
|
Predictive variance, shape |
Source code in src/gaussx/_gp/_prediction_cache.py
conditional_interpolate(A_fwd: Float[Array, 'd d'], Q_fwd: Float[Array, 'd d'], A_bwd: Float[Array, 'd d'], Q_bwd: Float[Array, 'd d'], mu_prev: Float[Array, ' d'], P_prev: Float[Array, 'd d'], mu_next: Float[Array, ' d'], P_next: Float[Array, 'd d'], *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, ' d'], Float[Array, 'd d']]
¶
Interpolated marginal at time t given posteriors at t^- and t^+.
For an SDE-discretized state-space model:
x_t | x_{t^-} \sim N(A_{fwd} x_{t^-}, Q_{fwd})
x_{t^+} | x_t \sim N(A_{bwd} x_t, Q_{bwd})
computes p(x_t | x_{t^-}, x_{t^+}) using information fusion
of the forward and backward predictions:
\Lambda_{fwd} = (A_{fwd} P_{prev} A_{fwd}^T + Q_{fwd})^{-1}
\Lambda_{bwd} = A_{bwd}^T (P_{next} + Q_{bwd})^{-1} A_{bwd}
\eta_{1,fwd} = \Lambda_{fwd} m_{fwd}
\eta_{1,bwd} = A_{bwd}^T (P_{next} + Q_{bwd})^{-1} \mu_{next}
\Lambda = \Lambda_{fwd} + \Lambda_{bwd}
P = \Lambda^{-1}
\mu = P (\eta_{1,fwd} + \eta_{1,bwd})
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A_fwd
|
Float[Array, 'd d']
|
Forward transition from |
required |
Q_fwd
|
Float[Array, 'd d']
|
Forward process noise, shape |
required |
A_bwd
|
Float[Array, 'd d']
|
Backward transition from |
required |
Q_bwd
|
Float[Array, 'd d']
|
Backward process noise, shape |
required |
mu_prev
|
Float[Array, ' d']
|
Marginal mean at |
required |
P_prev
|
Float[Array, 'd d']
|
Marginal covariance at |
required |
mu_next
|
Float[Array, ' d']
|
Marginal mean at |
required |
P_next
|
Float[Array, 'd d']
|
Marginal covariance at |
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy for structured linear algebra.
When |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ' d'], Float[Array, 'd d']]
|
Tuple |
Source code in src/gaussx/_gp/_interpolation.py
kronecker_posterior_predictive(K_factors: list[lx.AbstractLinearOperator], y: Float[Array, ' N'], noise_var: float, grid_shape: tuple[int, ...], K_cross_factors: list[Float[Array, 'Ni_test Ni_train']], *, K_test_diag_factors: list[Float[Array, ' Ni_test']]) -> tuple[Float[Array, ' N_test'], Float[Array, ' N_test']]
¶
Posterior mean and variance for a Kronecker GP at test points.
Uses the eigendecomposition trick: projects cross-covariances onto the eigenbasis and weights by inverse eigenvalues:
mu_* = K_{*f} (K_{ff} + sigma^2 I)^{-1} y
var_* = k_{**} - K_{*f} (K_{ff} + sigma^2 I)^{-1} K_{f*}
Both computed via per-factor eigendecomposition in
O(sum n_i^3 + N + N_test) time.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
K_factors
|
list[AbstractLinearOperator]
|
List of per-dimension training kernel operators. |
required |
y
|
Float[Array, ' N']
|
Observations, shape |
required |
noise_var
|
float
|
Observation noise variance |
required |
grid_shape
|
tuple[int, ...]
|
Grid shape, e.g. |
required |
K_cross_factors
|
list[Float[Array, 'Ni_test Ni_train']]
|
Per-dimension cross-covariance matrices,
each shape |
required |
K_test_diag_factors
|
list[Float[Array, ' Ni_test']]
|
Per-dimension prior diagonals at the test points,
each shape |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ' N_test'], Float[Array, ' N_test']]
|
Tuple |
Source code in src/gaussx/_gp/_kronecker_gp.py
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | |
kronecker_mll(K_factors: list[lx.AbstractLinearOperator], y: Float[Array, ' N'], noise_var: float, grid_shape: tuple[int, ...]) -> Float[Array, '']
¶
Exact marginal log-likelihood for a Kronecker-structured GP.
For a GP with covariance K = K_1 \otimes K_2 \otimes \ldots + sigma^2 I,
computes the log marginal likelihood via per-factor eigendecomposition:
log p(y) = -0.5 * (y^T (K + sigma^2 I)^{-1} y
+ log|K + sigma^2 I|
+ N log(2 pi))
The Kronecker eigendecomposition avoids forming the full N x N matrix:
if K_i = Q_i Lambda_i Q_i^T, the combined eigenvalues are the outer
products of per-factor eigenvalues and the combined eigenvectors are the
Kronecker product of per-factor eigenvectors.
Complexity: O(sum n_i^3 + N) instead of O(N^3) where N = prod n_i.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
K_factors
|
list[AbstractLinearOperator]
|
List of per-dimension kernel operators. Each must be square and symmetric. |
required |
y
|
Float[Array, ' N']
|
Observations, shape |
required |
noise_var
|
float
|
Observation noise variance |
required |
grid_shape
|
tuple[int, ...]
|
Shape of the grid, e.g. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar log marginal likelihood. |
Source code in src/gaussx/_gp/_kronecker_gp.py
Pathwise sampling¶
Matheron's rule turns joint prior draws \((a, b)\) into posterior draws: \(a + \mathrm{Cov}(a,b)\,\mathrm{Cov}(b,b)^{-1}(\beta - b)\).
Structured linear algebra and Gaussian primitives for JAX.
matheron_update(prior_sample_target: Float[Array, 'S N_star'], prior_sample_conditioning: Float[Array, 'S M'], observed_value: Float[Array, ' M'], cross_covariance: lx.AbstractLinearOperator, conditioning_covariance: lx.AbstractLinearOperator, *, solver: AbstractSolveStrategy | None = None) -> Float[Array, 'S N_star']
¶
Posterior samples via Matheron's-rule correction.
Given joint prior draws (a, b) and an observed conditioning value
β, Matheron's rule samples from a | b = β by applying
This helper keeps both covariance arguments as lineax operators, so the conditioning solve uses the existing GaussX structural dispatch and the target correction is a rectangular matvec.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prior_sample_target
|
Float[Array, 'S N_star']
|
Prior samples at target points, shape
|
required |
prior_sample_conditioning
|
Float[Array, 'S M']
|
Joint prior samples at conditioning
points, shape |
required |
observed_value
|
Float[Array, ' M']
|
Observed conditioning value, shape |
required |
cross_covariance
|
AbstractLinearOperator
|
Cross-covariance operator |
required |
conditioning_covariance
|
AbstractLinearOperator
|
Conditioning covariance operator
|
required |
solver
|
AbstractSolveStrategy | None
|
Optional solver strategy for the conditioning solve
(e.g. |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'S N_star']
|
Corrected posterior samples, shape |
Source code in src/gaussx/_gp/_matheron.py
Whitening¶
The whitened parameterization \(u = Lv\), \(v \sim \mathcal{N}(0, I)\) that keeps sparse-variational optimization well-conditioned, and the whitened SVGP predictive that consumes it.
Structured linear algebra and Gaussian primitives for JAX.
whiten_covariance = unwhiten_covariance
module-attribute
¶
unwhiten(m_tilde: Float[Array, ' M'], L: lx.AbstractLinearOperator) -> Float[Array, ' M']
¶
Unwhiten variational mean: m = L @ m_tilde.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
m_tilde
|
Float[Array, ' M']
|
Whitened mean vector, shape |
required |
L
|
AbstractLinearOperator
|
Cholesky factor, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ' M']
|
Unwhitened mean m, shape |
Source code in src/gaussx/_gp/_unwhiten.py
unwhiten_covariance(L: lx.AbstractLinearOperator, S_tilde: lx.AbstractLinearOperator) -> lx.MatrixLinearOperator
¶
Unwhiten variational covariance: S = L S̃ Lᵀ.
Delegates to cov_transform.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
L
|
AbstractLinearOperator
|
Cholesky factor, shape |
required |
S_tilde
|
AbstractLinearOperator
|
Whitened variational covariance, shape |
required |
Returns:
| Type | Description |
|---|---|
MatrixLinearOperator
|
Unwhitened covariance operator S. |
Source code in src/gaussx/_gp/_unwhiten.py
whitened_svgp_predict(K_zz_op: lx.AbstractLinearOperator, K_xz: Float[Array, 'N M'], u_mean: Float[Array, ' M'], u_chol: Float[Array, 'M M'], K_xx_diag: Float[Array, ' N']) -> tuple[Float[Array, ' N'], Float[Array, ' N']]
¶
Whitened SVGP prediction: mean and variance at test points.
Computes the predictive mean and variance for a sparse variational GP using the whitened parameterization:
L_{zz} = cholesky(K_{zz})
A = L_{zz}^{-1} K_{zx} (triangular solve)
f_{loc} = A^T u_{mean}
Q_{xx} = sum(A^2, axis=0) (prior variance reduction)
W = u_{chol}^T A
S_{contrib} = sum(W^2, axis=0) (posterior variance contribution)
f_{var} = K_{xx,diag} - Q_{xx} + S_{contrib}
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
K_zz_op
|
AbstractLinearOperator
|
Inducing-point covariance operator, shape |
required |
K_xz
|
Float[Array, 'N M']
|
Cross-covariance matrix, shape |
required |
u_mean
|
Float[Array, ' M']
|
Whitened variational mean, shape |
required |
u_chol
|
Float[Array, 'M M']
|
Whitened variational Cholesky factor, shape |
required |
K_xx_diag
|
Float[Array, ' N']
|
Prior diagonal variances at test points, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ' N']
|
Tuple |
Float[Array, ' N']
|
predictive variance shape |
Source code in src/gaussx/_gp/_svgp.py
svgp_variance_adjustment(K_zz_op: lx.AbstractLinearOperator, S_u: lx.AbstractLinearOperator) -> lx.AbstractLinearOperator
¶
Compute the SVGP variance adjustment operator.
Builds the operator Q = K_{zz}^{-1} S_u K_{zz}^{-1} - K_{zz}^{-1}
which appears in every sparse GP predictive variance computation:
Var[f_*] = k_{**} - k_{*z} (K_{zz}^{-1} - Q) k_{z*}
The returned value is exposed as a linear operator, but the current
implementation materializes dense (M, M) intermediates while building
it.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
K_zz_op
|
AbstractLinearOperator
|
Inducing-point covariance operator, shape |
required |
S_u
|
AbstractLinearOperator
|
Variational covariance operator, shape |
required |
Returns:
| Type | Description |
|---|---|
AbstractLinearOperator
|
Operator |
AbstractLinearOperator
|
|
Source code in src/gaussx/_gp/_svgp_variance.py
Variational bounds & KL¶
ELBOs for Gaussian and Monte-Carlo variational families, the collapsed (Titsias) sparse bound, and the Gaussian-to-Gaussian KL term.
Structured linear algebra and Gaussian primitives for JAX.
variational_elbo_gaussian(y: Float[Array, ' N'], f_loc: Float[Array, ' N'], f_var: Float[Array, ' N'], noise_var: float, kl: Float[Array, '']) -> Float[Array, '']
¶
Titsias collapsed ELBO for Gaussian likelihoods.
Computes:
ELBO = E_q[log p(y|f)] - KL(q||p)
where the expected log-likelihood under a Gaussian variational distribution with diagonal variance has the closed form:
E_q[log N(y|f, sigma^2 I)]
= -0.5 * N * log(2 pi sigma^2)
-0.5 / sigma^2 * (||y - f_loc||^2 + sum(f_var))
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y
|
Float[Array, ' N']
|
Observations, shape |
required |
f_loc
|
Float[Array, ' N']
|
Variational mean, shape |
required |
f_var
|
Float[Array, ' N']
|
Variational marginal variances, shape |
required |
noise_var
|
float
|
Observation noise variance (scalar). |
required |
kl
|
Float[Array, '']
|
KL divergence term |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar ELBO value. |
Source code in src/gaussx/_gp/_elbo.py
variational_elbo_mc(log_likelihood_fn: Callable[[Float[Array, ' N']], Float[Array, '']], f_samples: Float[Array, 'S N'], kl: Float[Array, '']) -> Float[Array, '']
¶
Monte Carlo ELBO for non-conjugate likelihoods.
Computes:
ELBO = (1/S) sum_s log p(y|f_s) - KL(q||p)
where f_s ~ q(f) are samples from the variational distribution.
Supports any likelihood (Poisson, Bernoulli, Pareto, etc.).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
log_likelihood_fn
|
Callable[[Float[Array, ' N']], Float[Array, '']]
|
Function mapping latent samples to scalar
log-likelihood. Signature |
required |
f_samples
|
Float[Array, 'S N']
|
Samples from the variational posterior, shape
|
required |
kl
|
Float[Array, '']
|
KL divergence term |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar ELBO value. |
Source code in src/gaussx/_gp/_elbo.py
collapsed_elbo(y: Float[Array, ' N'], K_diag: Float[Array, ' N'], K_xz: Float[Array, 'N M'], K_zz: Float[Array, 'M M'], noise_var: float, *, jitter: float = 1e-06, solver: AbstractSolverStrategy | None = None) -> Float[Array, '']
¶
Collapsed ELBO (Titsias bound) for sparse GP regression.
Computes the variational lower bound on the log marginal likelihood using the matrix determinant lemma for O(NM² + M³) cost:
ELBO = log 𝒩(y | 0, Q_ff + σ²I) − ½σ⁻² tr(K_ff − Q_ff)
where Q_ff = K_xz K_zz⁻¹ K_xzᵀ is the Nyström approximation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y
|
Float[Array, ' N']
|
Observations, shape |
required |
K_diag
|
Float[Array, ' N']
|
Diagonal of full kernel matrix K_ff, shape |
required |
K_xz
|
Float[Array, 'N M']
|
Cross-covariance between data and inducing points,
shape |
required |
K_zz
|
Float[Array, 'M M']
|
Inducing point kernel matrix, shape |
required |
noise_var
|
float
|
Observation noise variance σ² (scalar). |
required |
jitter
|
float
|
Diagonal jitter for numerical stability in Cholesky decomposition of K_zz. |
1e-06
|
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy for structured linear algebra.
When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar ELBO value. |
Source code in src/gaussx/_gp/_collapsed_elbo.py
gauss_kl(q_mu: Float[Array, 'M R'], q_sqrt: Float[Array, 'R M M'] | Float[Array, 'M R'], K: Float[Array, 'M M'] | None = None, *, solver: AbstractSolverStrategy | None = None) -> Float[Array, '']
¶
KL divergence KL[q(u) || p(u)] between Gaussian distributions.
Cholesky-parameterised variant of
dist_kl_divergence designed for
GP/SVGP models. The Cholesky representation avoids explicit covariance
matrix construction and supports both full and diagonal q_sqrt.
For lineax-operator covariances, use
dist_kl_divergence instead.
Computes the KL divergence where:
q = N(q_mu, q_sqrt @ q_sqrt^T)p = N(0, K)orp = N(0, I)ifK is None(white prior)
Handles both full and diagonal q_sqrt:
- Full
q_sqrt: shape(R, M, M)— lower-triangular Cholesky factors of the variational covariance per output dimension. - Diagonal
q_sqrt: shape(M, R)— diagonal standard deviations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q_mu
|
Float[Array, 'M R']
|
Variational mean, shape |
required |
q_sqrt
|
Float[Array, 'R M M'] | Float[Array, 'M R']
|
Variational Cholesky factor or diagonal std devs. |
required |
K
|
Float[Array, 'M M'] | None
|
Prior covariance matrix, shape |
None
|
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy for structured linear algebra.
When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar KL divergence summed over all |
See Also
dist_kl_divergence: General KL
between two multivariate normals with lineax covariance operators.
Source code in src/gaussx/_gp/_gauss_kl.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | |
Cross-validation & diagnostics¶
LOVE-style cached predictive variances and closed-form leave-one-out cross-validation from a single factorization.
Structured linear algebra and Gaussian primitives for JAX.
LOVECache
¶
Bases: Module
Cached Lanczos factorization for fast predictive variance.
Stores the eigenvector basis Q and inverse eigenvalues such that
K^{-1} \approx Q \Lambda^{-1} Q^T.
Attributes:
| Name | Type | Description |
|---|---|---|
Q |
Float[Array, 'N k']
|
Lanczos eigenvector basis, shape |
inv_eigvals |
Float[Array, ' k']
|
Inverse eigenvalues |
Source code in src/gaussx/_gp/_love.py
LOOResult
¶
Bases: Module
Result of leave-one-out cross-validation.
Attributes:
| Name | Type | Description |
|---|---|---|
loo_log_likelihood |
Float[Array, '']
|
Scalar LOO-CV log-likelihood. |
loo_means |
Float[Array, ' N']
|
Per-point LOO predictive means, shape |
loo_variances |
Float[Array, ' N']
|
Per-point LOO predictive variances, shape |
Source code in src/gaussx/_gp/_loo.py
love_cache(K_op: lx.AbstractLinearOperator, lanczos_order: int = 50, key: jax.Array | None = None) -> LOVECache
¶
Precompute Lanczos factorization of K^{-1} for fast variance.
Builds a rank-k approximation K^{-1} \approx Q \Lambda^{-1} Q^T
using the symmetric Lanczos algorithm via partial eigendecomposition.
This amortizes the cost of predictive variance: once cached, each
test point needs only O(Nk) instead of O(N^2) for a solve.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
K_op
|
AbstractLinearOperator
|
Training kernel operator, shape |
required |
lanczos_order
|
int
|
Number of Lanczos iterations (rank of approximation).
Default |
50
|
key
|
Array | None
|
PRNG key for the initial random vector. If |
None
|
Returns:
| Type | Description |
|---|---|
LOVECache
|
A |
Source code in src/gaussx/_gp/_love.py
love_variance(cache: LOVECache, K_star_row: Float[Array, ' N']) -> Float[Array, '']
¶
Fast predictive variance using a LOVE cache.
Computes k_*^T K^{-1} k_* in O(Nk) via the cached Lanczos
factorization:
k_*^T K^{-1} k_* \approx k_*^T Q \Lambda^{-1} Q^T k_*
= \sum_i (q_i^T k_*)^2 / \lambda_i
The predictive variance for a GP is then:
var_* = k(x_*, x_*) - love\_variance(cache, k_*)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cache
|
LOVECache
|
A |
required |
K_star_row
|
Float[Array, ' N']
|
Cross-covariance vector |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar |
Source code in src/gaussx/_gp/_love.py
leave_one_out_cv(operator: lx.AbstractLinearOperator, y: Float[Array, ' N'], *, solver: AbstractSolveStrategy | None = None, diag_inv_method: str = 'solve', diag_inv_num_probes: int = 30, diag_inv_key: jax.Array | None = None) -> LOOResult
¶
LOO-CV via the bordered-system identity.
Computes leave-one-out predictive means, variances, and log-likelihood without refitting the model N times.
Math:
alpha = K_y^{-1} y
mu_LOO_i = y_i - alpha_i / [K_y^{-1}]_{ii}
sigma^2_LOO_i = 1 / [K_y^{-1}]_{ii}
LOO-CV = -(1/2) sum_i [ log sigma^2_LOO_i
+ (y_i - mu_LOO_i)^2 / sigma^2_LOO_i
+ log 2 pi ]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
A linear operator representing the (noise-inclusive) kernel matrix K_y. |
required |
y
|
Float[Array, ' N']
|
Observation vector of shape |
required |
solver
|
AbstractSolveStrategy | None
|
Optional solve strategy for computing K_y^{-1} y
and for the |
None
|
diag_inv_method
|
str
|
Method passed to |
'solve'
|
diag_inv_num_probes
|
int
|
Number of Hutchinson probes when
|
30
|
diag_inv_key
|
Array | None
|
PRNG key for probe generation when
|
None
|
Returns:
| Type | Description |
|---|---|
LOOResult
|
A |
LOOResult
|
predictive means, and predictive variances. |
Source code in src/gaussx/_gp/_loo.py
Multi-output projections¶
The orthogonal instantaneous linear mixing model (OILMM): project multi-output observations into independent latent processes and back.
Structured linear algebra and Gaussian primitives for JAX.
oilmm_project(Y: Float[Array, 'N P'], W: Float[Array, 'P L'], noise_var: Float[Array, ' P'] | float) -> tuple[Float[Array, 'N L'], Float[Array, ' L']]
¶
Project multi-output data to independent latent GPs via OILMM.
Given an orthogonal mixing matrix W ∈ ℝᴾˣᴸ with WᵀW = I_L, projects P-output observations to L independent latent channels:
Y_latent = Y W (N, L)
σ²_latent = (W ⊙ W)ᵀ σ² (L,)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
Y
|
Float[Array, 'N P']
|
Observations, shape |
required |
W
|
Float[Array, 'P L']
|
Orthogonal mixing matrix, shape |
required |
noise_var
|
Float[Array, ' P'] | float
|
Observation noise variance. Scalar for isotropic noise,
or shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N L']
|
Tuple |
Float[Array, ' L']
|
and |
Source code in src/gaussx/_gp/_oilmm.py
oilmm_back_project(f_means: Float[Array, 'N L'], f_vars: Float[Array, 'N L'], W: Float[Array, 'P L']) -> tuple[Float[Array, 'N P'], Float[Array, 'N P']]
¶
Back-project latent GP predictions to the observation space.
Reconstructs observation-space predictions via:
y_means = f_means Wᵀ (N, P)
y_vars = f_vars (W ⊙ W)ᵀ (N, P)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f_means
|
Float[Array, 'N L']
|
Latent predictive means, shape |
required |
f_vars
|
Float[Array, 'N L']
|
Latent predictive variances, shape |
required |
W
|
Float[Array, 'P L']
|
Orthogonal mixing matrix, shape |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, 'N P'], Float[Array, 'N P']]
|
Tuple |