GP API¶
Wave 2 ships the dense-GP foundation: kernel math functions, concrete
Parameterized kernel classes, abstract component protocols, and the
model-facing entry points (GPPrior, ConditionedGP, gp_factor,
gp_sample). Scalable matrix construction and solver strategies
(numerically stable assembly, implicit operators, batched matvec,
Cholesky / CG / BBMM / LSMR / SLQ) live in
gaussx.
Split with gaussx
pyrox owns the kernel function side — closed-form math primitives
readable in a dozen lines — plus the NumPyro-aware model shell
(GPPrior, gp_factor, gp_sample). gaussx owns every piece of
linear algebra: stable matrix construction, solver strategies, and
the underlying MultivariateNormal distribution. The model entry
points accept any gaussx.AbstractSolverStrategy (default
gaussx.DenseSolver()).
Model entry points¶
import jax.numpy as jnp
import numpyro
from pyrox.gp import GPPrior, RBF, gp_factor, gp_sample
def regression_model(X, y):
"""Collapsed Gaussian-likelihood GP regression."""
kernel = RBF()
prior = GPPrior(kernel=kernel, X=X)
gp_factor("obs", prior, y, noise_var=jnp.array(0.05))
def latent_model(X):
"""Latent-function GP for non-conjugate likelihoods."""
kernel = RBF()
prior = GPPrior(kernel=kernel, X=X)
f = gp_sample("f", prior)
# ... attach any likelihood to f here, e.g. Bernoulli or Poisson.
Swap the solver strategy at construction time:
from gaussx import CGSolver, ComposedSolver, DenseLogdet, DenseSolver
prior = GPPrior(kernel=RBF(), X=X, solver=CGSolver())
# Or compose — CG for solve, dense Cholesky for logdet:
prior = GPPrior(
kernel=RBF(), X=X,
solver=ComposedSolver(solve_strategy=CGSolver(), logdet_strategy=DenseLogdet()),
)
pyrox.gp.GPPrior
¶
Bases: Module
Finite-dimensional GP prior over a fixed training input set.
Holds a kernel, training inputs X, an optional mean function, an
optional solver strategy, and a small diagonal jitter for numerical
stability on otherwise-singular prior covariances.
Attributes:
| Name | Type | Description |
|---|---|---|
kernel |
Kernel
|
Any :class: |
X |
Float[Array, 'N D']
|
Training inputs of shape |
mean_fn |
Callable[[Float[Array, 'N D']], Float[Array, ' N']] | None
|
Callable |
solver |
AbstractSolverStrategy | None
|
Any |
jitter |
float
|
Diagonal regularization added to the prior covariance
for numerical stability. Not a noise model — use
|
Source code in src/pyrox/gp/_models.py
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | |
condition(y, noise_var)
¶
Condition on Gaussian-likelihood observations y.
Precomputes
alpha = (K + (jitter + noise_var) * I)^{-1} (y - mu(X)) and
caches it in the returned :class:ConditionedGP. The same
jitter regularization configured on this prior is included
alongside noise_var in the conditioned operator and solve, so
every downstream predict / sample call sees the regularized
covariance.
The operator construction and any subsequent hyperparameter
capture share one :func:_kernel_context, so for Pattern B/C
kernels with priors the cached operator and the resolved
hyperparameters on the returned :class:ConditionedGP come from
the same draw. Downstream consumers (notably
:class:pyrox.gp.PathwiseSampler) reuse those values to stay
consistent with the cached operator.
Source code in src/pyrox/gp/_models.py
condition_nongauss(likelihood, y, *, strategy)
¶
Condition on a non-Gaussian likelihood via a site-based strategy.
Convenience that forwards to strategy.fit(self, likelihood, y).
Pick any of the site-based strategies in
:mod:pyrox.gp._inference_nongauss:
:class:pyrox.gp.LaplaceInference,
:class:pyrox.gp.GaussNewtonInference,
:class:pyrox.gp.PosteriorLinearization,
:class:pyrox.gp.ExpectationPropagation, or
:class:pyrox.gp.QuasiNewtonInference. Returns a
:class:pyrox.gp.NonGaussConditionedGP with the same
predict / predict_mean / predict_var API as the
Gaussian-likelihood :class:ConditionedGP.
Example::
from pyrox.gp import (
BernoulliLikelihood,
ExpectationPropagation,
GPPrior,
RBF,
)
prior = GPPrior(kernel=RBF(), X=X)
cond = prior.condition_nongauss(
BernoulliLikelihood(), y,
strategy=ExpectationPropagation(),
)
mean, var = cond.predict(X_star)
Source code in src/pyrox/gp/_models.py
log_prob(f)
¶
Marginal log-density of f under the GP prior.
Computes :math:\log \mathcal{N}(f \mid \mu(X), K(X, X) + \text{jitter}\,I)
using :func:gaussx.log_marginal_likelihood, so any solver strategy
on this prior applies.
Source code in src/pyrox/gp/_models.py
mean(X)
¶
Evaluate the mean function at X; zero by default.
sample(key)
¶
Draw f \sim p(f) = \mathcal{N}(\mu(X), K + \text{jitter}\,I).
Wraps the prior in a :class:gaussx.MultivariateNormal with
the configured :attr:solver. This is the non-NumPyro analogue
of :func:gp_sample — useful for tests, diagnostics, and
prior-sample initialization without registering a sample site.
Source code in src/pyrox/gp/_models.py
pyrox.gp.ConditionedGP
¶
Bases: Module
GP conditioned on Gaussian-likelihood training observations.
Holds the precomputed training solve alpha (via
:class:gaussx.PredictionCache) and the noisy covariance operator so
predictions at multiple test sets reuse the training solve.
Source code in src/pyrox/gp/_models.py
predict(X_star)
¶
Return (mean, variance) at X_* as a tuple.
Both kernel evaluations share a single kernel context; see
:meth:predict_var.
Source code in src/pyrox/gp/_models.py
predict_mean(X_star)
¶
:math:\mu_* = \mu(X_*) + K_{*f}\,\alpha.
Source code in src/pyrox/gp/_models.py
predict_var(X_star)
¶
Diagonal predictive variance at X_*.
.. math:: \sigma^2_{,i} = k(x_{,i}, x_{,i}) - K_{f}[i,:] \cdot (K + \sigma^2 I)^{-1} K_{f*}[:,i]
K_cross and K_diag are computed under one shared kernel
context so Pattern B / C kernels with prior'd hyperparameters
register their NumPyro sites once and reuse them across both
kernel calls (and the cached training solve).
Source code in src/pyrox/gp/_models.py
sample(key, X_star, n_samples=1)
¶
Sample from the diagonal predictive N(mean, diag(var)).
Returns samples independently per test point; correlated joint
samples from the full predictive covariance are not covered by
the Wave 2 dense surface. For correlated samples, build the full
predictive covariance explicitly and draw from
:class:gaussx.MultivariateNormal.
Source code in src/pyrox/gp/_models.py
pyrox.gp.gp_factor(name, prior, y, noise_var)
¶
Register the collapsed GP log marginal likelihood with NumPyro.
Adds
log p(y | X, theta) = log N(y | mu, K + (jitter + sigma^2) I)
to the NumPyro trace as numpyro.factor(name, ...). The prior's
jitter is included in addition to the observation noise variance
so the covariance matches what :meth:GPPrior.condition builds. Use
this inside a NumPyro model when the likelihood is Gaussian and you
want the latent function marginalized analytically.
Source code in src/pyrox/gp/_models.py
pyrox.gp.gp_sample(name, prior, *, whitened=False, guide=None)
¶
Sample a latent function f at the prior's training inputs.
Three mutually exclusive modes:
whitened=False,guide=None(default) — register a singlenumpyro.sample(name, MVN(mu, K + jitter I))site. The latent function is sampled directly from the prior.whitened=True,guide=None— register a unit-normal latent sitef"{name}_u"with shape(N,)and return the deterministic valuef = mu(X) + L uwhereLis the Cholesky factor ofK + jitter I. This reparameterization is the standard fix for mean-field SVI on GP-correlated latents (Murray & Adams, 2010): a NumPyro auto-guide such as :class:numpyro.infer.autoguide.AutoNormalthen approximates the well-conditioned isotropic posterior overuinstead of the ill-conditioned correlated posterior overf.guideprovided — delegate toguide.register(name, prior). Concrete variational guides (Wave 3) own their own parameterization, so combiningwhitened=Truewithguideis rejected.
Use this inside a NumPyro model for non-conjugate likelihoods, where the latent function cannot be marginalized analytically.
Source code in src/pyrox/gp/_models.py
Concrete kernels¶
Each Parameterized kernel registers its hyperparameters with positivity
constraints where appropriate. Attach priors with set_prior, autoguides
with autoguide, and flip set_mode("model" | "guide").
pyrox.gp.RBF
¶
Bases: _ParameterizedKernel
Radial basis function (squared exponential) kernel.
Source code in src/pyrox/gp/_kernels.py
pyrox.gp.Matern
¶
Bases: _ParameterizedKernel
Matern kernel with nu in {0.5, 1.5, 2.5}.
nu is a static class attribute — it selects a code path in the
underlying math primitive and is not a trainable parameter.
Source code in src/pyrox/gp/_kernels.py
pyrox.gp.Periodic
¶
Bases: _ParameterizedKernel
Periodic (MacKay) kernel.
Source code in src/pyrox/gp/_kernels.py
pyrox.gp.Linear
¶
Bases: _ParameterizedKernel
Linear kernel sigma^2 x^T x' + bias.
bias is constrained nonnegative because k = sigma^2 X X^T + b 1 1^T
is only PSD for b >= 0 (e.g. X = 0 gives eigenvalue N*b).
Source code in src/pyrox/gp/_kernels.py
pyrox.gp.RationalQuadratic
¶
Bases: _ParameterizedKernel
Rational quadratic kernel.
Source code in src/pyrox/gp/_kernels.py
pyrox.gp.Polynomial
¶
Bases: _ParameterizedKernel
Polynomial kernel sigma^2 (x^T x' + bias)^degree.
degree is a static class field (it selects an integer power, not
an optimization target). bias is constrained nonnegative — the
degree=1 case reduces to :class:Linear and has the same
PSD-requires-b>=0 failure mode.
Source code in src/pyrox/gp/_kernels.py
pyrox.gp.Cosine
¶
Bases: _ParameterizedKernel
Cosine kernel sigma^2 cos(2 pi ||x - x'|| / period).
Source code in src/pyrox/gp/_kernels.py
pyrox.gp.White
¶
Bases: _ParameterizedKernel
White-noise kernel sigma^2 delta(x, x').
Source code in src/pyrox/gp/_kernels.py
pyrox.gp.Constant
¶
Bases: _ParameterizedKernel
Constant kernel k(x, x') = sigma^2.
Source code in src/pyrox/gp/_kernels.py
Sparse-GP inducing features (#49)¶
Inter-domain inducing-feature families used to build scalable sparse GPs
where the inducing-prior covariance K_uu becomes diagonal. Pass any
of these to :class:SparseGPPrior via the inducing= keyword in
place of a raw point matrix Z.
from pyrox.gp import RBF, FourierInducingFeatures, SparseGPPrior
kernel = RBF(init_lengthscale=0.3, init_variance=1.0)
features = FourierInducingFeatures.init(in_features=1, num_basis_per_dim=64, L=5.0)
prior = SparseGPPrior(kernel=kernel, inducing=features) # K_uu is diagonal!
pyrox.gp.InducingFeatures
¶
Bases: Protocol
Protocol for inter-domain inducing features.
Implementations expose the inducing-prior covariance K_uu and the
cross-covariance k_ux(X) between data points and inducing
features. Diagonal-friendly concretions return
:class:lineax.DiagonalLinearOperator so the downstream solve dispatches
to elementwise division.
Input shape is family-dependent. k_ux takes a batch of data
points X in whatever representation the family consumes:
- :class:
FourierInducingFeatures: coordinates(N, D). - :class:
SphericalHarmonicInducingFeatures: unit vectors(N, 3). - :class:
LaplacianInducingFeatures: integer node indices(N,).
Each implementation validates its own expected shape and dtype.
Source code in src/pyrox/gp/_inducing.py
pyrox.gp.FourierInducingFeatures
¶
Bases: Module
VFF — Variational Fourier inducing features on :math:[-L, L]^D.
For a stationary kernel with spectral density :math:S(\cdot), the
basis :math:\{\phi_j\} of Laplacian eigenfunctions on the box gives
.. math::
K_{uu} = \mathrm{diag}\!\big(S(\sqrt{\lambda_j})\big),
\qquad
K_{ux}(x)_j = S(\sqrt{\lambda_j})\,\phi_j(x).
With this convention :math:K_{ux} K_{uu}^{-1} = \phi_j(x), so the
SVGP predictive mean reduces to a basis evaluation. K_{uu} is
returned as a :class:lineax.DiagonalLinearOperator to preserve the
O(M) solve dispatch end-to-end.
Attributes:
| Name | Type | Description |
|---|---|---|
in_features |
int
|
Input dimension :math: |
num_basis_per_dim |
tuple[int, ...]
|
Per-axis number of 1D eigenfunctions; total
count is |
L |
tuple[float, ...]
|
Per-axis box half-width. |
Source code in src/pyrox/gp/_inducing.py
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 | |
K_uu(kernel, *, jitter=1e-06)
¶
Diagonal :math:K_{uu} — entries S(sqrt(lambda_j)) plus jitter.
Source code in src/pyrox/gp/_inducing.py
k_ux(x, kernel)
¶
Cross-covariance entries :math:S(\sqrt{\lambda_j})\,\phi_j(x).
Source code in src/pyrox/gp/_inducing.py
pyrox.gp.SphericalHarmonicInducingFeatures
¶
Bases: Module
VISH — inducing harmonics on :math:S^2 (Dutordoir et al. 2020).
For any zonal kernel :math:k(x, x') = \kappa(x \cdot x') on the
unit 2-sphere, the Funk-Hecke theorem gives a diagonal :math:K_{uu}
whose eigenvalues are the kernel's Funk-Hecke coefficients
:math:a_l. The cross-covariance is :math:a_l\,Y_{lm}(x).
Funk-Hecke coefficients are computed by Gauss-Legendre quadrature (arbitrary kernels supported, no closed form required). For kernels that have a closed-form Funk-Hecke series (RBF on S² via Bessel functions etc.), the numerical and analytic answers should agree to the quadrature tolerance.
Attributes:
| Name | Type | Description |
|---|---|---|
l_max |
int
|
Maximum harmonic degree, inclusive. |
num_quadrature |
int
|
Gauss-Legendre nodes for the Funk-Hecke integral. |
Source code in src/pyrox/gp/_inducing.py
K_uu(kernel, *, jitter=1e-06)
¶
Diagonal :math:K_{uu} — Funk-Hecke coefficients per harmonic.
Source code in src/pyrox/gp/_inducing.py
k_ux(unit_xyz, kernel)
¶
Cross-covariance: :math:a_l\,Y_{lm}(x).
Source code in src/pyrox/gp/_inducing.py
pyrox.gp.LaplacianInducingFeatures
¶
Bases: Module
Inducing features from low-frequency graph Laplacian eigenvectors.
For a graph with normalized Laplacian :math:L, take the smallest
num_basis eigenpairs :math:(\mu_j, v_j). Treating the kernel as
a function of the graph distance — specifically, applying the kernel
spectrum :math:g(\mu) to the Laplacian eigenvalues — gives a
diagonal :math:K_{uu}.
This implementation supports the heat-kernel family
:math:g(\mu) = \exp(-\mu / (2 \ell^2)) (matching :class:pyrox.gp.RBF
in spectrum) by reusing :func:pyrox._basis.spectral_density with the
eigenvalues as input.
Attributes:
| Name | Type | Description |
|---|---|---|
eigvals |
Float[Array, ' M']
|
|
eigvecs |
Float[Array, 'V M']
|
|
num_quadrature |
Float[Array, 'V M']
|
Unused (kept for protocol uniformity). |
Note
X is a vector of node indices (integer-valued), not
coordinates. The returned cross-covariance gathers the relevant
rows of eigvecs.
Source code in src/pyrox/gp/_inducing.py
pyrox.gp.DecoupledInducingFeatures
¶
Bases: Module
Decoupled mean / covariance inducing-feature bases (Cheng & Boots 2017).
Two independent inducing-feature sets:
mean_features: a largealpha-basis used by the SVGP posterior mean (cheap — predictive mean cost is linear in the mean-basis size).cov_features: a smallbeta-basis used for the posterior covariance (the true bottleneck; keep this small).
The two bases need not share the same family — a common pattern is a large Fourier basis for the mean and a small spherical-harmonic basis for the covariance, or vice versa. The downstream guide consumes both via the standard SVGP machinery.
Attributes:
| Name | Type | Description |
|---|---|---|
mean_features |
InducingFeatures
|
Inducing-feature object backing the predictive mean. |
cov_features |
InducingFeatures
|
Inducing-feature object backing the predictive covariance. |
Note
DecoupledInducingFeatures itself does not implement
:class:InducingFeatures (no single K_uu makes sense for two
bases). Consumers should access .mean_features and
.cov_features directly.
Source code in src/pyrox/gp/_inducing.py
pyrox.gp.funk_hecke_coefficients(kernel, l_max, *, num_quadrature=256)
¶
Funk-Hecke coefficients of a zonal kernel on :math:S^2.
For a kernel of the form :math:k(x, x') = \kappa(x \cdot x') on the
unit 2-sphere, the Funk-Hecke theorem gives:
.. math::
a_l = 2\pi \int_{-1}^{1} \kappa(t)\,P_l(t)\,dt.
Returns (l_max + 1,) coefficients indexed by l. We treat any
Euclidean kernel as zonal-on-the-sphere via
:math:\kappa(t) = k_{\mathrm{euc}}(\hat{n}_0, \hat{n}_t) for unit
vectors at angular separation arccos(t).
Source code in src/pyrox/gp/_inducing.py
Sparse GP prior¶
pyrox.gp.SparseGPPrior
¶
Bases: Module
GP prior parameterized over inducing inputs Z.
Represents the zero-mean prior over inducing values u = f(Z)
used by sparse variational guides:
.. math::
p(u) = \mathcal{N}(0,\, K_{ZZ} + \mathrm{jitter}\,I).
The standard SVGP convention is to subtract any global mean function
before forming the prior over u and to add it back at predict
time, so the inducing-prior mean is fixed to zero (this is what the
guides' KL terms assume — see :meth:FullRankGuide.kl_divergence,
:meth:MeanFieldGuide.kl_divergence, :meth:WhitenedGuide.kl_divergence).
The :attr:mean_fn attribute on this class is exposed as a
convenience for callers that want to add mu(X_*) back onto the
predictive mean returned by :meth:Guide.predict; it is not
incorporated in :meth:inducing_operator or in the guides' KL.
Pair with a sparse variational guide that owns q(u) = N(m, S) to
obtain the standard SVGP predictive
.. math::
\mu_*(x) = K_{xZ} K_{ZZ}^{-1} m, \qquad
\sigma_*^2(x) = k(x, x) - K_{xZ} K_{ZZ}^{-1} K_{Zx}
+ K_{xZ} K_{ZZ}^{-1} S K_{ZZ}^{-1} K_{Zx}.
Attributes:
| Name | Type | Description |
|---|---|---|
kernel |
Kernel
|
Any :class: |
Z |
Float[Array, 'M D'] | None
|
Inducing inputs of shape |
mean_fn |
Callable[[Float[Array, 'N D']], Float[Array, ' N']] | None
|
Callable |
solver |
AbstractSolverStrategy | None
|
Any |
jitter |
float
|
Diagonal regularization added to |
Source code in src/pyrox/gp/_sparse.py
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 | |
num_inducing
property
¶
Number of inducing inputs / features M.
cross_covariance(X)
¶
:math:K_{XZ} — covariance between X and the inducing inputs/features.
The expected shape of X is inducing-family-dependent:
- Point-inducing (
Z) or :class:FourierInducingFeatures: coordinates(N, D). - :class:
SphericalHarmonicInducingFeatures: unit vectors(N, 3). - :class:
LaplacianInducingFeatures: integer node indices(N,).
See :meth:predictive_blocks for the shared-context batch
helper to use when assembling several SVGP blocks together.
Source code in src/pyrox/gp/_sparse.py
inducing_operator()
¶
Return K_{ZZ} + \text{jitter}\,I as a lineax operator.
For point-inducing priors, returns a dense
:class:lineax.MatrixLinearOperator with positive_semidefinite_tag.
For inducing-feature priors, delegates to
:meth:InducingFeatures.K_uu — typically a
:class:lineax.DiagonalLinearOperator so the downstream
:func:gaussx.solve dispatches in O(M) instead of O(M^3).
Single kernel call; safe standalone for kernels with priors. For
building several SVGP blocks together, prefer
:meth:predictive_blocks, which scopes one shared kernel
context across K_zz, K_xz, and K_xx_diag so
Pattern B / C kernels register their NumPyro hyperparameter
sites once instead of resampling per call.
Source code in src/pyrox/gp/_sparse.py
kernel_diag(X)
¶
Prior diagonal \mathrm{diag}\,K(X, X) — variance at each x.
See :meth:predictive_blocks for the shared-context batch
helper to use when assembling several SVGP blocks together.
Source code in src/pyrox/gp/_sparse.py
log_prob(u)
¶
Log-density under :math:p(u) = \mathcal{N}(0, K_{ZZ} + \text{jitter}\,I).
Delegates to :func:gaussx.gaussian_log_prob with the
configured :attr:solver so the user-supplied solver controls
the solve / logdet work on K_zz_op. Useful for
scoring inducing values against the SVGP prior in non-NumPyro
contexts (e.g.\ tests, diagnostics).
Source code in src/pyrox/gp/_sparse.py
mean(X)
¶
Evaluate the mean function at X; zero by default.
predictive_blocks(X)
¶
Return (K_zz_op, K_xz, K_xx_diag) under one shared kernel context.
For Pattern B / C kernels with prior'd hyperparameters, the three
kernel evaluations needed for an SVGP predictive must share a
single :class:pyrox.PyroxModule context so the underlying
pyrox_sample sites register once and yield consistent
hyperparameter draws across K_{ZZ}, K_{XZ}, and the
diagonal \mathrm{diag}\,K(X, X). Without this scoping, three
separate calls would draw three independent hyperparameter
samples (under seed) or raise NumPyro duplicate-site errors
(under tracing) — either way invalidating the SVGP math.
For pure :class:equinox.Module kernels (no _get_context),
this is equivalent to calling :meth:inducing_operator,
:meth:cross_covariance, and :meth:kernel_diag independently.
For inducing-feature priors, K_zz_op is a
:class:lineax.DiagonalLinearOperator (jitter folded into the
diagonal vector — never + jnp.eye) so the downstream solve
stays O(M).
Source code in src/pyrox/gp/_sparse.py
sample(key)
¶
Draw u \sim p(u) from the inducing prior.
Wraps the inducing operator in a
:class:gaussx.MultivariateNormal with the configured
:attr:solver. MultivariateNormal.sample factors the
covariance via :func:gaussx.cholesky and reparameterizes;
the returned draw has shape (M,).
Note: the SVGP variational workflow samples u from the
guide :math:q(u), not the prior. This method exists so the
prior surface is symmetric with the guide surface and so users
can score / draw inducing values against the prior directly
(e.g.\ for tests or for prior-sample initialization).
Source code in src/pyrox/gp/_sparse.py
Pathwise posterior samplers (#39)¶
Callable posterior function draws via Matheron's rule. Each sampled
path is a :class:PathwiseFunction that evaluates in
O(N_* · F · D + N_* · N_corr) per path — N_* · F · D for the
RFF prior draw and N_* · N_corr for the kernel correction against
the N_corr training (exact) or inducing (decoupled) points — so the
same draw can be reused at arbitrary test sets without rebuilding a
test-set covariance. Standard enabler for Thompson sampling, Bayesian
optimization, and posterior visualization.
from pyrox.gp import (
RBF,
GPPrior,
PathwiseSampler,
DecoupledPathwiseSampler,
FullRankGuide,
SparseGPPrior,
)
import jax
import jax.numpy as jnp
# Exact GP:
posterior = GPPrior(kernel=RBF(), X=X).condition(y, jnp.array(0.05))
paths = PathwiseSampler(posterior, n_features=512).sample_paths(
jax.random.PRNGKey(0), n_paths=32
)
draws = paths(X_star) # (32, N_star)
# Sparse / decoupled:
sparse = SparseGPPrior(kernel=RBF(), Z=Z)
guide = FullRankGuide.init(Z.shape[0])
paths = DecoupledPathwiseSampler(sparse, guide).sample_paths(key, n_paths=16)
samples = paths(X_star)
Currently supports RBF and Matern kernels. Point-inducing
SparseGPPrior only — inducing-feature priors raise at construction.
pyrox.gp.PathwiseSampler
¶
Bases: Module
Exact-GP pathwise posterior sampler using Matheron's rule.
Given a :class:ConditionedGP, draws a zero-mean RFF prior path
f_tilde and an iid noise draw eps_tilde at the training
inputs, forms the residual y - mu(X) - f_tilde(X) - eps_tilde,
solves it against the cached noisy operator K + (jitter + sigma^2)I,
and stores the result as posterior correction weights. The returned
:class:PathwiseFunction is callable at any X_* in
:math:\mathcal{O}(N_* \cdot F \cdot D + N_* \cdot N) per path,
where N is the number of training (correction) points: the RFF
prior term recomputes features over X_* each call
(N_* · F · D), and the correction term forms a fresh
K(X_*, X) block (N_* · N).
Example
posterior = GPPrior(kernel=RBF(), X=X).condition(y, jnp.array(0.05)) sampler = PathwiseSampler(posterior, n_features=512) paths = sampler.sample_paths(key, n_paths=32) draws = paths(X_star)
Example
sampler = PathwiseSampler(posterior, n_features=1024) thompson = sampler.sample_paths(key, n_paths=1) values = thompson(X_candidates)
Source code in src/pyrox/gp/_pathwise.py
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 | |
__call__(key, X_star, n_paths=1)
¶
Convenience wrapper for sample_paths(key, n_paths)(X_star).
Source code in src/pyrox/gp/_pathwise.py
sample_paths(key, n_paths=1)
¶
Sample callable posterior paths.
key is split into three subkeys: one for the RFF basis,
one for the iid training-noise draw, and one reserved for
future extensions.
Source code in src/pyrox/gp/_pathwise.py
pyrox.gp.DecoupledPathwiseSampler
¶
Bases: Module
Sparse/decoupled pathwise sampler with RFF prior + inducing update.
The prior draw uses random features while the correction is represented in the inducing-point basis, so each sampled path stays callable at arbitrary inputs after a one-time inducing solve.
Supported for point-inducing :class:SparseGPPrior (Z=...);
inducing-feature priors (inducing=...) are rejected at
construction with a clear error.
Handles :class:WhitenedGuide automatically: whitened guide draws
v ~ q(v) are unwhitened to inducing values u = L_ZZ v via
:func:gaussx.unwhiten before forming the inducing-space residual.
Example
prior = SparseGPPrior(kernel=RBF(), Z=Z) guide = FullRankGuide.init(Z.shape[0]) sampler = DecoupledPathwiseSampler(prior, guide, n_features=512) paths = sampler.sample_paths(key, n_paths=16) draws = paths(X_star)
Source code in src/pyrox/gp/_pathwise.py
290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 | |
__call__(key, X_star, n_paths=1)
¶
Convenience wrapper for sample_paths(key, n_paths)(X_star).
Source code in src/pyrox/gp/_pathwise.py
sample_paths(key, n_paths=1)
¶
Sample callable sparse posterior paths.
key is split into three subkeys: one for the RFF basis, one
for n_paths independent guide draws, and one for the
jitter-augmentation of the prior inducing draw. The RFF basis
draw and the :math:K_{zz} assembly share a single
_kernel_context so kernels with hyperparameter priors
(Pattern B / C) sample (variance, lengthscale) once.
The Matheron correction needs Cov(u_tilde) = K_{zz} + \text{jitter}\,I so it matches the operator that the
correction is solved against. The bare RFF draw at Z
produces only the K_{zz} part; we add an iid Gaussian
with variance jitter per inducing index to close the gap —
without this, paths are under-dispersed when jitter is bumped
up for stability.
Source code in src/pyrox/gp/_pathwise.py
pyrox.gp.PathwiseFunction
¶
Bases: Module
Callable posterior function draw(s) produced by a pathwise sampler.
Carries the random-feature prior basis (omega, phase,
feature_weights) and the posterior correction weights evaluated
against either the training inputs (exact) or the inducing inputs
(sparse). Calling the instance on test points X_star evaluates
.. math::
f_{\text{post}}(x_*) =
\tilde{f}(x_*)
+ K(x_*,\, X_{\mathrm{corr}})\,\alpha
+ \mu(x_*),
where :math:\tilde f is the stored RFF prior draw and
:math:X_{\mathrm{corr}} is either the training set (exact) or
the inducing set (sparse).
The kernel enters only as a frozen (X1, X2) -> K callable with
the sample-time variance and lengthscale baked in, so
repeated evaluations stay consistent with the original RFF draw
even for Pattern B/C kernels that register hyperparameter priors.
Example
prior = GPPrior(kernel=RBF(), X=X) posterior = prior.condition(y, noise_var=jnp.array(0.05)) sampler = PathwiseSampler(posterior, n_features=512) paths = sampler.sample_paths(key, n_paths=8) samples = paths(X_star)
Example
sparse_prior = SparseGPPrior(kernel=RBF(), Z=Z) guide = FullRankGuide.init(Z.shape[0]) paths = DecoupledPathwiseSampler(sparse_prior, guide).sample_paths(key) thompson_values = paths(X_candidates)
Source code in src/pyrox/gp/_pathwise.py
__call__(X_star)
¶
Evaluate the sampled function(s) at arbitrary inputs X_star.
Source code in src/pyrox/gp/_pathwise.py
State-space (SDE) kernels¶
Stationary 1-D kernels expressed as linear time-invariant SDEs. Once in
state-space form, GP inference on a 1-D grid reduces to Kalman filtering
in O(N d^3) instead of O(N^3) Cholesky. The protocol exposes
sde_params() -> (F, L, H, Q_c, P_inf) and discretise(dt) -> (A_k, Q_k)
for downstream Kalman / RTS use.
import jax.numpy as jnp
from pyrox.gp import (
ConstantSDE, CosineSDE, MaternSDE, PeriodicSDE,
ProductSDE, QuasiPeriodicSDE, SumSDE,
)
# Primitive kernels
matern = MaternSDE(variance=1.0, lengthscale=0.5, order=1) # nu = 3/2
cos = CosineSDE(variance=1.0, frequency=2.0)
const = ConstantSDE(variance=0.3)
per = PeriodicSDE(variance=1.0, lengthscale=1.0, period=2.0, n_harmonics=7)
# Composition: trend + offset
trend = SumSDE((matern, const)) # state dim = 2 + 1 = 3
# Composition: damped oscillation (Matern x Cosine)
damped = ProductSDE(matern, cos) # state dim = 2 * 2 = 4
# Quasi-periodic (Matern x Periodic) — convenience wrapper around ProductSDE
qp = QuasiPeriodicSDE(matern, per) # state dim = 2 * 15 = 30
pyrox.gp.SDEKernel
¶
Bases: Module
Abstract base for kernels with state-space (SDE) representations.
Stationary kernels with rational spectral densities admit exact finite-dimensional state-space representations of the form
.. math:: d\mathbf{x}(t) = F\,\mathbf{x}(t)\, dt + L\, dw(t), \qquad f(t) = H\,\mathbf{x}(t)
where :math:w(t) is white noise with spectral density :math:Q_c
and :math:P_\infty is the stationary state covariance solving the
Lyapunov equation :math:F P_\infty + P_\infty F^\top + L Q_c L^\top = 0.
Discretisation at time step :math:\Delta t gives
.. math:: A_k = \exp(F\,\Delta t), \qquad Q_k = P_\infty - A_k\,P_\infty\,A_k^\top,
so that :math:x_{k+1} = A_k x_k + q_k with :math:q_k \sim \mathcal{N}(0, Q_k).
Concrete subclasses implement :meth:sde_params returning the
closed-form (F, L, H, Q_c, P_inf) tuple. :meth:discretise
defaults to a generic expm-based implementation; subclasses with
closed-form transitions (e.g. Matern-1/2) may override it.
The continuous-time autocovariance recovered from the SDE is
:math:k(\tau) = H\,\exp(F|\tau|)\,P_\infty\,H^\top for stationary
kernels.
Source code in src/pyrox/gp/_protocols.py
state_dim
abstractmethod
property
¶
State dimension :math:d of the SDE representation.
discretise(dt)
¶
Discretise the SDE at time steps dt.
Default implementation evaluates A_k = expm(F dt_k) via
jax.scipy.linalg.expm and Q_k = P_\infty - A_k P_\infty A_k^\top.
Subclasses with closed-form transitions should override.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dt
|
Float[Array, ' N']
|
|
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, 'N d d'], Float[Array, 'N d d']]
|
Tuple |
Source code in src/pyrox/gp/_protocols.py
sde_params()
abstractmethod
¶
Return (F, L, H, Q_c, P_inf) defining the continuous SDE.
Source code in src/pyrox/gp/_protocols.py
pyrox.gp.MaternSDE
¶
Bases: SDEKernel
Matern kernel in state-space (companion) form for order in {0, 1, 2}.
The Matern-:math:\nu kernel with :math:\nu = p + 1/2 for
:math:p \in \{0, 1, 2\} has an exact :math:d = p + 1 dimensional
SDE representation. The closed-form parameters are:
- Matern-1/2 (
order=0, :math:d=1): :math:\lambda = 1/\ell,
.. math:: F = [-\lambda],\quad L = [1],\quad H = [1],\quad Q_c = 2\sigma^2\lambda,\quad P_\infty = \sigma^2.
- Matern-3/2 (
order=1, :math:d=2): :math:\lambda = \sqrt{3}/\ell,
.. math:: F = \begin{pmatrix} 0 & 1 \ -\lambda^2 & -2\lambda \end{pmatrix}, \quad L = \begin{pmatrix} 0 \ 1 \end{pmatrix},\quad H = \begin{pmatrix} 1 & 0 \end{pmatrix},
.. math:: Q_c = 4\sigma^2\lambda^3,\quad P_\infty = \sigma^2\,\mathrm{diag}(1,\;\lambda^2).
- Matern-5/2 (
order=2, :math:d=3): :math:\lambda = \sqrt{5}/\ell,
.. math:: F = \begin{pmatrix} 0 & 1 & 0 \ 0 & 0 & 1 \ -\lambda^3 & -3\lambda^2 & -3\lambda \end{pmatrix}, \quad L = \begin{pmatrix} 0 \ 0 \ 1 \end{pmatrix},\quad H = \begin{pmatrix} 1 & 0 & 0 \end{pmatrix},
.. math:: Q_c = \tfrac{16}{3}\sigma^2\lambda^5,\quad P_\infty = \sigma^2 \begin{pmatrix} 1 & 0 & -\lambda^2/3 \ 0 & \lambda^2/3 & 0 \ -\lambda^2/3 & 0 & \lambda^4 \end{pmatrix}.
order is a static (Python int) field — it picks a code path,
not a trainable parameter. variance and lengthscale are
JAX-traced scalars suitable for autograd.
Examples:
>>> import jax.numpy as jnp
>>> sde = MaternSDE(variance=1.0, lengthscale=0.5, order=1)
>>> F, L, H, Q_c, P_inf = sde.sde_params()
>>> A, Q = sde.discretise(jnp.array([0.1, 0.2, 0.3]))
>>> A.shape, Q.shape
((3, 2, 2), (3, 2, 2))
References
Sarkka & Solin (2019), Applied Stochastic Differential Equations, Ch. 12; Hartikainen & Sarkka (2010), Kalman Filtering and Smoothing Solutions to Temporal Gaussian Process Regression Models, IEEE MLSP.
Source code in src/pyrox/gp/_sde_kernels.py
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | |
nu
property
¶
Smoothness nu = order + 1/2.
state_dim
property
¶
State dimension d = order + 1.
sde_params()
¶
Return (F, L, H, Q_c, P_inf) for the chosen Matern order.
Source code in src/pyrox/gp/_sde_kernels.py
pyrox.gp.ConstantSDE
¶
Bases: SDEKernel
Constant kernel :math:k(\tau) = \sigma^2 in state-space form.
A degenerate 1-D state space with zero dynamics and zero diffusion:
.. math:: F = [0],\quad L = [0],\quad H = [1],\quad Q_c = [0],\quad P_\infty = [\sigma^2].
The transition is the identity A_k = I and the process noise is
zero Q_k = 0. Useful as a non-trivial component of a
:class:SumSDE (e.g. Matern + Constant for a fixed offset).
Source code in src/pyrox/gp/_sde_kernels.py
pyrox.gp.CosineSDE
¶
Bases: SDEKernel
Cosine kernel :math:k(\tau) = \sigma^2 \cos(\omega_0 \tau) in SDE form.
A 2-D deterministic oscillator with rotation matrix transitions:
.. math:: F = \begin{pmatrix} 0 & -\omega_0 \ \omega_0 & 0 \end{pmatrix}, \quad L = \begin{pmatrix} 0 \ 0 \end{pmatrix},\quad H = \begin{pmatrix} 1 & 0 \end{pmatrix},
.. math:: Q_c = 0,\quad P_\infty = \sigma^2 I_2.
There is no driving noise, so the discrete-time transition is a pure
rotation :math:A_k = R(\omega_0\,\Delta t_k) and :math:Q_k = 0.
The :meth:discretise method overrides the default expm path
with the closed-form rotation for efficiency.
Source code in src/pyrox/gp/_sde_kernels.py
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 | |
discretise(dt)
¶
Closed-form rotation: A_k = R(omega * dt_k), Q_k = 0.
Source code in src/pyrox/gp/_sde_kernels.py
pyrox.gp.PeriodicSDE
¶
Bases: SDEKernel
Periodic kernel in state-space form via Fourier-series truncation.
The MacKay periodic kernel
:math:k(\tau) = \sigma^2 \exp\!\bigl(-2 \sin^2(\pi\tau/T)/\ell^2\bigr)
expands as
.. math:: k(\tau) = \sigma^2 e^{-1/\ell^2} \Bigl[I_0(1/\ell^2) + 2 \sum_{j=1}^\infty I_j(1/\ell^2) \cos(j\,\omega_0 \tau)\Bigr],
with :math:\omega_0 = 2\pi/T. Truncating to J = n_harmonics
cosines gives a deterministic state-space model whose state collects
a degenerate 1-D constant block (the :math:j=0 DC term) and J
rotation blocks, one per harmonic. Total state dimension is
:math:1 + 2J, L = 0, Q_c = 0 (no driving noise), and
:math:P_\infty is block-diagonal with entries
.. math:: q_0 = \sigma^2 e^{-1/\ell^2} I_0(1/\ell^2),\qquad q_j = 2 \sigma^2 e^{-1/\ell^2} I_j(1/\ell^2)\quad (j \geq 1).
The scaled modified Bessel coefficients are computed by
:func:_scaled_bessel_i_seq using a log-space Taylor-series
accumulation (logsumexp over (j + 2k) log(x/2) - x - log(k!) -
log((k+j)!)). For n_harmonics around 7 the truncation matches
the dense MacKay periodic kernel to better than 1e-6 across the
typical hyperparameter regime.
References
Solin & Sarkka (2014), Explicit Link Between Periodic Covariance Functions and State Space Models, AISTATS.
Source code in src/pyrox/gp/_sde_kernels.py
498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 | |
discretise(dt)
¶
Closed-form discretisation: harmonic block rotations, Q_k = 0.
The drift F is exactly block-diagonal with (0) (DC mode)
and J skew-symmetric 2x2 rotation generators
:math:F_j = \mathrm{skew}(j\,\omega_0). The matrix exponential
of each block has a closed form (identity for the DC mode and a
2-D rotation for each harmonic), and Q_k vanishes identically
because the diffusion is zero. Using the closed form avoids the
float32 expm accumulation that affects CosineSDE for
large j * omega_0 * dt.
Source code in src/pyrox/gp/_sde_kernels.py
pyrox.gp.SumSDE
¶
Bases: SDEKernel
Sum of SDE kernels via block-diagonal state-space composition.
For :math:k(\tau) = \sum_i k_i(\tau), the SDE is the block-diagonal
concatenation of each component:
.. math:: F = \mathrm{blkdiag}(F_1, \dots, F_K),\quad L = \mathrm{blkdiag}(L_1, \dots, L_K),\quad Q_c = \mathrm{blkdiag}(Q_{c,1}, \dots, Q_{c,K}),
.. math:: H = [H_1, \dots, H_K],\quad P_\infty = \mathrm{blkdiag}(P_{\infty,1}, \dots, P_{\infty,K}).
Total state dimension is :math:\sum_i d_i. Components with disjoint
state spaces evolve independently.
Source code in src/pyrox/gp/_sde_kernels.py
pyrox.gp.ProductSDE
¶
Bases: SDEKernel
Product of two SDE kernels via Kronecker composition.
For :math:k(\tau) = k_1(\tau)\,k_2(\tau), the joint SDE has
Kronecker-sum drift and Kronecker-product readout / stationary
covariance:
.. math:: F = F_1 \otimes I_{d_2} + I_{d_1} \otimes F_2,\quad H = H_1 \otimes H_2,\quad P_\infty = P_{\infty,1} \otimes P_{\infty,2}.
The diffusion is not a simple Kronecker product. Substituting into the Lyapunov equation yields
.. math:: L Q_c L^\top = (L_1 Q_{c,1} L_1^\top) \otimes P_{\infty,2} + P_{\infty,1} \otimes (L_2 Q_{c,2} L_2^\top).
For simplicity we set :math:L = I_{d_1 d_2} and store the right-hand
side as Q_c directly. Total state dimension is :math:d_1 \cdot d_2.
Source code in src/pyrox/gp/_sde_kernels.py
pyrox.gp.QuasiPeriodicSDE
¶
Bases: ProductSDE
Quasi-periodic kernel: :math:k(\tau) = k_{\rm Mat}(\tau)\,k_{\rm Per}(\tau).
A thin documented subclass of :class:ProductSDE that captures the
standard Matern :math:\times Periodic decomposition used for
modulated periodic signals (stellar light curves, modulated seasonal
patterns). The Matern envelope sets the timescale on which the
amplitude drifts; the periodic factor sets the cycle.
Example
import jax.numpy as jnp qp = QuasiPeriodicSDE( ... MaternSDE(variance=1.0, lengthscale=2.0, order=1), ... PeriodicSDE(variance=1.0, lengthscale=1.0, period=1.0, n_harmonics=5), ... ) qp.state_dim 22
References
Sarkka & Solin (2019), Applied Stochastic Differential Equations, Sec. 12.3; Wilkinson et al. (2021), BayesNewton.
Source code in src/pyrox/gp/_sde_kernels.py
Markov GP — Kalman / RTS workflow¶
MarkovGPPrior consumes any SDEKernel over a sorted
1-D grid and gives O(N d^3) marginal likelihood (forward Kalman filter)
and posterior smoothing (backward RTS), where d is the SDE state
dimension. Use it for temporal GP regression / forecasting when the
training grid lives on a single time axis. Predictions at arbitrary
test times — including forecasting, backcasting, and within-window
interpolation — re-run the filter+smoother over the merged grid with the
test points masked out of the update step.
import jax.numpy as jnp
from pyrox.gp import MaternSDE, MarkovGPPrior, markov_gp_factor
times = jnp.linspace(0.0, 5.0, 200)
y = jnp.sin(times) + 0.05 * jnp.cos(7.0 * times)
prior = MarkovGPPrior(
MaternSDE(variance=1.0, lengthscale=0.5, order=1), # Matern-3/2
times,
)
log_marg = prior.log_marginal(y, jnp.asarray(0.01)) # Kalman forward
cond = prior.condition(y, jnp.asarray(0.01)) # filter + RTS smoother
mean, var = cond.predict(jnp.linspace(-0.5, 6.0, 50)) # arbitrary test times
Inside a NumPyro model, swap gp_factor for markov_gp_factor:
import jax.numpy as jnp
import numpyro
from numpyro import distributions as dist
from pyrox.gp import MarkovGPPrior, MaternSDE, markov_gp_factor
def temporal_model(times, y):
sigma2 = numpyro.sample("variance", dist.LogNormal(0.0, 1.0))
ell = numpyro.sample("lengthscale", dist.LogNormal(0.0, 1.0))
sde = MaternSDE(variance=sigma2, lengthscale=ell, order=1)
prior = MarkovGPPrior(sde, times)
markov_gp_factor("obs", prior, y, jnp.array(0.01))
Currently scoped to Gaussian-likelihood regression on a single time axis. Non-Gaussian likelihoods on top of the Markov path (CVI, EP) and spatio-temporal Markov priors land in later waves.
pyrox.gp.MarkovGPPrior
¶
Bases: Module
Linear-time temporal GP prior over a sorted 1-D grid.
Wraps any :class:pyrox.gp.SDEKernel (e.g. :class:pyrox.gp.MaternSDE,
:class:pyrox.gp.SumSDE, :class:pyrox.gp.PeriodicSDE) to give Kalman
filtering for the marginal log-likelihood and RTS smoothing for the
posterior on the training grid. Supports an optional mean function and a
small observation-noise floor for numerical stability.
Attributes:
| Name | Type | Description |
|---|---|---|
sde_kernel |
SDEKernel
|
Any :class: |
times |
Float[Array, ' N']
|
Sorted, strictly increasing observation times of shape
|
mean_fn |
Callable[[Float[Array, ' N']], Float[Array, ' N']] | None
|
Optional callable mapping |
obs_noise_floor |
float
|
Small extra diagonal added to the observation
variance |
Examples:
>>> import jax.numpy as jnp
>>> from pyrox.gp import MaternSDE, MarkovGPPrior
>>> times = jnp.linspace(0.0, 5.0, 50)
>>> sde = MaternSDE(variance=1.0, lengthscale=0.5, order=1)
>>> prior = MarkovGPPrior(sde, times)
>>> y = jnp.sin(times) + 0.05 * jnp.cos(3.0 * times)
>>> log_marg = prior.log_marginal(y, jnp.asarray(0.01))
Notes
The solver-strategy plumbing used by :class:pyrox.gp.GPPrior does
not apply here — Kalman filtering is its own linear-algebra path
and does not factor through gaussx.AbstractSolverStrategy.
Source code in src/pyrox/gp/_markov.py
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 | |
state_dim
property
¶
SDE state dimension :math:d for this kernel.
condition(y, noise_var)
¶
Condition on Gaussian-likelihood observations via filter + smoother.
Source code in src/pyrox/gp/_markov.py
condition_nongauss(likelihood, y, *, strategy)
¶
Condition on a non-Gaussian likelihood via a site-based strategy.
Convenience that forwards to strategy.fit(self, likelihood, y).
Pick any of the Markov-aware site-based strategies in
:mod:pyrox.gp._inference_nongauss_markov:
:class:pyrox.gp.LaplaceMarkovInference,
:class:pyrox.gp.GaussNewtonMarkovInference,
:class:pyrox.gp.PosteriorLinearizationMarkov, or
:class:pyrox.gp.ExpectationPropagationMarkov. Returns a
:class:pyrox.gp.NonGaussConditionedMarkovGP with the same
predict API as the Gaussian-likelihood
:class:ConditionedMarkovGP.
Source code in src/pyrox/gp/_markov.py
filter(y, noise_var)
¶
Run the forward Kalman filter on the training grid.
Returns:
| Type | Description |
|---|---|
Float[Array, 'N d']
|
Tuple |
Float[Array, 'N d d']
|
each |
Float[Array, 'N d']
|
|
Float[Array, 'N d d']
|
|
Source code in src/pyrox/gp/_markov.py
log_marginal(y, noise_var)
¶
Marginal log-likelihood log p(y | theta) via Kalman filtering.
log_prob(f)
¶
Log density of an exact-state path :math:f(t_n) = H x_n under the prior.
Evaluates log N(f | mu(times), K_NN) where K_NN is the dense
Gram of the kernel encoded by sde_kernel on self.times.
Computes the dense covariance via H exp(F |t_i - t_j|) P_inf H^T
— one expm per pairwise lag, costing :math:O(N^2 d^3) for the
Gram plus :math:O(N^3) for the Cholesky solve — intended for
sanity checks and small-grid use rather than scalable inference.
For training, prefer :meth:log_marginal.
Source code in src/pyrox/gp/_markov.py
mean(times)
¶
Evaluate the mean function at times; zero by default.
smooth(y, noise_var)
¶
Run filter + RTS smoother on the training grid.
Returns (m_smooth, P_smooth, log_marginal) over the training
times.
Source code in src/pyrox/gp/_markov.py
pyrox.gp.ConditionedMarkovGP
¶
Bases: Module
Markov GP conditioned on Gaussian-likelihood observations.
Holds the smoothed posterior on the training grid plus the marginal
log-likelihood. Use :meth:predict for marginal posterior mean / variance
at arbitrary test times.
Attributes:
| Name | Type | Description |
|---|---|---|
prior |
MarkovGPPrior
|
The originating :class: |
y |
Float[Array, ' N']
|
Observations of shape |
noise_var |
Float[Array, '']
|
Observation variance used for conditioning. |
smoothed_means |
Float[Array, 'N d']
|
|
smoothed_covs |
Float[Array, 'N d d']
|
|
log_marginal |
Float[Array, '']
|
Scalar :math: |
Source code in src/pyrox/gp/_markov.py
predict(t_star)
¶
Predictive marginals (mean, var) at arbitrary test times.
Implementation: re-run the filter+smoother over the merged grid
sort(times \\cup t_star) with the test points masked out of the
update step, then read off the smoothed marginals at the test
positions via H @ m and H @ P @ H^T. Cost is
:math:O((N + M)\\,d^3). Handles training-grid lookups, forecasting,
backcasting, and within-window interpolation under one code path.
Source code in src/pyrox/gp/_markov.py
pyrox.gp.markov_gp_factor(name, prior, y, noise_var)
¶
Register the collapsed Markov-GP marginal log-likelihood with NumPyro.
Computes log p(y | times, theta) via Kalman filtering and adds it as
numpyro.factor(name, ...). Use this inside a NumPyro model for
Gaussian-likelihood temporal GP regression — the latent function is
marginalized analytically.
Source code in src/pyrox/gp/_markov.py
pyrox.gp.markov_gp_sample(name, prior)
¶
Sample a latent function f at the prior's training times.
Registers a single numpyro.sample(name, MVN(mu, K)) site where K
is the dense Gram derived from the SDE autocovariance
H exp(F|tau|) P_inf H^T. This is the simple, dense path — use it
when N is small. Scalable Markov-aware sample sites land in a
later wave alongside non-Gaussian likelihood support.
Source code in src/pyrox/gp/_markov.py
Component protocols¶
Abstract pyrox-local bases for the orthogonal component stack. Wave 2
ships only the contracts for Guide, Integrator, and Likelihood —
concrete implementations land in later waves. Solver strategies live in
gaussx._strategies.
pyrox.gp.Kernel
¶
Bases: Module
Abstract base for GP covariance functions.
Subclasses implement :meth:__call__ returning the Gram matrix on a pair
of input batches. :meth:gram and :meth:diag are convenience defaults
that derive from :meth:__call__; structured subclasses (Kronecker,
state-space, etc.) should override them for efficiency.
Source code in src/pyrox/gp/_protocols.py
diag(X)
¶
Diagonal of K(X, X).
Default implementation extracts the diagonal of the full Gram. For
stationary kernels with constant diagonal, override with a vectorized
broadcast for the O(N) shortcut.
Source code in src/pyrox/gp/_protocols.py
pyrox.gp.Guide
¶
Bases: Module
Abstract base for variational posterior families.
Concrete guides (DeltaGuide, MeanFieldGuide, LowRankGuide,
FullRankGuide, etc.) land in the dedicated guide waves (#28, #29).
The whitening principle keeps optimization geometry well-conditioned —
sample from a unit-scale latent and unwhiten with the prior Cholesky.
Two distinct entry points:
- :meth:
sample/ :meth:log_prob— pure variational draws and densities.sample(self, key)returns a draw fromq(f);log_prob(self, f)evaluateslog q(f). Neither touches the NumPyro trace. register(name, prior)(optional) — the NumPyro-integration hook invoked by :func:pyrox.gp.gp_samplewhen a guide is supplied. Use it to register a sample / param site (or compose one out of guide state) undernameand return the latent function value. Concrete guides that participate in :func:gp_sampleshould implement this; the protocol leaves it unspecified so guides usable purely outside NumPyro stay valid.
Source code in src/pyrox/gp/_protocols.py
pyrox.gp.Integrator
¶
Bases: Module
Abstract base for Gaussian-expectation integrators.
Computes :math:\mathbb{E}_{q(f)}[g(f)] where q(f) = N(mean, var).
Concrete integrators (Gauss-Hermite, sigma-points, cubature, Taylor,
Monte Carlo) land in later waves and may delegate to gaussx's
quadrature primitives.
Source code in src/pyrox/gp/_protocols.py
pyrox.gp.Likelihood
¶
Bases: Module
Abstract base for observation models.
Implements the conditional p(y | f) and a default
:meth:expected_log_prob that integrates over a Gaussian latent via an
:class:Integrator. Concrete scalar-latent likelihoods
(:class:GaussianLikelihood, :class:BernoulliLikelihood,
:class:PoissonLikelihood, :class:StudentTLikelihood) and
multi-latent ones (:class:SoftmaxLikelihood,
:class:HeteroscedasticGaussianLikelihood) live in
:mod:pyrox.gp._likelihoods.
Multi-latent likelihoods declare latent_dim: int as a static
field (e.g. latent_dim = num_classes for softmax). Scalar
likelihoods may omit the field; consumers should read
getattr(lik, "latent_dim", 1).
Source code in src/pyrox/gp/_protocols.py
Math primitives¶
Pure JAX kernel functions. Stateless, differentiable, composable —
(Array, ..., hyperparams) -> Gram. No NumPyro, no protocols.
Pure JAX kernel evaluation primitives — math definitions only.
Each function takes two input matrices X1 of shape (N1, D),
X2 of shape (N2, D), hyperparameters as JAX arrays, and returns
the (N1, N2) Gram matrix. All inputs are 2-D; callers that have 1-D
arrays must add a trailing singleton dimension first.
These are the canonical closed-form math for each kernel — small,
readable, and tutorial-facing. The companion scalable construction
surface (numerically stable matrix assembly, mixed-precision
accumulation, implicit/structured operators, batched matvec) lives in
gaussx; see :func:gaussx.stable_rbf_kernel and
:class:gaussx.ImplicitKernelOperator for the production path.
The composition helpers :func:kernel_add and :func:kernel_mul act on
already-evaluated Gram matrices, not on callables. Higher-level
:class:pyrox.gp.Kernel classes (Wave 2 Layer 1, see issue #20) compose
callables and may opt in to gaussx's scalable variants when needed.
Index axes are named via :mod:einops (einsum / rearrange) rather
than raw broadcasting so shape intent stays legible at the call site.
constant_kernel(X1, X2, variance)
¶
Constant kernel.
.. math:: k(x, x') = \sigma^2
A rank-one kernel useful as a scalar offset additive component.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X1
|
Float[Array, 'N1 D']
|
|
required |
X2
|
Float[Array, 'N2 D']
|
|
required |
variance
|
Float[Array, '']
|
Scalar value. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N1 N2']
|
|
Source code in src/pyrox/gp/_src/kernels.py
cosine_kernel(X1, X2, variance, period)
¶
Cosine kernel.
.. math:: k(x, x') = \sigma^2 \cos!\left( \frac{2 \pi |x - x'|}{p} \right)
Useful as a simple periodic building block alongside
:func:periodic_kernel; unlike the Mackay form this one uses plain
cosine of distance and can go negative.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X1
|
Float[Array, 'N1 D']
|
|
required |
X2
|
Float[Array, 'N2 D']
|
|
required |
variance
|
Float[Array, '']
|
Scalar signal variance. |
required |
period
|
Float[Array, '']
|
Scalar period. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N1 N2']
|
|
Source code in src/pyrox/gp/_src/kernels.py
kernel_add(K1, K2)
¶
kernel_mul(K1, K2)
¶
Pointwise (Hadamard) product of two already-evaluated Gram matrices.
linear_kernel(X1, X2, variance, bias)
¶
Linear kernel.
.. math:: k(x, x') = \sigma^2\, x^\top x' + b
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X1
|
Float[Array, 'N1 D']
|
|
required |
X2
|
Float[Array, 'N2 D']
|
|
required |
variance
|
Float[Array, '']
|
Scalar variance multiplier on the dot product. |
required |
bias
|
Float[Array, '']
|
Scalar additive bias. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N1 N2']
|
|
Source code in src/pyrox/gp/_src/kernels.py
matern_kernel(X1, X2, variance, lengthscale, nu)
¶
Matern kernel with closed-form nu in {1/2, 3/2, 5/2}.
.. math:: k(x, x') = \sigma^2\, f_\nu(r / \ell), \qquad r = |x - x'|
Only the three common half-integer orders are supported because those
admit closed-form expressions without Bessel evaluations. nu is a
static Python float (not a JAX array) so the branch specializes at
trace time.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X1
|
Float[Array, 'N1 D']
|
|
required |
X2
|
Float[Array, 'N2 D']
|
|
required |
variance
|
Float[Array, '']
|
Scalar signal variance. |
required |
lengthscale
|
Float[Array, '']
|
Scalar lengthscale. |
required |
nu
|
float
|
Smoothness parameter; must be |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N1 N2']
|
|
Source code in src/pyrox/gp/_src/kernels.py
periodic_kernel(X1, X2, variance, lengthscale, period)
¶
Periodic (MacKay) kernel.
.. math:: k(x, x') = \sigma^2 \exp!\left( -\frac{2 \sin^2(\pi |x - x'| / p)}{\ell^2} \right)
For multi-dimensional inputs the argument uses the Euclidean distance, matching the common GPML convention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X1
|
Float[Array, 'N1 D']
|
|
required |
X2
|
Float[Array, 'N2 D']
|
|
required |
variance
|
Float[Array, '']
|
Scalar signal variance. |
required |
lengthscale
|
Float[Array, '']
|
Scalar lengthscale. |
required |
period
|
Float[Array, '']
|
Scalar period |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N1 N2']
|
|
Source code in src/pyrox/gp/_src/kernels.py
polynomial_kernel(X1, X2, variance, bias, degree)
¶
Polynomial kernel.
.. math:: k(x, x') = \sigma^2 \bigl(x^\top x' + b\bigr)^d
:func:linear_kernel is the special case degree == 1 without the
outer power. degree is a static Python int so the kernel specializes
at trace time.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X1
|
Float[Array, 'N1 D']
|
|
required |
X2
|
Float[Array, 'N2 D']
|
|
required |
variance
|
Float[Array, '']
|
Scalar multiplier. |
required |
bias
|
Float[Array, '']
|
Scalar additive bias inside the power. |
required |
degree
|
int
|
Positive integer polynomial degree. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N1 N2']
|
|
Source code in src/pyrox/gp/_src/kernels.py
rational_quadratic_kernel(X1, X2, variance, lengthscale, alpha)
¶
Rational quadratic kernel.
.. math:: k(x, x') = \sigma^2 \left( 1 + \frac{|x - x'|^2}{2\alpha \ell^2} \right)^{-\alpha}
Scale mixture of RBF kernels: the limit alpha -> infty recovers the
RBF, small alpha yields heavier-tailed correlations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X1
|
Float[Array, 'N1 D']
|
|
required |
X2
|
Float[Array, 'N2 D']
|
|
required |
variance
|
Float[Array, '']
|
Scalar signal variance. |
required |
lengthscale
|
Float[Array, '']
|
Scalar lengthscale. |
required |
alpha
|
Float[Array, '']
|
Scalar shape parameter; must be positive. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N1 N2']
|
|
Source code in src/pyrox/gp/_src/kernels.py
rbf_kernel(X1, X2, variance, lengthscale)
¶
Radial basis function (squared exponential) kernel.
.. math:: k(x, x') = \sigma^2 \exp!\left(-\frac{|x - x'|^2}{2\ell^2}\right)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X1
|
Float[Array, 'N1 D']
|
|
required |
X2
|
Float[Array, 'N2 D']
|
|
required |
variance
|
Float[Array, '']
|
Scalar signal variance |
required |
lengthscale
|
Float[Array, '']
|
Scalar lengthscale |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N1 N2']
|
|
Source code in src/pyrox/gp/_src/kernels.py
white_kernel(X1, X2, variance)
¶
White-noise kernel.
.. math:: k(x, x') = \sigma^2 \,\delta(x, x')
Nonzero only where X1[i] exactly matches X2[j] across all feature
dimensions. When evaluated at X1 == X2 this yields sigma^2 * I.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X1
|
Float[Array, 'N1 D']
|
|
required |
X2
|
Float[Array, 'N2 D']
|
|
required |
variance
|
Float[Array, '']
|
Scalar noise variance. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N1 N2']
|
|