GP API¶
Wave 2 ships the dense-GP foundation: kernel math functions, concrete
Parameterized kernel classes, abstract component protocols, and the
model-facing entry points (GPPrior, ConditionedGP, gp_factor,
gp_sample). Scalable matrix construction and solver strategies
(numerically stable assembly, implicit operators, batched matvec,
Cholesky / CG / BBMM / LSMR / SLQ) live in
gaussx.
Split with gaussx
pyrox owns the kernel function side — closed-form math primitives
readable in a dozen lines — plus the NumPyro-aware model shell
(GPPrior, gp_factor, gp_sample). gaussx owns every piece of
linear algebra: stable matrix construction, solver strategies, and
the underlying MultivariateNormal distribution. The model entry
points accept any gaussx.AbstractSolverStrategy (default
gaussx.DenseSolver()).
Model entry points¶
import jax.numpy as jnp
import numpyro
from pyrox.gp import GPPrior, RBF, gp_factor, gp_sample
def regression_model(X, y):
"""Collapsed Gaussian-likelihood GP regression."""
kernel = RBF()
prior = GPPrior(kernel=kernel, X=X)
gp_factor("obs", prior, y, noise_var=jnp.array(0.05))
def latent_model(X):
"""Latent-function GP for non-conjugate likelihoods."""
kernel = RBF()
prior = GPPrior(kernel=kernel, X=X)
f = gp_sample("f", prior)
# ... attach any likelihood to f here, e.g. Bernoulli or Poisson.
Swap the solver strategy at construction time:
from gaussx import CGSolver, ComposedSolver, DenseLogdet, DenseSolver
prior = GPPrior(kernel=RBF(), X=X, solver=CGSolver())
# Or compose — CG for solve, dense Cholesky for logdet:
prior = GPPrior(
kernel=RBF(), X=X,
solver=ComposedSolver(solve_strategy=CGSolver(), logdet_strategy=DenseLogdet()),
)
pyrox.gp.GPPrior
¶
Bases: Module
Finite-dimensional GP prior over a fixed training input set.
Holds a kernel, training inputs X, an optional mean function, an
optional solver strategy, and a small diagonal jitter for numerical
stability on otherwise-singular prior covariances.
Attributes:
| Name | Type | Description |
|---|---|---|
kernel |
Kernel
|
Any :class: |
X |
Float[Array, 'N D']
|
Training inputs of shape |
mean_fn |
Callable[[Float[Array, 'N D']], Float[Array, ' N']] | None
|
Callable |
solver |
AbstractSolverStrategy | None
|
Any |
jitter |
float
|
Diagonal regularization added to the prior covariance
for numerical stability. Not a noise model — use
|
Source code in src/pyrox/gp/_models.py
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | |
condition(y, noise_var)
¶
Condition on Gaussian-likelihood observations y.
Precomputes
alpha = (K + (jitter + noise_var) * I)^{-1} (y - mu(X)) and
caches it in the returned :class:ConditionedGP. The same
jitter regularization configured on this prior is included
alongside noise_var in the conditioned operator and solve, so
every downstream predict / sample call sees the regularized
covariance.
Source code in src/pyrox/gp/_models.py
log_prob(f)
¶
Marginal log-density of f under the GP prior.
Computes :math:\log \mathcal{N}(f \mid \mu(X), K(X, X) + \text{jitter}\,I)
using :func:gaussx.log_marginal_likelihood, so any solver strategy
on this prior applies.
Source code in src/pyrox/gp/_models.py
mean(X)
¶
Evaluate the mean function at X; zero by default.
sample(key)
¶
Draw f \sim p(f) = \mathcal{N}(\mu(X), K + \text{jitter}\,I).
Wraps the prior in a :class:gaussx.MultivariateNormal with
the configured :attr:solver. This is the non-NumPyro analogue
of :func:gp_sample — useful for tests, diagnostics, and
prior-sample initialization without registering a sample site.
Source code in src/pyrox/gp/_models.py
pyrox.gp.ConditionedGP
¶
Bases: Module
GP conditioned on Gaussian-likelihood training observations.
Holds the precomputed training solve alpha (via
:class:gaussx.PredictionCache) and the noisy covariance operator so
predictions at multiple test sets reuse the training solve.
Source code in src/pyrox/gp/_models.py
predict(X_star)
¶
Return (mean, variance) at X_* as a tuple.
Both kernel evaluations share a single kernel context; see
:meth:predict_var.
Source code in src/pyrox/gp/_models.py
predict_mean(X_star)
¶
:math:\mu_* = \mu(X_*) + K_{*f}\,\alpha.
Source code in src/pyrox/gp/_models.py
predict_var(X_star)
¶
Diagonal predictive variance at X_*.
.. math:: \sigma^2_{,i} = k(x_{,i}, x_{,i}) - K_{f}[i,:] \cdot (K + \sigma^2 I)^{-1} K_{f*}[:,i]
K_cross and K_diag are computed under one shared kernel
context so Pattern B / C kernels with prior'd hyperparameters
register their NumPyro sites once and reuse them across both
kernel calls (and the cached training solve).
Source code in src/pyrox/gp/_models.py
sample(key, X_star, n_samples=1)
¶
Sample from the diagonal predictive N(mean, diag(var)).
Returns samples independently per test point; correlated joint
samples from the full predictive covariance are not covered by
the Wave 2 dense surface. For correlated samples, build the full
predictive covariance explicitly and draw from
:class:gaussx.MultivariateNormal.
Source code in src/pyrox/gp/_models.py
pyrox.gp.gp_factor(name, prior, y, noise_var)
¶
Register the collapsed GP log marginal likelihood with NumPyro.
Adds
log p(y | X, theta) = log N(y | mu, K + (jitter + sigma^2) I)
to the NumPyro trace as numpyro.factor(name, ...). The prior's
jitter is included in addition to the observation noise variance
so the covariance matches what :meth:GPPrior.condition builds. Use
this inside a NumPyro model when the likelihood is Gaussian and you
want the latent function marginalized analytically.
Source code in src/pyrox/gp/_models.py
pyrox.gp.gp_sample(name, prior, *, whitened=False, guide=None)
¶
Sample a latent function f at the prior's training inputs.
Three mutually exclusive modes:
whitened=False,guide=None(default) — register a singlenumpyro.sample(name, MVN(mu, K + jitter I))site. The latent function is sampled directly from the prior.whitened=True,guide=None— register a unit-normal latent sitef"{name}_u"with shape(N,)and return the deterministic valuef = mu(X) + L uwhereLis the Cholesky factor ofK + jitter I. This reparameterization is the standard fix for mean-field SVI on GP-correlated latents (Murray & Adams, 2010): a NumPyro auto-guide such as :class:numpyro.infer.autoguide.AutoNormalthen approximates the well-conditioned isotropic posterior overuinstead of the ill-conditioned correlated posterior overf.guideprovided — delegate toguide.register(name, prior). Concrete variational guides (Wave 3) own their own parameterization, so combiningwhitened=Truewithguideis rejected.
Use this inside a NumPyro model for non-conjugate likelihoods, where the latent function cannot be marginalized analytically.
Source code in src/pyrox/gp/_models.py
Concrete kernels¶
Each Parameterized kernel registers its hyperparameters with positivity
constraints where appropriate. Attach priors with set_prior, autoguides
with autoguide, and flip set_mode("model" | "guide").
pyrox.gp.RBF
¶
Bases: _ParameterizedKernel
Radial basis function (squared exponential) kernel.
Source code in src/pyrox/gp/_kernels.py
pyrox.gp.Matern
¶
Bases: _ParameterizedKernel
Matern kernel with nu in {0.5, 1.5, 2.5}.
nu is a static class attribute — it selects a code path in the
underlying math primitive and is not a trainable parameter.
Source code in src/pyrox/gp/_kernels.py
pyrox.gp.Periodic
¶
Bases: _ParameterizedKernel
Periodic (MacKay) kernel.
Source code in src/pyrox/gp/_kernels.py
pyrox.gp.Linear
¶
Bases: _ParameterizedKernel
Linear kernel sigma^2 x^T x' + bias.
bias is constrained nonnegative because k = sigma^2 X X^T + b 1 1^T
is only PSD for b >= 0 (e.g. X = 0 gives eigenvalue N*b).
Source code in src/pyrox/gp/_kernels.py
pyrox.gp.RationalQuadratic
¶
Bases: _ParameterizedKernel
Rational quadratic kernel.
Source code in src/pyrox/gp/_kernels.py
pyrox.gp.Polynomial
¶
Bases: _ParameterizedKernel
Polynomial kernel sigma^2 (x^T x' + bias)^degree.
degree is a static class field (it selects an integer power, not
an optimization target). bias is constrained nonnegative — the
degree=1 case reduces to :class:Linear and has the same
PSD-requires-b>=0 failure mode.
Source code in src/pyrox/gp/_kernels.py
pyrox.gp.Cosine
¶
Bases: _ParameterizedKernel
Cosine kernel sigma^2 cos(2 pi ||x - x'|| / period).
Source code in src/pyrox/gp/_kernels.py
pyrox.gp.White
¶
Bases: _ParameterizedKernel
White-noise kernel sigma^2 delta(x, x').
Source code in src/pyrox/gp/_kernels.py
pyrox.gp.Constant
¶
Bases: _ParameterizedKernel
Constant kernel k(x, x') = sigma^2.
Source code in src/pyrox/gp/_kernels.py
Sparse-GP inducing features (#49)¶
Inter-domain inducing-feature families used to build scalable sparse GPs
where the inducing-prior covariance K_uu becomes diagonal. Pass any
of these to :class:SparseGPPrior via the inducing= keyword in
place of a raw point matrix Z.
from pyrox.gp import RBF, FourierInducingFeatures, SparseGPPrior
kernel = RBF(init_lengthscale=0.3, init_variance=1.0)
features = FourierInducingFeatures.init(in_features=1, num_basis_per_dim=64, L=5.0)
prior = SparseGPPrior(kernel=kernel, inducing=features) # K_uu is diagonal!
pyrox.gp.InducingFeatures
¶
Bases: Protocol
Protocol for inter-domain inducing features.
Implementations expose the inducing-prior covariance K_uu and the
cross-covariance k_ux(X) between data points and inducing
features. Diagonal-friendly concretions return
:class:lineax.DiagonalLinearOperator so the downstream solve dispatches
to elementwise division.
Input shape is family-dependent. k_ux takes a batch of data
points X in whatever representation the family consumes:
- :class:
FourierInducingFeatures: coordinates(N, D). - :class:
SphericalHarmonicInducingFeatures: unit vectors(N, 3). - :class:
LaplacianInducingFeatures: integer node indices(N,).
Each implementation validates its own expected shape and dtype.
Source code in src/pyrox/gp/_inducing.py
pyrox.gp.FourierInducingFeatures
¶
Bases: Module
VFF — Variational Fourier inducing features on :math:[-L, L]^D.
For a stationary kernel with spectral density :math:S(\cdot), the
basis :math:\{\phi_j\} of Laplacian eigenfunctions on the box gives
.. math::
K_{uu} = \mathrm{diag}\!\big(S(\sqrt{\lambda_j})\big),
\qquad
K_{ux}(x)_j = S(\sqrt{\lambda_j})\,\phi_j(x).
With this convention :math:K_{ux} K_{uu}^{-1} = \phi_j(x), so the
SVGP predictive mean reduces to a basis evaluation. K_{uu} is
returned as a :class:lineax.DiagonalLinearOperator to preserve the
O(M) solve dispatch end-to-end.
Attributes:
| Name | Type | Description |
|---|---|---|
in_features |
int
|
Input dimension :math: |
num_basis_per_dim |
tuple[int, ...]
|
Per-axis number of 1D eigenfunctions; total
count is |
L |
tuple[float, ...]
|
Per-axis box half-width. |
Source code in src/pyrox/gp/_inducing.py
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 | |
K_uu(kernel, *, jitter=1e-06)
¶
Diagonal :math:K_{uu} — entries S(sqrt(lambda_j)) plus jitter.
Source code in src/pyrox/gp/_inducing.py
k_ux(x, kernel)
¶
Cross-covariance entries :math:S(\sqrt{\lambda_j})\,\phi_j(x).
Source code in src/pyrox/gp/_inducing.py
pyrox.gp.SphericalHarmonicInducingFeatures
¶
Bases: Module
VISH — inducing harmonics on :math:S^2 (Dutordoir et al. 2020).
For any zonal kernel :math:k(x, x') = \kappa(x \cdot x') on the
unit 2-sphere, the Funk-Hecke theorem gives a diagonal :math:K_{uu}
whose eigenvalues are the kernel's Funk-Hecke coefficients
:math:a_l. The cross-covariance is :math:a_l\,Y_{lm}(x).
Funk-Hecke coefficients are computed by Gauss-Legendre quadrature (arbitrary kernels supported, no closed form required). For kernels that have a closed-form Funk-Hecke series (RBF on S² via Bessel functions etc.), the numerical and analytic answers should agree to the quadrature tolerance.
Attributes:
| Name | Type | Description |
|---|---|---|
l_max |
int
|
Maximum harmonic degree, inclusive. |
num_quadrature |
int
|
Gauss-Legendre nodes for the Funk-Hecke integral. |
Source code in src/pyrox/gp/_inducing.py
K_uu(kernel, *, jitter=1e-06)
¶
Diagonal :math:K_{uu} — Funk-Hecke coefficients per harmonic.
Source code in src/pyrox/gp/_inducing.py
k_ux(unit_xyz, kernel)
¶
Cross-covariance: :math:a_l\,Y_{lm}(x).
Source code in src/pyrox/gp/_inducing.py
pyrox.gp.LaplacianInducingFeatures
¶
Bases: Module
Inducing features from low-frequency graph Laplacian eigenvectors.
For a graph with normalized Laplacian :math:L, take the smallest
num_basis eigenpairs :math:(\mu_j, v_j). Treating the kernel as
a function of the graph distance — specifically, applying the kernel
spectrum :math:g(\mu) to the Laplacian eigenvalues — gives a
diagonal :math:K_{uu}.
This implementation supports the heat-kernel family
:math:g(\mu) = \exp(-\mu / (2 \ell^2)) (matching :class:pyrox.gp.RBF
in spectrum) by reusing :func:pyrox._basis.spectral_density with the
eigenvalues as input.
Attributes:
| Name | Type | Description |
|---|---|---|
eigvals |
Float[Array, ' M']
|
|
eigvecs |
Float[Array, 'V M']
|
|
num_quadrature |
Float[Array, 'V M']
|
Unused (kept for protocol uniformity). |
Note
X is a vector of node indices (integer-valued), not
coordinates. The returned cross-covariance gathers the relevant
rows of eigvecs.
Source code in src/pyrox/gp/_inducing.py
pyrox.gp.DecoupledInducingFeatures
¶
Bases: Module
Decoupled mean / covariance inducing-feature bases (Cheng & Boots 2017).
Two independent inducing-feature sets:
mean_features: a largealpha-basis used by the SVGP posterior mean (cheap — predictive mean cost is linear in the mean-basis size).cov_features: a smallbeta-basis used for the posterior covariance (the true bottleneck; keep this small).
The two bases need not share the same family — a common pattern is a large Fourier basis for the mean and a small spherical-harmonic basis for the covariance, or vice versa. The downstream guide consumes both via the standard SVGP machinery.
Attributes:
| Name | Type | Description |
|---|---|---|
mean_features |
InducingFeatures
|
Inducing-feature object backing the predictive mean. |
cov_features |
InducingFeatures
|
Inducing-feature object backing the predictive covariance. |
Note
DecoupledInducingFeatures itself does not implement
:class:InducingFeatures (no single K_uu makes sense for two
bases). Consumers should access .mean_features and
.cov_features directly.
Source code in src/pyrox/gp/_inducing.py
pyrox.gp.funk_hecke_coefficients(kernel, l_max, *, num_quadrature=256)
¶
Funk-Hecke coefficients of a zonal kernel on :math:S^2.
For a kernel of the form :math:k(x, x') = \kappa(x \cdot x') on the
unit 2-sphere, the Funk-Hecke theorem gives:
.. math::
a_l = 2\pi \int_{-1}^{1} \kappa(t)\,P_l(t)\,dt.
Returns (l_max + 1,) coefficients indexed by l. We treat any
Euclidean kernel as zonal-on-the-sphere via
:math:\kappa(t) = k_{\mathrm{euc}}(\hat{n}_0, \hat{n}_t) for unit
vectors at angular separation arccos(t).
Source code in src/pyrox/gp/_inducing.py
Sparse GP prior¶
pyrox.gp.SparseGPPrior
¶
Bases: Module
GP prior parameterized over inducing inputs Z.
Represents the zero-mean prior over inducing values u = f(Z)
used by sparse variational guides:
.. math::
p(u) = \mathcal{N}(0,\, K_{ZZ} + \mathrm{jitter}\,I).
The standard SVGP convention is to subtract any global mean function
before forming the prior over u and to add it back at predict
time, so the inducing-prior mean is fixed to zero (this is what the
guides' KL terms assume — see :meth:FullRankGuide.kl_divergence,
:meth:MeanFieldGuide.kl_divergence, :meth:WhitenedGuide.kl_divergence).
The :attr:mean_fn attribute on this class is exposed as a
convenience for callers that want to add mu(X_*) back onto the
predictive mean returned by :meth:Guide.predict; it is not
incorporated in :meth:inducing_operator or in the guides' KL.
Pair with a sparse variational guide that owns q(u) = N(m, S) to
obtain the standard SVGP predictive
.. math::
\mu_*(x) = K_{xZ} K_{ZZ}^{-1} m, \qquad
\sigma_*^2(x) = k(x, x) - K_{xZ} K_{ZZ}^{-1} K_{Zx}
+ K_{xZ} K_{ZZ}^{-1} S K_{ZZ}^{-1} K_{Zx}.
Attributes:
| Name | Type | Description |
|---|---|---|
kernel |
Kernel
|
Any :class: |
Z |
Float[Array, 'M D'] | None
|
Inducing inputs of shape |
mean_fn |
Callable[[Float[Array, 'N D']], Float[Array, ' N']] | None
|
Callable |
solver |
AbstractSolverStrategy | None
|
Any |
jitter |
float
|
Diagonal regularization added to |
Source code in src/pyrox/gp/_sparse.py
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 | |
num_inducing
property
¶
Number of inducing inputs / features M.
cross_covariance(X)
¶
:math:K_{XZ} — covariance between X and the inducing inputs/features.
The expected shape of X is inducing-family-dependent:
- Point-inducing (
Z) or :class:FourierInducingFeatures: coordinates(N, D). - :class:
SphericalHarmonicInducingFeatures: unit vectors(N, 3). - :class:
LaplacianInducingFeatures: integer node indices(N,).
See :meth:predictive_blocks for the shared-context batch
helper to use when assembling several SVGP blocks together.
Source code in src/pyrox/gp/_sparse.py
inducing_operator()
¶
Return K_{ZZ} + \text{jitter}\,I as a lineax operator.
For point-inducing priors, returns a dense
:class:lineax.MatrixLinearOperator with positive_semidefinite_tag.
For inducing-feature priors, delegates to
:meth:InducingFeatures.K_uu — typically a
:class:lineax.DiagonalLinearOperator so the downstream
:func:gaussx.solve dispatches in O(M) instead of O(M^3).
Single kernel call; safe standalone for kernels with priors. For
building several SVGP blocks together, prefer
:meth:predictive_blocks, which scopes one shared kernel
context across K_zz, K_xz, and K_xx_diag so
Pattern B / C kernels register their NumPyro hyperparameter
sites once instead of resampling per call.
Source code in src/pyrox/gp/_sparse.py
kernel_diag(X)
¶
Prior diagonal \mathrm{diag}\,K(X, X) — variance at each x.
See :meth:predictive_blocks for the shared-context batch
helper to use when assembling several SVGP blocks together.
Source code in src/pyrox/gp/_sparse.py
log_prob(u)
¶
Log-density under :math:p(u) = \mathcal{N}(0, K_{ZZ} + \text{jitter}\,I).
Delegates to :func:gaussx.gaussian_log_prob with the
configured :attr:solver so the user-supplied solver controls
the solve / logdet work on K_zz_op. Useful for
scoring inducing values against the SVGP prior in non-NumPyro
contexts (e.g.\ tests, diagnostics).
Source code in src/pyrox/gp/_sparse.py
mean(X)
¶
Evaluate the mean function at X; zero by default.
predictive_blocks(X)
¶
Return (K_zz_op, K_xz, K_xx_diag) under one shared kernel context.
For Pattern B / C kernels with prior'd hyperparameters, the three
kernel evaluations needed for an SVGP predictive must share a
single :class:pyrox.PyroxModule context so the underlying
pyrox_sample sites register once and yield consistent
hyperparameter draws across K_{ZZ}, K_{XZ}, and the
diagonal \mathrm{diag}\,K(X, X). Without this scoping, three
separate calls would draw three independent hyperparameter
samples (under seed) or raise NumPyro duplicate-site errors
(under tracing) — either way invalidating the SVGP math.
For pure :class:equinox.Module kernels (no _get_context),
this is equivalent to calling :meth:inducing_operator,
:meth:cross_covariance, and :meth:kernel_diag independently.
For inducing-feature priors, K_zz_op is a
:class:lineax.DiagonalLinearOperator (jitter folded into the
diagonal vector — never + jnp.eye) so the downstream solve
stays O(M).
Source code in src/pyrox/gp/_sparse.py
sample(key)
¶
Draw u \sim p(u) from the inducing prior.
Wraps the inducing operator in a
:class:gaussx.MultivariateNormal with the configured
:attr:solver. MultivariateNormal.sample factors the
covariance via :func:gaussx.cholesky and reparameterizes;
the returned draw has shape (M,).
Note: the SVGP variational workflow samples u from the
guide :math:q(u), not the prior. This method exists so the
prior surface is symmetric with the guide surface and so users
can score / draw inducing values against the prior directly
(e.g.\ for tests or for prior-sample initialization).
Source code in src/pyrox/gp/_sparse.py
Component protocols¶
Abstract pyrox-local bases for the orthogonal component stack. Wave 2
ships only the contracts for Guide, Integrator, and Likelihood —
concrete implementations land in later waves. Solver strategies live in
gaussx._strategies.
pyrox.gp.Kernel
¶
Bases: Module
Abstract base for GP covariance functions.
Subclasses implement :meth:__call__ returning the Gram matrix on a pair
of input batches. :meth:gram and :meth:diag are convenience defaults
that derive from :meth:__call__; structured subclasses (Kronecker,
state-space, etc.) should override them for efficiency.
Source code in src/pyrox/gp/_protocols.py
diag(X)
¶
Diagonal of K(X, X).
Default implementation extracts the diagonal of the full Gram. For
stationary kernels with constant diagonal, override with a vectorized
broadcast for the O(N) shortcut.
Source code in src/pyrox/gp/_protocols.py
pyrox.gp.Guide
¶
Bases: Module
Abstract base for variational posterior families.
Concrete guides (DeltaGuide, MeanFieldGuide, LowRankGuide,
FullRankGuide, etc.) land in the dedicated guide waves (#28, #29).
The whitening principle keeps optimization geometry well-conditioned —
sample from a unit-scale latent and unwhiten with the prior Cholesky.
Two distinct entry points:
- :meth:
sample/ :meth:log_prob— pure variational draws and densities.sample(self, key)returns a draw fromq(f);log_prob(self, f)evaluateslog q(f). Neither touches the NumPyro trace. register(name, prior)(optional) — the NumPyro-integration hook invoked by :func:pyrox.gp.gp_samplewhen a guide is supplied. Use it to register a sample / param site (or compose one out of guide state) undernameand return the latent function value. Concrete guides that participate in :func:gp_sampleshould implement this; the protocol leaves it unspecified so guides usable purely outside NumPyro stay valid.
Source code in src/pyrox/gp/_protocols.py
pyrox.gp.Integrator
¶
Bases: Module
Abstract base for Gaussian-expectation integrators.
Computes :math:\mathbb{E}_{q(f)}[g(f)] where q(f) = N(mean, var).
Concrete integrators (Gauss-Hermite, sigma-points, cubature, Taylor,
Monte Carlo) land in later waves and may delegate to gaussx's
quadrature primitives.
Source code in src/pyrox/gp/_protocols.py
pyrox.gp.Likelihood
¶
Bases: Module
Abstract base for observation models.
Implements the conditional p(y | f) and a default
:meth:expected_log_prob that integrates over a Gaussian latent via an
:class:Integrator. Concrete likelihoods (Gaussian, Bernoulli, Poisson,
StudentT, ...) land in later waves.
Source code in src/pyrox/gp/_protocols.py
Math primitives¶
Pure JAX kernel functions. Stateless, differentiable, composable —
(Array, ..., hyperparams) -> Gram. No NumPyro, no protocols.
Pure JAX kernel evaluation primitives — math definitions only.
Each function takes two input matrices X1 of shape (N1, D),
X2 of shape (N2, D), hyperparameters as JAX arrays, and returns
the (N1, N2) Gram matrix. All inputs are 2-D; callers that have 1-D
arrays must add a trailing singleton dimension first.
These are the canonical closed-form math for each kernel — small,
readable, and tutorial-facing. The companion scalable construction
surface (numerically stable matrix assembly, mixed-precision
accumulation, implicit/structured operators, batched matvec) lives in
gaussx; see :func:gaussx.stable_rbf_kernel and
:class:gaussx.ImplicitKernelOperator for the production path.
The composition helpers :func:kernel_add and :func:kernel_mul act on
already-evaluated Gram matrices, not on callables. Higher-level
:class:pyrox.gp.Kernel classes (Wave 2 Layer 1, see issue #20) compose
callables and may opt in to gaussx's scalable variants when needed.
Index axes are named via :mod:einops (einsum / rearrange) rather
than raw broadcasting so shape intent stays legible at the call site.
constant_kernel(X1, X2, variance)
¶
Constant kernel.
.. math:: k(x, x') = \sigma^2
A rank-one kernel useful as a scalar offset additive component.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X1
|
Float[Array, 'N1 D']
|
|
required |
X2
|
Float[Array, 'N2 D']
|
|
required |
variance
|
Float[Array, '']
|
Scalar value. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N1 N2']
|
|
Source code in src/pyrox/gp/_src/kernels.py
cosine_kernel(X1, X2, variance, period)
¶
Cosine kernel.
.. math:: k(x, x') = \sigma^2 \cos!\left( \frac{2 \pi |x - x'|}{p} \right)
Useful as a simple periodic building block alongside
:func:periodic_kernel; unlike the Mackay form this one uses plain
cosine of distance and can go negative.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X1
|
Float[Array, 'N1 D']
|
|
required |
X2
|
Float[Array, 'N2 D']
|
|
required |
variance
|
Float[Array, '']
|
Scalar signal variance. |
required |
period
|
Float[Array, '']
|
Scalar period. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N1 N2']
|
|
Source code in src/pyrox/gp/_src/kernels.py
kernel_add(K1, K2)
¶
kernel_mul(K1, K2)
¶
Pointwise (Hadamard) product of two already-evaluated Gram matrices.
linear_kernel(X1, X2, variance, bias)
¶
Linear kernel.
.. math:: k(x, x') = \sigma^2\, x^\top x' + b
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X1
|
Float[Array, 'N1 D']
|
|
required |
X2
|
Float[Array, 'N2 D']
|
|
required |
variance
|
Float[Array, '']
|
Scalar variance multiplier on the dot product. |
required |
bias
|
Float[Array, '']
|
Scalar additive bias. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N1 N2']
|
|
Source code in src/pyrox/gp/_src/kernels.py
matern_kernel(X1, X2, variance, lengthscale, nu)
¶
Matern kernel with closed-form nu in {1/2, 3/2, 5/2}.
.. math:: k(x, x') = \sigma^2\, f_\nu(r / \ell), \qquad r = |x - x'|
Only the three common half-integer orders are supported because those
admit closed-form expressions without Bessel evaluations. nu is a
static Python float (not a JAX array) so the branch specializes at
trace time.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X1
|
Float[Array, 'N1 D']
|
|
required |
X2
|
Float[Array, 'N2 D']
|
|
required |
variance
|
Float[Array, '']
|
Scalar signal variance. |
required |
lengthscale
|
Float[Array, '']
|
Scalar lengthscale. |
required |
nu
|
float
|
Smoothness parameter; must be |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N1 N2']
|
|
Source code in src/pyrox/gp/_src/kernels.py
periodic_kernel(X1, X2, variance, lengthscale, period)
¶
Periodic (MacKay) kernel.
.. math:: k(x, x') = \sigma^2 \exp!\left( -\frac{2 \sin^2(\pi |x - x'| / p)}{\ell^2} \right)
For multi-dimensional inputs the argument uses the Euclidean distance, matching the common GPML convention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X1
|
Float[Array, 'N1 D']
|
|
required |
X2
|
Float[Array, 'N2 D']
|
|
required |
variance
|
Float[Array, '']
|
Scalar signal variance. |
required |
lengthscale
|
Float[Array, '']
|
Scalar lengthscale. |
required |
period
|
Float[Array, '']
|
Scalar period |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N1 N2']
|
|
Source code in src/pyrox/gp/_src/kernels.py
polynomial_kernel(X1, X2, variance, bias, degree)
¶
Polynomial kernel.
.. math:: k(x, x') = \sigma^2 \bigl(x^\top x' + b\bigr)^d
:func:linear_kernel is the special case degree == 1 without the
outer power. degree is a static Python int so the kernel specializes
at trace time.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X1
|
Float[Array, 'N1 D']
|
|
required |
X2
|
Float[Array, 'N2 D']
|
|
required |
variance
|
Float[Array, '']
|
Scalar multiplier. |
required |
bias
|
Float[Array, '']
|
Scalar additive bias inside the power. |
required |
degree
|
int
|
Positive integer polynomial degree. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N1 N2']
|
|
Source code in src/pyrox/gp/_src/kernels.py
rational_quadratic_kernel(X1, X2, variance, lengthscale, alpha)
¶
Rational quadratic kernel.
.. math:: k(x, x') = \sigma^2 \left( 1 + \frac{|x - x'|^2}{2\alpha \ell^2} \right)^{-\alpha}
Scale mixture of RBF kernels: the limit alpha -> infty recovers the
RBF, small alpha yields heavier-tailed correlations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X1
|
Float[Array, 'N1 D']
|
|
required |
X2
|
Float[Array, 'N2 D']
|
|
required |
variance
|
Float[Array, '']
|
Scalar signal variance. |
required |
lengthscale
|
Float[Array, '']
|
Scalar lengthscale. |
required |
alpha
|
Float[Array, '']
|
Scalar shape parameter; must be positive. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N1 N2']
|
|
Source code in src/pyrox/gp/_src/kernels.py
rbf_kernel(X1, X2, variance, lengthscale)
¶
Radial basis function (squared exponential) kernel.
.. math:: k(x, x') = \sigma^2 \exp!\left(-\frac{|x - x'|^2}{2\ell^2}\right)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X1
|
Float[Array, 'N1 D']
|
|
required |
X2
|
Float[Array, 'N2 D']
|
|
required |
variance
|
Float[Array, '']
|
Scalar signal variance |
required |
lengthscale
|
Float[Array, '']
|
Scalar lengthscale |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N1 N2']
|
|
Source code in src/pyrox/gp/_src/kernels.py
white_kernel(X1, X2, variance)
¶
White-noise kernel.
.. math:: k(x, x') = \sigma^2 \,\delta(x, x')
Nonzero only where X1[i] exactly matches X2[j] across all feature
dimensions. When evaluated at X1 == X2 this yields sigma^2 * I.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X1
|
Float[Array, 'N1 D']
|
|
required |
X2
|
Float[Array, 'N2 D']
|
|
required |
variance
|
Float[Array, '']
|
Scalar noise variance. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N1 N2']
|
|