Quadrature & Moment Matching¶
Gaussian integration: \(\mathbb{E}_{x \sim \mathcal{N}(\mu, \Sigma)}[f(x)]\) via
deterministic rules (Gauss-Hermite, unscented / cubature, Taylor) or Monte
Carlo, behind one AbstractIntegrator interface. Everything that needs an
expectation — expected log-likelihoods, EP tilted moments, uncertain-input GP
predictions — takes an integrator argument, so swapping the rule never touches
the model.
State & integrators¶
GaussianState pairs a mean with a covariance operator; integrators
propagate functions of it.
Structured linear algebra and Gaussian primitives for JAX.
GaussianState
¶
Bases: Module
Gaussian distribution as (mean, covariance operator) pair.
Attributes:
| Name | Type | Description |
|---|---|---|
mean |
Float[Array, ' N']
|
Mean vector, shape |
cov |
AbstractLinearOperator
|
Covariance operator, shape |
Source code in src/gaussx/_quadrature/_types.py
PropagationResult
¶
Bases: Module
Output of uncertainty propagation through a nonlinear function.
Attributes:
| Name | Type | Description |
|---|---|---|
state |
GaussianState
|
Output Gaussian distribution. |
cross_cov |
Float[Array, 'N_in N_out'] | None
|
Input-output cross-covariance, shape |
Source code in src/gaussx/_quadrature/_types.py
AbstractIntegrator
¶
Bases: Module
Protocol for Gaussian integral approximation.
Subclasses implement integrate to propagate a Gaussian through
a nonlinear function, returning an approximate output distribution
and (optionally) input-output cross-covariance.
Source code in src/gaussx/_quadrature/_integrator.py
integrate(fn: Callable[[Float[Array, ' N']], Float[Array, ' M']], state: GaussianState) -> PropagationResult
abstractmethod
¶
Propagate a Gaussian through fn, returning output moments.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[[Float[Array, ' N']], Float[Array, ' M']]
|
Nonlinear function mapping |
required |
state
|
GaussianState
|
Input Gaussian distribution. |
required |
Returns:
| Type | Description |
|---|---|
PropagationResult
|
|
PropagationResult
|
cross-covariance. |
Source code in src/gaussx/_quadrature/_integrator.py
GaussHermiteIntegrator
¶
Bases: AbstractIntegrator
Gauss-Hermite quadrature integrator.
Approximates Gaussian expectations using tensor-product Gauss-Hermite quadrature:
E[g(f)] \approx \sum_i w_i \cdot g(\mu + L z_i)
where (z_i, w_i) are GH points/weights in standard normal space
and L is the square root of the covariance.
Exact for polynomials up to degree 2 * order - 1.
Complexity: O(order^dim), practical for dim <= ~5.
Attributes:
| Name | Type | Description |
|---|---|---|
order |
int
|
Number of quadrature points per dimension. Default |
Source code in src/gaussx/_quadrature/_gauss_hermite.py
integrate(fn: Callable[[Float[Array, ' N']], Float[Array, ' M']], state: GaussianState) -> PropagationResult
¶
Propagate Gaussian via Gauss-Hermite quadrature.
Source code in src/gaussx/_quadrature/_gauss_hermite.py
UnscentedIntegrator
¶
Bases: AbstractIntegrator
Unscented transform: deterministic sigma points.
Generates 2N+1 sigma points around the mean, propagates them
through the nonlinear function, and reconstructs output moments:
chi_i = mu + sqrt((N + lambda) * Sigma) @ xi_i
y_i = f(chi_i)
mu_y = sum(w_m * y_i)
Sigma_y = sum(w_c * (y_i - mu_y)(y_i - mu_y)^T)
cross_cov = sum(w_c * (chi_i - mu)(y_i - mu_y)^T)
where lambda = alpha^2 * (N + kappa) - N.
Attributes:
| Name | Type | Description |
|---|---|---|
alpha |
float
|
Spread parameter. Default |
beta |
float
|
Prior knowledge parameter (2.0 optimal for Gaussian). |
kappa |
float
|
Secondary scaling. Default |
Source code in src/gaussx/_quadrature/_unscented.py
integrate(fn: Callable[[Float[Array, ' N']], Float[Array, ' M']], state: GaussianState) -> PropagationResult
¶
Propagate Gaussian via unscented transform.
Source code in src/gaussx/_quadrature/_unscented.py
TaylorIntegrator
¶
Bases: AbstractIntegrator
1st or 2nd order Taylor expansion for uncertainty propagation.
1st order (EKF):
mu_y = f(mu_x)
Sigma_y = J @ Sigma_x @ J^T
cross_cov = Sigma_x @ J^T
2nd order:
mu_y_i += 0.5 * tr(H_i @ Sigma_x)
Sigma_y += correction from Hessians
Attributes:
| Name | Type | Description |
|---|---|---|
order |
int
|
Taylor expansion order (1 or 2). Default 1. |
correct_variance |
bool
|
If True and order=2, apply 2nd-order covariance
correction using 4th Gaussian moments. Default True to preserve
the historical |
Source code in src/gaussx/_quadrature/_taylor.py
integrate(fn: Callable[[Float[Array, ' N']], Float[Array, ' M']], state: GaussianState) -> PropagationResult
¶
Propagate Gaussian via Taylor expansion.
Source code in src/gaussx/_quadrature/_taylor.py
MonteCarloIntegrator
¶
Bases: AbstractIntegrator
Monte Carlo moment matching: sample, propagate, compute moments.
Propagates uncertainty by drawing samples from the input Gaussian, evaluating the function at each sample, and computing empirical output moments:
x_i ~ N(mu, Sigma) (n_samples points)
y_i = f(x_i)
mu_y = mean(y_i)
Sigma_y = cov(y_i) + regularization * I
cross_cov = cov(x_i, y_i)
Attributes:
| Name | Type | Description |
|---|---|---|
n_samples |
int
|
Number of Monte Carlo samples. Default |
regularization |
float
|
Diagonal jitter for numerical stability. |
key |
Array | None
|
PRNG key. If |
Source code in src/gaussx/_quadrature/_monte_carlo.py
integrate(fn: Callable[[Float[Array, ' N']], Float[Array, ' M']], state: GaussianState) -> PropagationResult
¶
Propagate Gaussian via Monte Carlo sampling.
Source code in src/gaussx/_quadrature/_monte_carlo.py
AssumedDensityFilter
¶
Bases: AbstractIntegrator
KL-optimal Gaussian projection via moment matching.
Projects the (possibly non-Gaussian) output distribution onto the
Gaussian family by matching first and second moments. Equivalent to
argmin_q KL(p(y) || q(y)) within the Gaussian family.
Adds adaptive regularization and optional diagnostics for detecting non-Gaussianity:
eps = eps_base * trace(Sigma_y) / n_dim
Attributes:
| Name | Type | Description |
|---|---|---|
n_samples |
int
|
Number of Monte Carlo samples. Default |
regularization |
float
|
Base regularization. Default |
adaptive_regularization |
bool
|
Scale regularization by output
variance. Default |
key |
Array | None
|
PRNG key. If |
Source code in src/gaussx/_quadrature/_adf.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | |
integrate(fn: Callable[[Float[Array, ' N']], Float[Array, ' M']], state: GaussianState) -> PropagationResult
¶
Propagate Gaussian via assumed density filtering.
Source code in src/gaussx/_quadrature/_adf.py
integrate_with_diagnostics(fn: Callable[[Float[Array, ' N']], Float[Array, ' M']], state: GaussianState) -> tuple[PropagationResult, dict]
¶
Propagate Gaussian and return non-Gaussianity diagnostics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[[Float[Array, ' N']], Float[Array, ' M']]
|
Nonlinear function mapping |
required |
state
|
GaussianState
|
Input Gaussian distribution. |
required |
Returns:
| Type | Description |
|---|---|
PropagationResult
|
Tuple |
dict
|
|
tuple[PropagationResult, dict]
|
|
Source code in src/gaussx/_quadrature/_adf.py
Quadrature rules¶
The raw point sets behind the integrators, for when a recipe needs direct control.
Structured linear algebra and Gaussian primitives for JAX.
gauss_hermite_points(order: int, dim: int) -> tuple[Float[Array, 'P D'], Float[Array, ' P']]
¶
Gauss-Hermite quadrature points and weights.
Generates tensor-product Gauss-Hermite quadrature points for
integrating functions against a standard Gaussian measure
(probabilists' Hermite polynomials).
P = order^dim total points.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
order
|
int
|
Number of quadrature points per dimension. |
required |
dim
|
int
|
Dimensionality of the integration domain. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'P D']
|
Tuple |
Float[Array, ' P']
|
|
Source code in src/gaussx/_quadrature/_quadrature.py
cubature_points(mean: Float[Array, ' N'], cov: lx.AbstractLinearOperator) -> tuple[Float[Array, 'P N'], Float[Array, ' P']]
¶
Spherical-radial cubature points and weights.
Generates 2N cubature points with equal weights 1/(2N).
This is the cubature Kalman filter (CKF) point set.
Uses gaussx.sqrt(cov) for structured square root dispatch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mean
|
Float[Array, ' N']
|
Mean vector, shape |
required |
cov
|
AbstractLinearOperator
|
Covariance operator, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'P N']
|
Tuple |
Float[Array, ' P']
|
|
tuple[Float[Array, 'P N'], Float[Array, ' P']]
|
|
Source code in src/gaussx/_quadrature/_quadrature.py
sigma_points(mean: Float[Array, ' N'], cov: lx.AbstractLinearOperator, alpha: float = 0.001, beta: float = 2.0, kappa: float = 0.0) -> tuple[Float[Array, 'P N'], Float[Array, ' P'], Float[Array, ' P']]
¶
Unscented transform sigma points and weights.
Generates 2N+1 deterministic sigma points for a Gaussian with
the given mean and covariance, using the scaled unscented transform.
Uses gaussx.sqrt(cov) for structured square root dispatch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mean
|
Float[Array, ' N']
|
Mean vector, shape |
required |
cov
|
AbstractLinearOperator
|
Covariance operator, shape |
required |
alpha
|
float
|
Spread parameter. Controls how far sigma points are
from the mean. Default |
0.001
|
beta
|
float
|
Prior distribution parameter. |
2.0
|
kappa
|
float
|
Secondary scaling parameter. Default |
0.0
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'P N']
|
Tuple |
Float[Array, ' P']
|
|
Float[Array, ' P']
|
|
tuple[Float[Array, 'P N'], Float[Array, ' P'], Float[Array, ' P']]
|
|
Source code in src/gaussx/_quadrature/_quadrature.py
Likelihoods¶
Observation models with quadrature-friendly log_prob surfaces, shared by the
expectation helpers and the SSM / CVI recipes.
Structured linear algebra and Gaussian primitives for JAX.
AbstractLikelihood
¶
Bases: Module
Base class for likelihood functions with optional analytical ELL.
Subclasses that support closed-form expected log-likelihood under
a Gaussian variational distribution should override
has_analytical_ell to return True and implement
analytical_expected_log_likelihood.
Source code in src/gaussx/_quadrature/_likelihood.py
log_prob(f: Float[Array, ' N']) -> Float[Array, '']
abstractmethod
¶
Evaluate log p(y | f) for fixed observations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
Float[Array, ' N']
|
Latent function values, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar log-likelihood. |
Source code in src/gaussx/_quadrature/_likelihood.py
has_analytical_ell() -> bool
¶
analytical_expected_log_likelihood(q_mu: Float[Array, ' N'], q_cov: lx.AbstractLinearOperator) -> Float[Array, '']
¶
Closed-form E_q[log p(y | f)] where q = N(q_mu, q_cov).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q_mu
|
Float[Array, ' N']
|
Variational mean, shape |
required |
q_cov
|
AbstractLinearOperator
|
Variational covariance operator, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar expected log-likelihood. |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
If no analytical form exists. |
Source code in src/gaussx/_quadrature/_likelihood.py
GaussianLikelihood
¶
Bases: AbstractLikelihood
Gaussian likelihood log N(y | f, noise_var * I).
Supports closed-form expected log-likelihood:
E_q[log N(y | f, \sigma^2 I)]
= log N(y | q_\mu, \sigma^2 I)
- 0.5 / \sigma^2 \cdot tr(q_{cov})
Attributes:
| Name | Type | Description |
|---|---|---|
y |
Float[Array, ' N']
|
Observed targets, shape |
noise_var |
float
|
Observation noise variance (scalar). |
Source code in src/gaussx/_quadrature/_likelihood.py
log_prob(f: Float[Array, ' N']) -> Float[Array, '']
¶
Evaluate log N(y | f, noise_var * I).
Source code in src/gaussx/_quadrature/_likelihood.py
has_analytical_ell() -> bool
¶
analytical_expected_log_likelihood(q_mu: Float[Array, ' N'], q_cov: lx.AbstractLinearOperator) -> Float[Array, '']
¶
Closed-form E_q[log N(y | f, \sigma^2 I)].
Uses:
E_q[log N(y|f,R)] = log N(y | q_mu, R) - 0.5 tr(R^{-1} q_cov)
where R = noise_var * I. Delegates the log-density term to
gaussx.gaussian_log_prob (which exploits the diagonal
noise structure) and computes the trace correction directly via
the structural trace(q_cov) / noise_var shortcut, so
Kronecker/BlockDiag-structured q_cov keeps its O(n)
prod(trace_factor) / per-block trace fast paths instead
of materializing through trace_product(R^{-1}, q_cov).
Source code in src/gaussx/_quadrature/_likelihood.py
HeteroscedasticGaussianLikelihood
¶
Bases: AbstractLikelihood
Heteroscedastic Gaussian likelihood with input-dependent noise.
Attributes:
| Name | Type | Description |
|---|---|---|
y |
Float[Array, ' N']
|
Observations, shape |
Source code in src/gaussx/_quadrature/_likelihoods.py
log_prob(f: Float[Array, ' 2N']) -> Float[Array, '']
¶
Evaluate heteroscedastic Gaussian log-likelihood.
Source code in src/gaussx/_quadrature/_likelihoods.py
BernoulliLikelihood
¶
Bases: AbstractLikelihood
Bernoulli likelihood with logit link.
Attributes:
| Name | Type | Description |
|---|---|---|
y |
Float[Array, ' N']
|
Binary observations, shape |
Source code in src/gaussx/_quadrature/_likelihoods.py
log_prob(f: Float[Array, ' N']) -> Float[Array, '']
¶
Evaluate Bernoulli log-likelihood with logit link.
PoissonLikelihood
¶
Bases: AbstractLikelihood
Poisson likelihood with log link.
Attributes:
| Name | Type | Description |
|---|---|---|
y |
Float[Array, ' N']
|
Count observations, shape |
Source code in src/gaussx/_quadrature/_likelihoods.py
log_prob(f: Float[Array, ' N']) -> Float[Array, '']
¶
Evaluate Poisson log-likelihood with log link.
SoftmaxLikelihood
¶
Bases: AbstractLikelihood
Softmax (categorical) likelihood for multi-class classification.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y
|
Int[Array, ' N']
|
Integer class labels, shape |
required |
num_classes
|
int
|
Number of classes C. |
required |
Source code in src/gaussx/_quadrature/_likelihoods.py
log_prob(f: Float[Array, ' NC']) -> Float[Array, '']
¶
Evaluate softmax log-likelihood.
Source code in src/gaussx/_quadrature/_likelihoods.py
StudentTLikelihood
¶
Bases: AbstractLikelihood
Student-t likelihood for robust regression.
Attributes:
| Name | Type | Description |
|---|---|---|
y |
Float[Array, ' N']
|
Observations, shape |
df |
float
|
Degrees of freedom (> 0). |
scale |
float
|
Scale parameter (> 0). |
Source code in src/gaussx/_quadrature/_likelihoods.py
log_prob(f: Float[Array, ' N']) -> Float[Array, '']
¶
Evaluate Student-t log-likelihood.
Source code in src/gaussx/_quadrature/_likelihoods.py
Expectations & EP moments¶
Expected log-likelihoods (the ELL term of every ELBO), generic mean / gradient / cost expectations, and the tilted-moment matching at the heart of expectation propagation.
Structured linear algebra and Gaussian primitives for JAX.
elbo(likelihood: AbstractLikelihood, state: GaussianState, kl: Float[Array, ''], integrator: AbstractIntegrator | None = None) -> Float[Array, '']
¶
Evidence lower bound (ELBO).
Computes:
ELBO = E_q[log p(y | f)] - KL(q || p)
Dispatches to analytical expected log-likelihood when available (e.g. Gaussian likelihood), or uses numerical integration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
likelihood
|
AbstractLikelihood
|
Likelihood object with |
required |
state
|
GaussianState
|
Variational Gaussian distribution |
required |
kl
|
Float[Array, '']
|
KL divergence |
required |
integrator
|
AbstractIntegrator | None
|
Integration method for non-conjugate likelihoods. Ignored when the likelihood has an analytical fast path. |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar ELBO value. |
Source code in src/gaussx/_quadrature/_expectations.py
expected_log_likelihood(likelihood: AbstractLikelihood, state: GaussianState, integrator: AbstractIntegrator | None = None) -> Float[Array, '']
¶
Unified expected log-likelihood with analytical dispatch.
Computes E_q[log p(y | f)] where q = N(mu, Sigma).
If the likelihood has a closed-form expected log-likelihood
(e.g. GaussianLikelihood), it is used directly without
an integrator. Otherwise, an integrator must be provided
for numerical approximation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
likelihood
|
AbstractLikelihood
|
Likelihood object with |
required |
state
|
GaussianState
|
Variational Gaussian distribution. |
required |
integrator
|
AbstractIntegrator | None
|
Integration method. Required for non-conjugate likelihoods; ignored when the likelihood has an analytical fast path. |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar expected log-likelihood. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If no integrator is provided and the likelihood has no analytical form. |
Source code in src/gaussx/_quadrature/_expectations.py
log_likelihood_expectation(likelihood_fn: Callable[[Float[Array, ' N']], Float[Array, '']], state: GaussianState, integrator: AbstractIntegrator) -> Float[Array, '']
¶
Compute E[log p(y_{obs} | f(x))] where x ~ N(mu, Sigma).
For non-conjugate likelihoods (Bernoulli, Poisson, etc.) where the expectation has no closed form.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
likelihood_fn
|
Callable[[Float[Array, ' N']], Float[Array, '']]
|
Function mapping latent values to scalar
log-likelihood: |
required |
state
|
GaussianState
|
Input Gaussian distribution. |
required |
integrator
|
AbstractIntegrator
|
Integration method. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar expected log-likelihood. |
Source code in src/gaussx/_quadrature/_expectations.py
mean_expectation(fn: Callable[[Float[Array, ' N']], Float[Array, ' M']], state: GaussianState, integrator: AbstractIntegrator) -> Float[Array, ' M']
¶
Compute E[f(x)] where x ~ N(mu, Sigma).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[[Float[Array, ' N']], Float[Array, ' M']]
|
Function mapping |
required |
state
|
GaussianState
|
Input Gaussian distribution. |
required |
integrator
|
AbstractIntegrator
|
Integration method (Taylor, Unscented, MC, etc.). |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ' M']
|
Expected function value, shape |
Source code in src/gaussx/_quadrature/_expectations.py
gradient_expectation(fn: Callable[[Float[Array, ' N']], Float[Array, '']], state: GaussianState, integrator: AbstractIntegrator) -> Float[Array, ' N']
¶
Compute E[nabla f(x)] via Stein's lemma.
Uses the identity:
E[nabla f(x)] = Sigma^{-1} Cov[x, f(x)]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[[Float[Array, ' N']], Float[Array, '']]
|
Scalar-valued function mapping |
required |
state
|
GaussianState
|
Input Gaussian distribution. |
required |
integrator
|
AbstractIntegrator
|
Integration method. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ' N']
|
Expected gradient, shape |
Source code in src/gaussx/_quadrature/_expectations.py
cost_expectation(prediction_fn: Callable[[Float[Array, ' N']], Float[Array, ' M']], cost_fn: Callable[[Float[Array, ' M'], Float[Array, ' M']], Float[Array, '']], state: GaussianState, target: Float[Array, ' M'], integrator: AbstractIntegrator) -> Float[Array, '']
¶
Compute E[Cost(f(x), target)] where x ~ N(mu, Sigma).
For model-based RL: expected cost of a policy under state uncertainty.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prediction_fn
|
Callable[[Float[Array, ' N']], Float[Array, ' M']]
|
Maps state to prediction, |
required |
cost_fn
|
Callable[[Float[Array, ' M'], Float[Array, ' M']], Float[Array, '']]
|
Cost function, |
required |
state
|
GaussianState
|
Input Gaussian distribution (uncertain state). |
required |
target
|
Float[Array, ' M']
|
Target value, shape |
required |
integrator
|
AbstractIntegrator
|
Integration method. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar expected cost. |
Source code in src/gaussx/_quadrature/_expectations.py
ep_tilted_moments(log_lik_fn: Callable[[Float[Array, '']], Float[Array, '']], cav_mean: Float[Array, ' *batch'], cav_var: Float[Array, ' *batch'], *, order: int = 20) -> tuple[Float[Array, ' *batch'], Float[Array, ' *batch']]
¶
Compute tilted distribution moments via Gauss-Hermite quadrature.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
log_lik_fn
|
Callable[[Float[Array, '']], Float[Array, '']]
|
Scalar function mapping latent value |
required |
cav_mean
|
Float[Array, ' *batch']
|
Cavity means, shape |
required |
cav_var
|
Float[Array, ' *batch']
|
Cavity variances (positive), shape |
required |
order
|
int
|
Number of Gauss-Hermite quadrature points. Default 20. |
20
|
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ' *batch'], Float[Array, ' *batch']]
|
Tuple |
Source code in src/gaussx/_quadrature/_tilted_moments.py
Kernel expectations & uncertain-input GPs¶
The \(\Psi\)-statistics \(\Psi_0 = \mathbb{E}[k(x,x)]\), \(\Psi_1 = \mathbb{E}[k(x, X)]\), \(\Psi_2 = \mathbb{E}[k(x,\cdot)k(x,\cdot)^\top]\) and the GP / SVGP / VGP / BGPLVM predictive equations for inputs that are themselves Gaussian.
Structured linear algebra and Gaussian primitives for JAX.
AnalyticalPsiStatistics
¶
Bases: Protocol
Protocol for kernels with closed-form Ψ statistics.
Ψ statistics are required for uncertain-input GP models (e.g., BGPLVM). A kernel implementing this protocol provides analytical formulae instead of requiring numerical integration.
Source code in src/gaussx/_quadrature/_psi_statistics.py
kernel_expectations(kernel_fn: Callable[[Float[Array, ' D'], Float[Array, ' D']], Float[Array, '']], state: GaussianState, X_train: Float[Array, 'N_train D'], integrator: AbstractIntegrator) -> tuple[Float[Array, ''], Float[Array, ' N_train'], Float[Array, 'N_train N_train']]
¶
Compute kernel expectations Psi_0, Psi_1, Psi_2 for uncertain inputs.
These are the core quantities for GP inference with uncertain inputs:
Psi_0 = E[k(x, x)] scalar
Psi_1_i = E[k(x, x_i)] (N_train,)
Psi_2_{ij} = E[k(x, x_i) k(x, x_j)] (N_train, N_train)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kernel_fn
|
Callable[[Float[Array, ' D'], Float[Array, ' D']], Float[Array, '']]
|
Kernel function |
required |
state
|
GaussianState
|
Uncertain input distribution |
required |
X_train
|
Float[Array, 'N_train D']
|
Training points, shape |
required |
integrator
|
AbstractIntegrator
|
Integration method. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ''], Float[Array, ' N_train'], Float[Array, 'N_train N_train']]
|
Tuple |
Source code in src/gaussx/_quadrature/_gp_predict.py
compute_psi_statistics(kernel: object, state: GaussianState, X_train: Float[Array, 'M D'], *, integrator: AbstractIntegrator | None = None) -> tuple[Float[Array, ''], Float[Array, ' M'], Float[Array, 'M M']]
¶
Compute Ψ statistics, dispatching to analytical or numerical.
If kernel implements AnalyticalPsiStatistics, uses
the closed-form methods. Otherwise, falls back to numerical
integration via the provided integrator:
Ψ₀ = E[k(x, x)] scalar
Ψ₁ᵢ = E[k(x, xᵢ)] (M,)
Ψ₂ᵢⱼ = E[k(x, xᵢ) k(x, xⱼ)] (M, M)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kernel
|
object
|
Kernel object, optionally implementing
|
required |
state
|
GaussianState
|
Input Gaussian distribution x ~ 𝒩(μ, Σ). |
required |
X_train
|
Float[Array, 'M D']
|
Training/inducing points, shape |
required |
integrator
|
AbstractIntegrator | None
|
Numerical integrator for fallback. Required if
|
None
|
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ''], Float[Array, ' M'], Float[Array, 'M M']]
|
Tuple |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Source code in src/gaussx/_quadrature/_psi_statistics.py
uncertain_gp_predict(kernel_fn: Callable, X_train: Float[Array, 'N_train D'], alpha: Float[Array, ' N_train'], K_inv: lx.AbstractLinearOperator, state: GaussianState, integrator: AbstractIntegrator) -> tuple[Float[Array, ''], Float[Array, '']]
¶
Predictive mean and variance for GP with uncertain inputs.
Uses kernel expectations:
mu_pred = Psi_1 @ alpha
var_pred = Psi_0 - tr(K_inv @ Psi_2) + alpha^T @ Psi_2 @ alpha - mu_pred^2
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kernel_fn
|
Callable
|
Kernel function |
required |
X_train
|
Float[Array, 'N_train D']
|
Training points, shape |
required |
alpha
|
Float[Array, ' N_train']
|
Precomputed weights |
required |
K_inv
|
AbstractLinearOperator
|
Inverse of training kernel matrix operator. |
required |
state
|
GaussianState
|
Uncertain test input |
required |
integrator
|
AbstractIntegrator
|
Integration method. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ''], Float[Array, '']]
|
Tuple |
Source code in src/gaussx/_quadrature/_gp_predict.py
uncertain_gp_predict_mc(predict_fn: Callable[[Float[Array, ' D']], tuple[Float[Array, ''], Float[Array, '']]], state: GaussianState, n_particles: int = 100, key: jax.Array | None = None) -> tuple[Float[Array, ''], Float[Array, '']]
¶
Monte Carlo GP prediction with uncertain inputs.
Alternative to analytic kernel expectations when Psi integrals are intractable. Uses law of total variance:
mu = mean(particle_means)
var = var(particle_means) + mean(particle_vars)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predict_fn
|
Callable[[Float[Array, ' D']], tuple[Float[Array, ''], Float[Array, '']]]
|
GP predictor mapping |
required |
state
|
GaussianState
|
Uncertain test input |
required |
n_particles
|
int
|
Number of Monte Carlo particles. Default |
100
|
key
|
Array | None
|
PRNG key. If |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ''], Float[Array, '']]
|
Tuple |
Source code in src/gaussx/_quadrature/_gp_predict.py
uncertain_svgp_predict(kernel_fn: Callable, Z: Float[Array, 'M D'], alpha: Float[Array, ' M'], Q: lx.AbstractLinearOperator, state: GaussianState, integrator: AbstractIntegrator) -> tuple[Float[Array, ''], Float[Array, '']]
¶
Predictive mean and variance for SVGP with uncertain inputs.
Uses kernel expectations with inducing points:
mu_pred = Psi_1 @ alpha
var_pred = Psi_0 + tr(Q @ Psi_2) + alpha^T @ Psi_2 @ alpha - mu_pred^2
where Q = K_{zz}^{-1} S K_{zz}^{-1} - K_{zz}^{-1} is the variance
adjustment operator (see gaussx.svgp_variance_adjustment).
The exact uncertain GP trace correction is recovered by setting
Q = -K_{zz}^{-1}; if Z equals the training inputs, this matches
gaussx.uncertain_gp_predict.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kernel_fn
|
Callable
|
Kernel function |
required |
Z
|
Float[Array, 'M D']
|
Inducing points, shape |
required |
alpha
|
Float[Array, ' M']
|
Effective weights, shape |
required |
Q
|
AbstractLinearOperator
|
Variance adjustment operator |
required |
state
|
GaussianState
|
Uncertain test input |
required |
integrator
|
AbstractIntegrator
|
Integration method. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ''], Float[Array, '']]
|
Tuple |
Source code in src/gaussx/_quadrature/_gp_predict.py
uncertain_vgp_predict(kernel_fn: Callable, X_train: Float[Array, 'N_train D'], alpha: Float[Array, ' N_train'], Q: lx.AbstractLinearOperator, state: GaussianState, integrator: AbstractIntegrator) -> tuple[Float[Array, ''], Float[Array, '']]
¶
Predictive mean and variance for dense VGP with uncertain inputs.
Uses kernel expectations with training points:
mu_pred = Psi_1 @ alpha
var_pred = Psi_0 + tr(Q @ Psi_2) + alpha^T @ Psi_2 @ alpha - mu_pred^2
where Q = K^{-1} S K^{-1} - K^{-1} and alpha = K^{-1} m.
The exact uncertain GP is the special case Q = -K^{-1}. By contrast,
Q = 0 corresponds to S = K and removes only the trace correction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kernel_fn
|
Callable
|
Kernel function |
required |
X_train
|
Float[Array, 'N_train D']
|
Training points, shape |
required |
alpha
|
Float[Array, ' N_train']
|
Precomputed weights |
required |
Q
|
AbstractLinearOperator
|
Variance adjustment operator |
required |
state
|
GaussianState
|
Uncertain test input |
required |
integrator
|
AbstractIntegrator
|
Integration method. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ''], Float[Array, '']]
|
Tuple |
Source code in src/gaussx/_quadrature/_gp_predict.py
uncertain_bgplvm_predict(kernel_fn: Callable, X_train: Float[Array, 'N_train D'], alpha: Float[Array, 'N_train D_out'], K_inv: lx.AbstractLinearOperator, state: GaussianState, integrator: AbstractIntegrator) -> tuple[Float[Array, ' D_out'], Float[Array, ' D_out']]
¶
Multi-output uncertain GP prediction for BGPLVM.
Maps a latent variable z ~ N(mu, Sigma) to high-dimensional
reconstruction y using per-output GP weights:
mu_pred_d = Psi_1 @ alpha_d
var_pred_d = Psi_0 - tr(K_inv @ Psi_2)
+ alpha_d^T @ Psi_2 @ alpha_d - mu_pred_d^2
This intentionally uses the exact GP trace term for every output dimension.
Unlike gaussx.uncertain_vgp_predict and
gaussx.uncertain_svgp_predict, there is no separate variational
covariance correction operator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kernel_fn
|
Callable
|
Kernel function |
required |
X_train
|
Float[Array, 'N_train D']
|
Training points, shape |
required |
alpha
|
Float[Array, 'N_train D_out']
|
Multi-output weights |
required |
K_inv
|
AbstractLinearOperator
|
Inverse of training kernel matrix operator. |
required |
state
|
GaussianState
|
Uncertain latent input |
required |
integrator
|
AbstractIntegrator
|
Integration method. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ' D_out'], Float[Array, ' D_out']]
|
Tuple |