Skip to content

GP API

Wave 2 ships the dense-GP foundation: kernel math functions, concrete Parameterized kernel classes, abstract component protocols, and the model-facing entry points (GPPrior, ConditionedGP, gp_factor, gp_sample). Scalable matrix construction and solver strategies (numerically stable assembly, implicit operators, batched matvec, Cholesky / CG / BBMM / LSMR / SLQ) live in gaussx.

Split with gaussx

pyrox owns the kernel function side — closed-form math primitives readable in a dozen lines — plus the NumPyro-aware model shell (GPPrior, gp_factor, gp_sample). gaussx owns every piece of linear algebra: stable matrix construction, solver strategies, and the underlying MultivariateNormal distribution. The model entry points accept any gaussx.AbstractSolverStrategy (default gaussx.DenseSolver()).

Model entry points

import jax.numpy as jnp
import numpyro
from pyrox.gp import GPPrior, RBF, gp_factor, gp_sample

def regression_model(X, y):
    """Collapsed Gaussian-likelihood GP regression."""
    kernel = RBF()
    prior = GPPrior(kernel=kernel, X=X)
    gp_factor("obs", prior, y, noise_var=jnp.array(0.05))


def latent_model(X):
    """Latent-function GP for non-conjugate likelihoods."""
    kernel = RBF()
    prior = GPPrior(kernel=kernel, X=X)
    f = gp_sample("f", prior)
    # ... attach any likelihood to f here, e.g. Bernoulli or Poisson.

Swap the solver strategy at construction time:

from gaussx import CGSolver, ComposedSolver, DenseLogdet, DenseSolver
prior = GPPrior(kernel=RBF(), X=X, solver=CGSolver())
# Or compose — CG for solve, dense Cholesky for logdet:
prior = GPPrior(
    kernel=RBF(), X=X,
    solver=ComposedSolver(solve_strategy=CGSolver(), logdet_strategy=DenseLogdet()),
)

pyrox.gp.GPPrior

Bases: Module

Finite-dimensional GP prior over a fixed training input set.

Holds a kernel, training inputs X, an optional mean function, an optional solver strategy, and a small diagonal jitter for numerical stability on otherwise-singular prior covariances.

Attributes:

Name Type Description
kernel Kernel

Any :class:pyrox.gp.Kernel — evaluated on X.

X Float[Array, 'N D']

Training inputs of shape (N, D).

mean_fn Callable[[Float[Array, 'N D']], Float[Array, ' N']] | None

Callable X -> (N,) or None for the zero mean.

solver AbstractSolverStrategy | None

Any gaussx.AbstractSolverStrategy. Defaults to gaussx.DenseSolver() — swap for CGSolver, BBMMSolver, ComposedSolver(solve=..., logdet=...), etc.

jitter float

Diagonal regularization added to the prior covariance for numerical stability. Not a noise model — use noise_var on :meth:condition for that.

Source code in src/pyrox/gp/_models.py
class GPPrior(eqx.Module):
    """Finite-dimensional GP prior over a fixed training input set.

    Holds a kernel, training inputs ``X``, an optional mean function, an
    optional solver strategy, and a small diagonal jitter for numerical
    stability on otherwise-singular prior covariances.

    Attributes:
        kernel: Any :class:`pyrox.gp.Kernel` — evaluated on ``X``.
        X: Training inputs of shape ``(N, D)``.
        mean_fn: Callable ``X -> (N,)`` or ``None`` for the zero mean.
        solver: Any ``gaussx.AbstractSolverStrategy``. Defaults to
            ``gaussx.DenseSolver()`` — swap for ``CGSolver``,
            ``BBMMSolver``, ``ComposedSolver(solve=..., logdet=...)``, etc.
        jitter: Diagonal regularization added to the prior covariance
            for numerical stability. Not a noise model — use
            ``noise_var`` on :meth:`condition` for that.
    """

    kernel: Kernel
    X: Float[Array, "N D"]
    mean_fn: Callable[[Float[Array, "N D"]], Float[Array, " N"]] | None = None
    solver: AbstractSolverStrategy | None = None
    jitter: float = 1e-6

    def mean(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
        """Evaluate the mean function at ``X``; zero by default."""
        if self.mean_fn is None:
            return jnp.zeros(X.shape[0], dtype=X.dtype)
        return self.mean_fn(X)

    def _prior_operator(self) -> lx.AbstractLinearOperator:
        K = self.kernel(self.X, self.X)
        K = K + self.jitter * jnp.eye(K.shape[0], dtype=K.dtype)
        return _psd_operator(K)

    def _noisy_operator(self, noise_var: Float[Array, ""]) -> lx.AbstractLinearOperator:
        K = self.kernel(self.X, self.X)
        reg = (self.jitter + noise_var) * jnp.eye(K.shape[0], dtype=K.dtype)
        return _psd_operator(K + reg)

    def _resolved_solver(self) -> AbstractSolverStrategy:
        return DenseSolver() if self.solver is None else self.solver  # ty: ignore[invalid-return-type]

    def log_prob(self, f: Float[Array, " N"]) -> Float[Array, ""]:
        r"""Marginal log-density of ``f`` under the GP prior.

        Computes :math:`\log \mathcal{N}(f \mid \mu(X), K(X, X) + \text{jitter}\,I)`
        using :func:`gaussx.log_marginal_likelihood`, so any solver strategy
        on this prior applies.
        """
        return log_marginal_likelihood(
            self.mean(self.X),
            self._prior_operator(),
            f,
            solver=self._resolved_solver(),
        )

    def sample(self, key: Array) -> Float[Array, " N"]:
        r"""Draw ``f \sim p(f) = \mathcal{N}(\mu(X), K + \text{jitter}\,I)``.

        Wraps the prior in a :class:`gaussx.MultivariateNormal` with
        the configured :attr:`solver`. This is the non-NumPyro analogue
        of :func:`gp_sample` — useful for tests, diagnostics, and
        prior-sample initialization without registering a sample site.
        """
        op = self._prior_operator()
        loc = self.mean(self.X)
        mvn = MultivariateNormal(loc, op, solver=self._resolved_solver())
        return mvn.sample(key)

    def condition(
        self,
        y: Float[Array, " N"],
        noise_var: Float[Array, ""],
    ) -> ConditionedGP:
        """Condition on Gaussian-likelihood observations ``y``.

        Precomputes
        ``alpha = (K + (jitter + noise_var) * I)^{-1} (y - mu(X))`` and
        caches it in the returned :class:`ConditionedGP`. The same
        ``jitter`` regularization configured on this prior is included
        alongside ``noise_var`` in the conditioned operator and solve, so
        every downstream predict / sample call sees the regularized
        covariance.
        """
        operator = self._noisy_operator(noise_var)
        residual = y - self.mean(self.X)
        cache = build_prediction_cache(
            operator, residual, solver=self._resolved_solver()
        )
        return ConditionedGP(  # ty: ignore[invalid-return-type]
            prior=self,
            y=y,
            noise_var=noise_var,
            cache=cache,
            operator=operator,
        )

condition(y, noise_var)

Condition on Gaussian-likelihood observations y.

Precomputes alpha = (K + (jitter + noise_var) * I)^{-1} (y - mu(X)) and caches it in the returned :class:ConditionedGP. The same jitter regularization configured on this prior is included alongside noise_var in the conditioned operator and solve, so every downstream predict / sample call sees the regularized covariance.

Source code in src/pyrox/gp/_models.py
def condition(
    self,
    y: Float[Array, " N"],
    noise_var: Float[Array, ""],
) -> ConditionedGP:
    """Condition on Gaussian-likelihood observations ``y``.

    Precomputes
    ``alpha = (K + (jitter + noise_var) * I)^{-1} (y - mu(X))`` and
    caches it in the returned :class:`ConditionedGP`. The same
    ``jitter`` regularization configured on this prior is included
    alongside ``noise_var`` in the conditioned operator and solve, so
    every downstream predict / sample call sees the regularized
    covariance.
    """
    operator = self._noisy_operator(noise_var)
    residual = y - self.mean(self.X)
    cache = build_prediction_cache(
        operator, residual, solver=self._resolved_solver()
    )
    return ConditionedGP(  # ty: ignore[invalid-return-type]
        prior=self,
        y=y,
        noise_var=noise_var,
        cache=cache,
        operator=operator,
    )

log_prob(f)

Marginal log-density of f under the GP prior.

Computes :math:\log \mathcal{N}(f \mid \mu(X), K(X, X) + \text{jitter}\,I) using :func:gaussx.log_marginal_likelihood, so any solver strategy on this prior applies.

Source code in src/pyrox/gp/_models.py
def log_prob(self, f: Float[Array, " N"]) -> Float[Array, ""]:
    r"""Marginal log-density of ``f`` under the GP prior.

    Computes :math:`\log \mathcal{N}(f \mid \mu(X), K(X, X) + \text{jitter}\,I)`
    using :func:`gaussx.log_marginal_likelihood`, so any solver strategy
    on this prior applies.
    """
    return log_marginal_likelihood(
        self.mean(self.X),
        self._prior_operator(),
        f,
        solver=self._resolved_solver(),
    )

mean(X)

Evaluate the mean function at X; zero by default.

Source code in src/pyrox/gp/_models.py
def mean(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
    """Evaluate the mean function at ``X``; zero by default."""
    if self.mean_fn is None:
        return jnp.zeros(X.shape[0], dtype=X.dtype)
    return self.mean_fn(X)

sample(key)

Draw f \sim p(f) = \mathcal{N}(\mu(X), K + \text{jitter}\,I).

Wraps the prior in a :class:gaussx.MultivariateNormal with the configured :attr:solver. This is the non-NumPyro analogue of :func:gp_sample — useful for tests, diagnostics, and prior-sample initialization without registering a sample site.

Source code in src/pyrox/gp/_models.py
def sample(self, key: Array) -> Float[Array, " N"]:
    r"""Draw ``f \sim p(f) = \mathcal{N}(\mu(X), K + \text{jitter}\,I)``.

    Wraps the prior in a :class:`gaussx.MultivariateNormal` with
    the configured :attr:`solver`. This is the non-NumPyro analogue
    of :func:`gp_sample` — useful for tests, diagnostics, and
    prior-sample initialization without registering a sample site.
    """
    op = self._prior_operator()
    loc = self.mean(self.X)
    mvn = MultivariateNormal(loc, op, solver=self._resolved_solver())
    return mvn.sample(key)

pyrox.gp.ConditionedGP

Bases: Module

GP conditioned on Gaussian-likelihood training observations.

Holds the precomputed training solve alpha (via :class:gaussx.PredictionCache) and the noisy covariance operator so predictions at multiple test sets reuse the training solve.

Source code in src/pyrox/gp/_models.py
class ConditionedGP(eqx.Module):
    """GP conditioned on Gaussian-likelihood training observations.

    Holds the precomputed training solve ``alpha`` (via
    :class:`gaussx.PredictionCache`) and the noisy covariance operator so
    predictions at multiple test sets reuse the training solve.
    """

    prior: GPPrior
    y: Float[Array, " N"]
    noise_var: Float[Array, ""]
    cache: PredictionCache
    operator: lx.AbstractLinearOperator

    def predict_mean(self, X_star: Float[Array, "M D"]) -> Float[Array, " M"]:
        r""":math:`\mu_* = \mu(X_*) + K_{*f}\,\alpha`."""
        with _kernel_context(self.prior.kernel):
            K_cross = self.prior.kernel(X_star, self.prior.X)
        return self.prior.mean(X_star) + predict_mean(self.cache, K_cross)

    def predict_var(self, X_star: Float[Array, "M D"]) -> Float[Array, " M"]:
        r"""Diagonal predictive variance at ``X_*``.

        .. math::
            \sigma^2_{*,i} = k(x_{*,i}, x_{*,i})
                - K_{*f}[i,:] \cdot (K + \sigma^2 I)^{-1} K_{f*}[:,i]

        ``K_cross`` and ``K_diag`` are computed under one shared kernel
        context so Pattern B / C kernels with prior'd hyperparameters
        register their NumPyro sites once and reuse them across both
        kernel calls (and the cached training solve).
        """
        with _kernel_context(self.prior.kernel):
            K_cross = self.prior.kernel(X_star, self.prior.X)
            K_diag = self.prior.kernel.diag(X_star)
        return predict_variance(
            K_cross,
            K_diag,
            self.operator,
            solver=self.prior._resolved_solver(),
        )

    def predict(
        self, X_star: Float[Array, "M D"]
    ) -> tuple[Float[Array, " M"], Float[Array, " M"]]:
        """Return ``(mean, variance)`` at ``X_*`` as a tuple.

        Both kernel evaluations share a single kernel context; see
        :meth:`predict_var`.
        """
        with _kernel_context(self.prior.kernel):
            return self.predict_mean(X_star), self.predict_var(X_star)

    def sample(
        self,
        key: Array,
        X_star: Float[Array, "M D"],
        n_samples: int = 1,
    ) -> Float[Array, "S M"]:
        """Sample from the diagonal predictive ``N(mean, diag(var))``.

        Returns samples independently per test point; correlated joint
        samples from the full predictive covariance are not covered by
        the Wave 2 dense surface. For correlated samples, build the full
        predictive covariance explicitly and draw from
        :class:`gaussx.MultivariateNormal`.
        """
        with _kernel_context(self.prior.kernel):
            mean = self.predict_mean(X_star)
            var = self.predict_var(X_star)
        std = jnp.sqrt(jnp.clip(var, min=0.0))
        eps = jax.random.normal(key, (n_samples, X_star.shape[0]), dtype=mean.dtype)
        return einsum(std, eps, "m, s m -> s m") + mean

predict(X_star)

Return (mean, variance) at X_* as a tuple.

Both kernel evaluations share a single kernel context; see :meth:predict_var.

Source code in src/pyrox/gp/_models.py
def predict(
    self, X_star: Float[Array, "M D"]
) -> tuple[Float[Array, " M"], Float[Array, " M"]]:
    """Return ``(mean, variance)`` at ``X_*`` as a tuple.

    Both kernel evaluations share a single kernel context; see
    :meth:`predict_var`.
    """
    with _kernel_context(self.prior.kernel):
        return self.predict_mean(X_star), self.predict_var(X_star)

predict_mean(X_star)

:math:\mu_* = \mu(X_*) + K_{*f}\,\alpha.

Source code in src/pyrox/gp/_models.py
def predict_mean(self, X_star: Float[Array, "M D"]) -> Float[Array, " M"]:
    r""":math:`\mu_* = \mu(X_*) + K_{*f}\,\alpha`."""
    with _kernel_context(self.prior.kernel):
        K_cross = self.prior.kernel(X_star, self.prior.X)
    return self.prior.mean(X_star) + predict_mean(self.cache, K_cross)

predict_var(X_star)

Diagonal predictive variance at X_*.

.. math:: \sigma^2_{,i} = k(x_{,i}, x_{,i}) - K_{f}[i,:] \cdot (K + \sigma^2 I)^{-1} K_{f*}[:,i]

K_cross and K_diag are computed under one shared kernel context so Pattern B / C kernels with prior'd hyperparameters register their NumPyro sites once and reuse them across both kernel calls (and the cached training solve).

Source code in src/pyrox/gp/_models.py
def predict_var(self, X_star: Float[Array, "M D"]) -> Float[Array, " M"]:
    r"""Diagonal predictive variance at ``X_*``.

    .. math::
        \sigma^2_{*,i} = k(x_{*,i}, x_{*,i})
            - K_{*f}[i,:] \cdot (K + \sigma^2 I)^{-1} K_{f*}[:,i]

    ``K_cross`` and ``K_diag`` are computed under one shared kernel
    context so Pattern B / C kernels with prior'd hyperparameters
    register their NumPyro sites once and reuse them across both
    kernel calls (and the cached training solve).
    """
    with _kernel_context(self.prior.kernel):
        K_cross = self.prior.kernel(X_star, self.prior.X)
        K_diag = self.prior.kernel.diag(X_star)
    return predict_variance(
        K_cross,
        K_diag,
        self.operator,
        solver=self.prior._resolved_solver(),
    )

sample(key, X_star, n_samples=1)

Sample from the diagonal predictive N(mean, diag(var)).

Returns samples independently per test point; correlated joint samples from the full predictive covariance are not covered by the Wave 2 dense surface. For correlated samples, build the full predictive covariance explicitly and draw from :class:gaussx.MultivariateNormal.

Source code in src/pyrox/gp/_models.py
def sample(
    self,
    key: Array,
    X_star: Float[Array, "M D"],
    n_samples: int = 1,
) -> Float[Array, "S M"]:
    """Sample from the diagonal predictive ``N(mean, diag(var))``.

    Returns samples independently per test point; correlated joint
    samples from the full predictive covariance are not covered by
    the Wave 2 dense surface. For correlated samples, build the full
    predictive covariance explicitly and draw from
    :class:`gaussx.MultivariateNormal`.
    """
    with _kernel_context(self.prior.kernel):
        mean = self.predict_mean(X_star)
        var = self.predict_var(X_star)
    std = jnp.sqrt(jnp.clip(var, min=0.0))
    eps = jax.random.normal(key, (n_samples, X_star.shape[0]), dtype=mean.dtype)
    return einsum(std, eps, "m, s m -> s m") + mean

pyrox.gp.gp_factor(name, prior, y, noise_var)

Register the collapsed GP log marginal likelihood with NumPyro.

Adds log p(y | X, theta) = log N(y | mu, K + (jitter + sigma^2) I) to the NumPyro trace as numpyro.factor(name, ...). The prior's jitter is included in addition to the observation noise variance so the covariance matches what :meth:GPPrior.condition builds. Use this inside a NumPyro model when the likelihood is Gaussian and you want the latent function marginalized analytically.

Source code in src/pyrox/gp/_models.py
def gp_factor(
    name: str,
    prior: GPPrior,
    y: Float[Array, " N"],
    noise_var: Float[Array, ""],
) -> None:
    """Register the collapsed GP log marginal likelihood with NumPyro.

    Adds
    ``log p(y | X, theta) = log N(y | mu, K + (jitter + sigma^2) I)``
    to the NumPyro trace as ``numpyro.factor(name, ...)``. The prior's
    ``jitter`` is included in addition to the observation noise variance
    so the covariance matches what :meth:`GPPrior.condition` builds. Use
    this inside a NumPyro model when the likelihood is Gaussian and you
    want the latent function marginalized analytically.
    """
    logp = log_marginal_likelihood(
        prior.mean(prior.X),
        prior._noisy_operator(noise_var),
        y,
        solver=prior._resolved_solver(),
    )
    numpyro.factor(name, logp)

pyrox.gp.gp_sample(name, prior, *, whitened=False, guide=None)

Sample a latent function f at the prior's training inputs.

Three mutually exclusive modes:

  • whitened=False, guide=None (default) — register a single numpyro.sample(name, MVN(mu, K + jitter I)) site. The latent function is sampled directly from the prior.
  • whitened=True, guide=None — register a unit-normal latent site f"{name}_u" with shape (N,) and return the deterministic value f = mu(X) + L u where L is the Cholesky factor of K + jitter I. This reparameterization is the standard fix for mean-field SVI on GP-correlated latents (Murray & Adams, 2010): a NumPyro auto-guide such as :class:numpyro.infer.autoguide.AutoNormal then approximates the well-conditioned isotropic posterior over u instead of the ill-conditioned correlated posterior over f.
  • guide provided — delegate to guide.register(name, prior). Concrete variational guides (Wave 3) own their own parameterization, so combining whitened=True with guide is rejected.

Use this inside a NumPyro model for non-conjugate likelihoods, where the latent function cannot be marginalized analytically.

Source code in src/pyrox/gp/_models.py
def gp_sample(
    name: str,
    prior: GPPrior,
    *,
    whitened: bool = False,
    guide: object | None = None,
) -> Float[Array, " N"]:
    r"""Sample a latent function ``f`` at the prior's training inputs.

    Three mutually exclusive modes:

    * ``whitened=False``, ``guide=None`` (default) — register a single
      ``numpyro.sample(name, MVN(mu, K + jitter I))`` site. The latent
      function is sampled directly from the prior.
    * ``whitened=True``, ``guide=None`` — register a unit-normal latent
      site ``f"{name}_u"`` with shape ``(N,)`` and return the
      deterministic value ``f = mu(X) + L u`` where ``L`` is the
      Cholesky factor of ``K + jitter I``. This reparameterization is the
      standard fix for mean-field SVI on GP-correlated latents
      (Murray & Adams, 2010): a NumPyro auto-guide such as
      :class:`numpyro.infer.autoguide.AutoNormal` then approximates the
      well-conditioned isotropic posterior over ``u`` instead of the
      ill-conditioned correlated posterior over ``f``.
    * ``guide`` provided — delegate to ``guide.register(name, prior)``.
      Concrete variational guides (Wave 3) own their own
      parameterization, so combining ``whitened=True`` with ``guide`` is
      rejected.

    Use this inside a NumPyro model for non-conjugate likelihoods, where
    the latent function cannot be marginalized analytically.
    """
    if guide is not None:
        if whitened:
            raise ValueError(
                "gp_sample: cannot combine `whitened=True` with `guide=...`. "
                "Provide one or the other; concrete guides own their own "
                "parameterization."
            )
        return guide.register(name, prior)  # type: ignore[attr-defined]  # ty: ignore[unresolved-attribute]

    if whitened:
        L = cholesky(prior._prior_operator())
        n = prior.X.shape[0]
        dtype = prior.X.dtype
        u = numpyro.sample(
            f"{name}_u",
            dist.Normal(jnp.zeros(n, dtype=dtype), jnp.ones((), dtype=dtype)).to_event(
                1
            ),
        )
        f = prior.mean(prior.X) + unwhiten(jnp.asarray(u), L)
        return numpyro.deterministic(name, f)  # ty: ignore[invalid-return-type]

    return numpyro.sample(  # ty: ignore[invalid-return-type]
        name,
        MultivariateNormal(
            prior.mean(prior.X),
            prior._prior_operator(),
            solver=prior._resolved_solver(),
        ),
    )

Concrete kernels

Each Parameterized kernel registers its hyperparameters with positivity constraints where appropriate. Attach priors with set_prior, autoguides with autoguide, and flip set_mode("model" | "guide").

pyrox.gp.RBF

Bases: _ParameterizedKernel

Radial basis function (squared exponential) kernel.

Source code in src/pyrox/gp/_kernels.py
class RBF(_ParameterizedKernel):
    """Radial basis function (squared exponential) kernel."""

    pyrox_name: str = "RBF"
    init_variance: float = 1.0
    init_lengthscale: float = 1.0

    def setup(self) -> None:
        self.register_param(
            "variance",
            jnp.asarray(self.init_variance),
            constraint=dist.constraints.positive,
        )
        self.register_param(
            "lengthscale",
            jnp.asarray(self.init_lengthscale),
            constraint=dist.constraints.positive,
        )

    @pyrox_method
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        return _k.rbf_kernel(
            X1, X2, self.get_param("variance"), self.get_param("lengthscale")
        )

pyrox.gp.Matern

Bases: _ParameterizedKernel

Matern kernel with nu in {0.5, 1.5, 2.5}.

nu is a static class attribute — it selects a code path in the underlying math primitive and is not a trainable parameter.

Source code in src/pyrox/gp/_kernels.py
class Matern(_ParameterizedKernel):
    """Matern kernel with ``nu in {0.5, 1.5, 2.5}``.

    ``nu`` is a static class attribute — it selects a code path in the
    underlying math primitive and is not a trainable parameter.
    """

    pyrox_name: str = "Matern"
    init_variance: float = 1.0
    init_lengthscale: float = 1.0
    nu: float = 2.5

    def setup(self) -> None:
        self.register_param(
            "variance",
            jnp.asarray(self.init_variance),
            constraint=dist.constraints.positive,
        )
        self.register_param(
            "lengthscale",
            jnp.asarray(self.init_lengthscale),
            constraint=dist.constraints.positive,
        )

    @pyrox_method
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        return _k.matern_kernel(
            X1,
            X2,
            self.get_param("variance"),
            self.get_param("lengthscale"),
            self.nu,
        )

pyrox.gp.Periodic

Bases: _ParameterizedKernel

Periodic (MacKay) kernel.

Source code in src/pyrox/gp/_kernels.py
class Periodic(_ParameterizedKernel):
    """Periodic (MacKay) kernel."""

    pyrox_name: str = "Periodic"
    init_variance: float = 1.0
    init_lengthscale: float = 1.0
    init_period: float = 1.0

    def setup(self) -> None:
        self.register_param(
            "variance",
            jnp.asarray(self.init_variance),
            constraint=dist.constraints.positive,
        )
        self.register_param(
            "lengthscale",
            jnp.asarray(self.init_lengthscale),
            constraint=dist.constraints.positive,
        )
        self.register_param(
            "period",
            jnp.asarray(self.init_period),
            constraint=dist.constraints.positive,
        )

    @pyrox_method
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        return _k.periodic_kernel(
            X1,
            X2,
            self.get_param("variance"),
            self.get_param("lengthscale"),
            self.get_param("period"),
        )

pyrox.gp.Linear

Bases: _ParameterizedKernel

Linear kernel sigma^2 x^T x' + bias.

bias is constrained nonnegative because k = sigma^2 X X^T + b 1 1^T is only PSD for b >= 0 (e.g. X = 0 gives eigenvalue N*b).

Source code in src/pyrox/gp/_kernels.py
class Linear(_ParameterizedKernel):
    """Linear kernel ``sigma^2 x^T x' + bias``.

    ``bias`` is constrained nonnegative because ``k = sigma^2 X X^T + b 1 1^T``
    is only PSD for ``b >= 0`` (e.g. ``X = 0`` gives eigenvalue ``N*b``).
    """

    pyrox_name: str = "Linear"
    init_variance: float = 1.0
    init_bias: float = 0.0

    def setup(self) -> None:
        self.register_param(
            "variance",
            jnp.asarray(self.init_variance),
            constraint=dist.constraints.positive,
        )
        self.register_param(
            "bias",
            jnp.asarray(self.init_bias),
            constraint=dist.constraints.nonnegative,
        )

    @pyrox_method
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        return _k.linear_kernel(
            X1, X2, self.get_param("variance"), self.get_param("bias")
        )

    def diag(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
        # Non-stationary: diagonal depends on |X[i]|^2.
        v = self.get_param("variance")
        b = self.get_param("bias")
        return v * jnp.sum(X * X, axis=-1) + b

pyrox.gp.RationalQuadratic

Bases: _ParameterizedKernel

Rational quadratic kernel.

Source code in src/pyrox/gp/_kernels.py
class RationalQuadratic(_ParameterizedKernel):
    """Rational quadratic kernel."""

    pyrox_name: str = "RationalQuadratic"
    init_variance: float = 1.0
    init_lengthscale: float = 1.0
    init_alpha: float = 1.0

    def setup(self) -> None:
        self.register_param(
            "variance",
            jnp.asarray(self.init_variance),
            constraint=dist.constraints.positive,
        )
        self.register_param(
            "lengthscale",
            jnp.asarray(self.init_lengthscale),
            constraint=dist.constraints.positive,
        )
        self.register_param(
            "alpha",
            jnp.asarray(self.init_alpha),
            constraint=dist.constraints.positive,
        )

    @pyrox_method
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        return _k.rational_quadratic_kernel(
            X1,
            X2,
            self.get_param("variance"),
            self.get_param("lengthscale"),
            self.get_param("alpha"),
        )

pyrox.gp.Polynomial

Bases: _ParameterizedKernel

Polynomial kernel sigma^2 (x^T x' + bias)^degree.

degree is a static class field (it selects an integer power, not an optimization target). bias is constrained nonnegative — the degree=1 case reduces to :class:Linear and has the same PSD-requires-b>=0 failure mode.

Source code in src/pyrox/gp/_kernels.py
class Polynomial(_ParameterizedKernel):
    """Polynomial kernel ``sigma^2 (x^T x' + bias)^degree``.

    ``degree`` is a static class field (it selects an integer power, not
    an optimization target). ``bias`` is constrained nonnegative — the
    ``degree=1`` case reduces to :class:`Linear` and has the same
    PSD-requires-``b>=0`` failure mode.
    """

    pyrox_name: str = "Polynomial"
    init_variance: float = 1.0
    init_bias: float = 0.0
    degree: int = 2

    def setup(self) -> None:
        self.register_param(
            "variance",
            jnp.asarray(self.init_variance),
            constraint=dist.constraints.positive,
        )
        self.register_param(
            "bias",
            jnp.asarray(self.init_bias),
            constraint=dist.constraints.nonnegative,
        )

    @pyrox_method
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        return _k.polynomial_kernel(
            X1,
            X2,
            self.get_param("variance"),
            self.get_param("bias"),
            self.degree,
        )

    def diag(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
        v = self.get_param("variance")
        b = self.get_param("bias")
        return v * (jnp.sum(X * X, axis=-1) + b) ** self.degree

pyrox.gp.Cosine

Bases: _ParameterizedKernel

Cosine kernel sigma^2 cos(2 pi ||x - x'|| / period).

Source code in src/pyrox/gp/_kernels.py
class Cosine(_ParameterizedKernel):
    """Cosine kernel ``sigma^2 cos(2 pi ||x - x'|| / period)``."""

    pyrox_name: str = "Cosine"
    init_variance: float = 1.0
    init_period: float = 1.0

    def setup(self) -> None:
        self.register_param(
            "variance",
            jnp.asarray(self.init_variance),
            constraint=dist.constraints.positive,
        )
        self.register_param(
            "period",
            jnp.asarray(self.init_period),
            constraint=dist.constraints.positive,
        )

    @pyrox_method
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        return _k.cosine_kernel(
            X1, X2, self.get_param("variance"), self.get_param("period")
        )

pyrox.gp.White

Bases: _ParameterizedKernel

White-noise kernel sigma^2 delta(x, x').

Source code in src/pyrox/gp/_kernels.py
class White(_ParameterizedKernel):
    """White-noise kernel ``sigma^2 delta(x, x')``."""

    pyrox_name: str = "White"
    init_variance: float = 1.0

    def setup(self) -> None:
        self.register_param(
            "variance",
            jnp.asarray(self.init_variance),
            constraint=dist.constraints.positive,
        )

    @pyrox_method
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        return _k.white_kernel(X1, X2, self.get_param("variance"))

pyrox.gp.Constant

Bases: _ParameterizedKernel

Constant kernel k(x, x') = sigma^2.

Source code in src/pyrox/gp/_kernels.py
class Constant(_ParameterizedKernel):
    """Constant kernel ``k(x, x') = sigma^2``."""

    pyrox_name: str = "Constant"
    init_variance: float = 1.0

    def setup(self) -> None:
        self.register_param(
            "variance",
            jnp.asarray(self.init_variance),
            constraint=dist.constraints.positive,
        )

    @pyrox_method
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        return _k.constant_kernel(X1, X2, self.get_param("variance"))

Sparse-GP inducing features (#49)

Inter-domain inducing-feature families used to build scalable sparse GPs where the inducing-prior covariance K_uu becomes diagonal. Pass any of these to :class:SparseGPPrior via the inducing= keyword in place of a raw point matrix Z.

from pyrox.gp import RBF, FourierInducingFeatures, SparseGPPrior

kernel   = RBF(init_lengthscale=0.3, init_variance=1.0)
features = FourierInducingFeatures.init(in_features=1, num_basis_per_dim=64, L=5.0)
prior    = SparseGPPrior(kernel=kernel, inducing=features)   # K_uu is diagonal!

pyrox.gp.InducingFeatures

Bases: Protocol

Protocol for inter-domain inducing features.

Implementations expose the inducing-prior covariance K_uu and the cross-covariance k_ux(X) between data points and inducing features. Diagonal-friendly concretions return :class:lineax.DiagonalLinearOperator so the downstream solve dispatches to elementwise division.

Input shape is family-dependent. k_ux takes a batch of data points X in whatever representation the family consumes:

  • :class:FourierInducingFeatures: coordinates (N, D).
  • :class:SphericalHarmonicInducingFeatures: unit vectors (N, 3).
  • :class:LaplacianInducingFeatures: integer node indices (N,).

Each implementation validates its own expected shape and dtype.

Source code in src/pyrox/gp/_inducing.py
@runtime_checkable
class InducingFeatures(Protocol):
    """Protocol for inter-domain inducing features.

    Implementations expose the inducing-prior covariance ``K_uu`` and the
    cross-covariance ``k_ux(X)`` between data points and inducing
    features. Diagonal-friendly concretions return
    :class:`lineax.DiagonalLinearOperator` so the downstream solve dispatches
    to elementwise division.

    **Input shape is family-dependent.** ``k_ux`` takes a batch of data
    points ``X`` in whatever representation the family consumes:

    - :class:`FourierInducingFeatures`: coordinates ``(N, D)``.
    - :class:`SphericalHarmonicInducingFeatures`: unit vectors ``(N, 3)``.
    - :class:`LaplacianInducingFeatures`: integer node indices ``(N,)``.

    Each implementation validates its own expected shape and dtype.
    """

    @property
    def num_features(self) -> int: ...

    def K_uu(
        self, kernel: Kernel, *, jitter: float = 1e-6
    ) -> lx.AbstractLinearOperator: ...

    def k_ux(self, x: Array, kernel: Kernel) -> Float[Array, "N M"]: ...

pyrox.gp.FourierInducingFeatures

Bases: Module

VFF — Variational Fourier inducing features on :math:[-L, L]^D.

For a stationary kernel with spectral density :math:S(\cdot), the basis :math:\{\phi_j\} of Laplacian eigenfunctions on the box gives

.. math::

K_{uu} = \mathrm{diag}\!\big(S(\sqrt{\lambda_j})\big),
\qquad
K_{ux}(x)_j = S(\sqrt{\lambda_j})\,\phi_j(x).

With this convention :math:K_{ux} K_{uu}^{-1} = \phi_j(x), so the SVGP predictive mean reduces to a basis evaluation. K_{uu} is returned as a :class:lineax.DiagonalLinearOperator to preserve the O(M) solve dispatch end-to-end.

Attributes:

Name Type Description
in_features int

Input dimension :math:D.

num_basis_per_dim tuple[int, ...]

Per-axis number of 1D eigenfunctions; total count is prod(num_basis_per_dim).

L tuple[float, ...]

Per-axis box half-width.

Source code in src/pyrox/gp/_inducing.py
class FourierInducingFeatures(eqx.Module):
    r"""VFF — Variational Fourier inducing features on :math:`[-L, L]^D`.

    For a stationary kernel with spectral density :math:`S(\cdot)`, the
    basis :math:`\{\phi_j\}` of Laplacian eigenfunctions on the box gives

    .. math::

        K_{uu} = \mathrm{diag}\!\big(S(\sqrt{\lambda_j})\big),
        \qquad
        K_{ux}(x)_j = S(\sqrt{\lambda_j})\,\phi_j(x).

    With this convention :math:`K_{ux} K_{uu}^{-1} = \phi_j(x)`, so the
    SVGP predictive mean reduces to a basis evaluation. ``K_{uu}`` is
    returned as a :class:`lineax.DiagonalLinearOperator` to preserve the
    O(M) solve dispatch end-to-end.

    Attributes:
        in_features: Input dimension :math:`D`.
        num_basis_per_dim: Per-axis number of 1D eigenfunctions; total
            count is ``prod(num_basis_per_dim)``.
        L: Per-axis box half-width.
    """

    in_features: int = eqx.field(static=True)
    num_basis_per_dim: tuple[int, ...] = eqx.field(static=True)
    L: tuple[float, ...] = eqx.field(static=True)

    @classmethod
    def init(
        cls,
        in_features: int,
        num_basis_per_dim: int | tuple[int, ...],
        L: float | tuple[float, ...],
    ) -> FourierInducingFeatures:
        M_per = _to_tuple(num_basis_per_dim, in_features, "num_basis_per_dim")
        L_per = _to_tuple(L, in_features, "L")
        if any(L_d <= 0 for L_d in L_per):
            raise ValueError(f"L must be all positive; got {L_per}.")
        if any(M_d < 1 for M_d in M_per):
            raise ValueError(f"num_basis_per_dim must be all >= 1; got {M_per}.")
        return cls(
            in_features=in_features,
            num_basis_per_dim=M_per,
            L=tuple(float(L_d) for L_d in L_per),
        )

    @property
    def num_features(self) -> int:
        n = 1
        for m in self.num_basis_per_dim:
            n *= m
        return n

    def _check_stationary(self, kernel: Kernel) -> None:
        if not _is_stationary(kernel):
            raise ValueError(
                f"FourierInducingFeatures requires a stationary kernel with a "
                f"registered spectral density (RBF or Matern); got "
                f"{type(kernel).__name__}."
            )

    def K_uu(
        self, kernel: Kernel, *, jitter: float = 1e-6
    ) -> lx.DiagonalLinearOperator:
        """Diagonal :math:`K_{uu}` — entries ``S(sqrt(lambda_j))`` plus jitter."""
        self._check_stationary(kernel)
        with _kernel_context(kernel):
            lam = fourier_eigenvalues(self.num_basis_per_dim, self.L, self.in_features)
            S = spectral_density(kernel, lam, D=self.in_features)
        return _diagonal_with_jitter(S, jitter)

    def k_ux(self, x: Float[Array, "N D"], kernel: Kernel) -> Float[Array, "N M"]:
        """Cross-covariance entries :math:`S(\\sqrt{\\lambda_j})\\,\\phi_j(x)`."""
        self._check_stationary(kernel)
        if x.ndim != 2 or x.shape[-1] != self.in_features:
            raise ValueError(f"x must be (N, {self.in_features}); got shape {x.shape}.")
        with _kernel_context(kernel):
            Phi, lam = fourier_basis(x, self.num_basis_per_dim, self.L)
            S = spectral_density(kernel, lam, D=self.in_features)
        return Phi * S[None, :]

K_uu(kernel, *, jitter=1e-06)

Diagonal :math:K_{uu} — entries S(sqrt(lambda_j)) plus jitter.

Source code in src/pyrox/gp/_inducing.py
def K_uu(
    self, kernel: Kernel, *, jitter: float = 1e-6
) -> lx.DiagonalLinearOperator:
    """Diagonal :math:`K_{uu}` — entries ``S(sqrt(lambda_j))`` plus jitter."""
    self._check_stationary(kernel)
    with _kernel_context(kernel):
        lam = fourier_eigenvalues(self.num_basis_per_dim, self.L, self.in_features)
        S = spectral_density(kernel, lam, D=self.in_features)
    return _diagonal_with_jitter(S, jitter)

k_ux(x, kernel)

Cross-covariance entries :math:S(\sqrt{\lambda_j})\,\phi_j(x).

Source code in src/pyrox/gp/_inducing.py
def k_ux(self, x: Float[Array, "N D"], kernel: Kernel) -> Float[Array, "N M"]:
    """Cross-covariance entries :math:`S(\\sqrt{\\lambda_j})\\,\\phi_j(x)`."""
    self._check_stationary(kernel)
    if x.ndim != 2 or x.shape[-1] != self.in_features:
        raise ValueError(f"x must be (N, {self.in_features}); got shape {x.shape}.")
    with _kernel_context(kernel):
        Phi, lam = fourier_basis(x, self.num_basis_per_dim, self.L)
        S = spectral_density(kernel, lam, D=self.in_features)
    return Phi * S[None, :]

pyrox.gp.SphericalHarmonicInducingFeatures

Bases: Module

VISH — inducing harmonics on :math:S^2 (Dutordoir et al. 2020).

For any zonal kernel :math:k(x, x') = \kappa(x \cdot x') on the unit 2-sphere, the Funk-Hecke theorem gives a diagonal :math:K_{uu} whose eigenvalues are the kernel's Funk-Hecke coefficients :math:a_l. The cross-covariance is :math:a_l\,Y_{lm}(x).

Funk-Hecke coefficients are computed by Gauss-Legendre quadrature (arbitrary kernels supported, no closed form required). For kernels that have a closed-form Funk-Hecke series (RBF on S² via Bessel functions etc.), the numerical and analytic answers should agree to the quadrature tolerance.

Attributes:

Name Type Description
l_max int

Maximum harmonic degree, inclusive.

num_quadrature int

Gauss-Legendre nodes for the Funk-Hecke integral.

Source code in src/pyrox/gp/_inducing.py
class SphericalHarmonicInducingFeatures(eqx.Module):
    r"""VISH — inducing harmonics on :math:`S^2` (Dutordoir et al. 2020).

    For any zonal kernel :math:`k(x, x') = \kappa(x \cdot x')` on the
    unit 2-sphere, the Funk-Hecke theorem gives a diagonal :math:`K_{uu}`
    whose eigenvalues are the kernel's Funk-Hecke coefficients
    :math:`a_l`. The cross-covariance is :math:`a_l\,Y_{lm}(x)`.

    Funk-Hecke coefficients are computed by Gauss-Legendre quadrature
    (arbitrary kernels supported, no closed form required). For
    kernels that have a closed-form Funk-Hecke series (RBF on S² via
    Bessel functions etc.), the numerical and analytic answers should
    agree to the quadrature tolerance.

    Attributes:
        l_max: Maximum harmonic degree, inclusive.
        num_quadrature: Gauss-Legendre nodes for the Funk-Hecke integral.
    """

    l_max: int = eqx.field(static=True)
    num_quadrature: int = eqx.field(static=True, default=256)

    @classmethod
    def init(
        cls, l_max: int, *, num_quadrature: int = 256
    ) -> SphericalHarmonicInducingFeatures:
        if l_max < 0:
            raise ValueError(f"l_max must be >= 0; got {l_max}.")
        if num_quadrature < 1:
            raise ValueError(f"num_quadrature must be >= 1; got {num_quadrature}.")
        return cls(l_max=l_max, num_quadrature=num_quadrature)

    @property
    def num_features(self) -> int:
        return (self.l_max + 1) ** 2

    def _per_feature_coeffs(self, kernel: Kernel) -> Float[Array, " M"]:
        a = funk_hecke_coefficients(
            kernel, self.l_max, num_quadrature=self.num_quadrature
        )
        # Each l contributes 2l+1 features with the same coefficient.
        return jnp.concatenate(
            [jnp.full((2 * ell + 1,), a[ell]) for ell in range(self.l_max + 1)]
        )

    def K_uu(
        self, kernel: Kernel, *, jitter: float = 1e-6
    ) -> lx.DiagonalLinearOperator:
        """Diagonal :math:`K_{uu}` — Funk-Hecke coefficients per harmonic."""
        diag = self._per_feature_coeffs(kernel)
        return _diagonal_with_jitter(diag, jitter)

    def k_ux(
        self,
        unit_xyz: Float[Array, "N 3"],
        kernel: Kernel,
    ) -> Float[Array, "N M"]:
        r"""Cross-covariance: :math:`a_l\,Y_{lm}(x)`."""
        if unit_xyz.ndim != 2 or unit_xyz.shape[-1] != 3:
            raise ValueError(f"unit_xyz must be (N, 3); got {unit_xyz.shape}.")
        Y = real_spherical_harmonics(unit_xyz, self.l_max)
        a_per_feature = self._per_feature_coeffs(kernel)
        return Y * a_per_feature[None, :]

K_uu(kernel, *, jitter=1e-06)

Diagonal :math:K_{uu} — Funk-Hecke coefficients per harmonic.

Source code in src/pyrox/gp/_inducing.py
def K_uu(
    self, kernel: Kernel, *, jitter: float = 1e-6
) -> lx.DiagonalLinearOperator:
    """Diagonal :math:`K_{uu}` — Funk-Hecke coefficients per harmonic."""
    diag = self._per_feature_coeffs(kernel)
    return _diagonal_with_jitter(diag, jitter)

k_ux(unit_xyz, kernel)

Cross-covariance: :math:a_l\,Y_{lm}(x).

Source code in src/pyrox/gp/_inducing.py
def k_ux(
    self,
    unit_xyz: Float[Array, "N 3"],
    kernel: Kernel,
) -> Float[Array, "N M"]:
    r"""Cross-covariance: :math:`a_l\,Y_{lm}(x)`."""
    if unit_xyz.ndim != 2 or unit_xyz.shape[-1] != 3:
        raise ValueError(f"unit_xyz must be (N, 3); got {unit_xyz.shape}.")
    Y = real_spherical_harmonics(unit_xyz, self.l_max)
    a_per_feature = self._per_feature_coeffs(kernel)
    return Y * a_per_feature[None, :]

pyrox.gp.LaplacianInducingFeatures

Bases: Module

Inducing features from low-frequency graph Laplacian eigenvectors.

For a graph with normalized Laplacian :math:L, take the smallest num_basis eigenpairs :math:(\mu_j, v_j). Treating the kernel as a function of the graph distance — specifically, applying the kernel spectrum :math:g(\mu) to the Laplacian eigenvalues — gives a diagonal :math:K_{uu}.

This implementation supports the heat-kernel family :math:g(\mu) = \exp(-\mu / (2 \ell^2)) (matching :class:pyrox.gp.RBF in spectrum) by reusing :func:pyrox._basis.spectral_density with the eigenvalues as input.

Attributes:

Name Type Description
eigvals Float[Array, ' M']

(M,) Laplacian eigenvalues.

eigvecs Float[Array, 'V M']

(V, M) Laplacian eigenvectors.

num_quadrature Float[Array, 'V M']

Unused (kept for protocol uniformity).

Note

X is a vector of node indices (integer-valued), not coordinates. The returned cross-covariance gathers the relevant rows of eigvecs.

Source code in src/pyrox/gp/_inducing.py
class LaplacianInducingFeatures(eqx.Module):
    r"""Inducing features from low-frequency graph Laplacian eigenvectors.

    For a graph with normalized Laplacian :math:`L`, take the smallest
    ``num_basis`` eigenpairs :math:`(\mu_j, v_j)`. Treating the kernel as
    a function of the graph distance — specifically, applying the kernel
    *spectrum* :math:`g(\mu)` to the Laplacian eigenvalues — gives a
    diagonal :math:`K_{uu}`.

    This implementation supports the *heat-kernel* family
    :math:`g(\mu) = \exp(-\mu / (2 \ell^2))` (matching :class:`pyrox.gp.RBF`
    in spectrum) by reusing :func:`pyrox._basis.spectral_density` with the
    eigenvalues as input.

    Attributes:
        eigvals: ``(M,)`` Laplacian eigenvalues.
        eigvecs: ``(V, M)`` Laplacian eigenvectors.
        num_quadrature: Unused (kept for protocol uniformity).

    Note:
        ``X`` is a vector of *node indices* (integer-valued), not
        coordinates. The returned cross-covariance gathers the relevant
        rows of ``eigvecs``.
    """

    eigvals: Float[Array, " M"]
    eigvecs: Float[Array, "V M"]

    @classmethod
    def fit(
        cls,
        adjacency: Float[Array, "V V"],
        num_basis: int,
        *,
        normalized: bool = True,
    ) -> LaplacianInducingFeatures:
        eigvals, eigvecs = graph_laplacian_eigpairs(
            adjacency, num_basis, normalized=normalized
        )
        return cls(eigvals=eigvals, eigvecs=eigvecs)

    @property
    def num_features(self) -> int:
        return int(self.eigvals.shape[0])

    def _check_stationary(self, kernel: Kernel) -> None:
        if not _is_stationary(kernel):
            raise ValueError(
                "LaplacianInducingFeatures requires a stationary kernel with a "
                f"registered spectral density; got {type(kernel).__name__}."
            )

    def K_uu(
        self, kernel: Kernel, *, jitter: float = 1e-6
    ) -> lx.DiagonalLinearOperator:
        self._check_stationary(kernel)
        with _kernel_context(kernel):
            S = spectral_density(kernel, self.eigvals, D=1)
        return _diagonal_with_jitter(S, jitter)

    def k_ux(
        self, node_indices: Int[Array, " N"], kernel: Kernel
    ) -> Float[Array, "N M"]:
        self._check_stationary(kernel)
        if node_indices.ndim != 1:
            raise ValueError(
                "node_indices must be a 1D integer array; got shape "
                f"{node_indices.shape}."
            )
        with _kernel_context(kernel):
            S = spectral_density(kernel, self.eigvals, D=1)
        rows = self.eigvecs[node_indices]
        return rows * S[None, :]

pyrox.gp.DecoupledInducingFeatures

Bases: Module

Decoupled mean / covariance inducing-feature bases (Cheng & Boots 2017).

Two independent inducing-feature sets:

  • mean_features: a large alpha-basis used by the SVGP posterior mean (cheap — predictive mean cost is linear in the mean-basis size).
  • cov_features: a small beta-basis used for the posterior covariance (the true bottleneck; keep this small).

The two bases need not share the same family — a common pattern is a large Fourier basis for the mean and a small spherical-harmonic basis for the covariance, or vice versa. The downstream guide consumes both via the standard SVGP machinery.

Attributes:

Name Type Description
mean_features InducingFeatures

Inducing-feature object backing the predictive mean.

cov_features InducingFeatures

Inducing-feature object backing the predictive covariance.

Note

DecoupledInducingFeatures itself does not implement :class:InducingFeatures (no single K_uu makes sense for two bases). Consumers should access .mean_features and .cov_features directly.

Source code in src/pyrox/gp/_inducing.py
class DecoupledInducingFeatures(eqx.Module):
    r"""Decoupled mean / covariance inducing-feature bases (Cheng & Boots 2017).

    Two independent inducing-feature sets:

    - ``mean_features``: a large ``alpha``-basis used by the SVGP
      posterior *mean* (cheap — predictive mean cost is linear in the
      mean-basis size).
    - ``cov_features``: a small ``beta``-basis used for the posterior
      *covariance* (the true bottleneck; keep this small).

    The two bases need not share the same family — a common pattern is a
    large Fourier basis for the mean and a small spherical-harmonic
    basis for the covariance, or vice versa. The downstream guide
    consumes both via the standard SVGP machinery.

    Attributes:
        mean_features: Inducing-feature object backing the predictive mean.
        cov_features: Inducing-feature object backing the predictive covariance.

    Note:
        ``DecoupledInducingFeatures`` itself does *not* implement
        :class:`InducingFeatures` (no single ``K_uu`` makes sense for two
        bases). Consumers should access ``.mean_features`` and
        ``.cov_features`` directly.
    """

    mean_features: InducingFeatures
    cov_features: InducingFeatures

    @property
    def num_mean_features(self) -> int:
        return self.mean_features.num_features

    @property
    def num_cov_features(self) -> int:
        return self.cov_features.num_features

pyrox.gp.funk_hecke_coefficients(kernel, l_max, *, num_quadrature=256)

Funk-Hecke coefficients of a zonal kernel on :math:S^2.

For a kernel of the form :math:k(x, x') = \kappa(x \cdot x') on the unit 2-sphere, the Funk-Hecke theorem gives:

.. math::

a_l = 2\pi \int_{-1}^{1} \kappa(t)\,P_l(t)\,dt.

Returns (l_max + 1,) coefficients indexed by l. We treat any Euclidean kernel as zonal-on-the-sphere via :math:\kappa(t) = k_{\mathrm{euc}}(\hat{n}_0, \hat{n}_t) for unit vectors at angular separation arccos(t).

Source code in src/pyrox/gp/_inducing.py
def funk_hecke_coefficients(
    kernel: Kernel,
    l_max: int,
    *,
    num_quadrature: int = 256,
) -> Float[Array, " l_max_plus_1"]:
    r"""Funk-Hecke coefficients of a zonal kernel on :math:`S^2`.

    For a kernel of the form :math:`k(x, x') = \kappa(x \cdot x')` on the
    unit 2-sphere, the Funk-Hecke theorem gives:

    .. math::

        a_l = 2\pi \int_{-1}^{1} \kappa(t)\,P_l(t)\,dt.

    Returns ``(l_max + 1,)`` coefficients indexed by ``l``. We treat any
    Euclidean kernel as zonal-on-the-sphere via
    :math:`\kappa(t) = k_{\mathrm{euc}}(\hat{n}_0, \hat{n}_t)` for unit
    vectors at angular separation ``arccos(t)``.
    """
    # Gauss-Legendre quadrature nodes on [-1, 1] (host-side setup constants).
    t, w = _gauss_legendre_nodes(num_quadrature)
    # Build pairs of unit vectors: x0 = (0, 0, 1), x_t = (sin(arccos t), 0, t).
    sin_t = jnp.sqrt(jnp.maximum(1.0 - t**2, 0.0))
    n0 = jnp.array([0.0, 0.0, 1.0])
    nT = jnp.stack([sin_t, jnp.zeros_like(t), t], axis=-1)  # (Q, 3)
    # Single batched kernel call — stays on-device and keeps autodiff edges
    # to any hyperparameters sampled inside ``kernel``. Taking row 0 of the
    # ``(1, Q)`` Gram is O(Q), not O(Q^2).
    with _kernel_context(kernel):
        kt = kernel(n0[None, :], nT)[0]  # (Q,)
    # Evaluate P_l(t) for l = 0, ..., l_max via three-term recurrence.
    P_lm1 = jnp.ones_like(t)  # P_0
    P_l = t  # P_1
    coeffs = [2.0 * jnp.pi * jnp.sum(w * kt)]  # a_0 = 2pi * int kt * 1 dt
    if l_max >= 1:
        coeffs.append(2.0 * jnp.pi * jnp.sum(w * kt * P_l))  # a_1
    for ell in range(2, l_max + 1):
        P_lp1 = ((2 * ell - 1) * t * P_l - (ell - 1) * P_lm1) / ell
        coeffs.append(2.0 * jnp.pi * jnp.sum(w * kt * P_lp1))
        P_lm1, P_l = P_l, P_lp1
    return jnp.stack(coeffs, axis=0)

Sparse GP prior

pyrox.gp.SparseGPPrior

Bases: Module

GP prior parameterized over inducing inputs Z.

Represents the zero-mean prior over inducing values u = f(Z) used by sparse variational guides:

.. math::

p(u) = \mathcal{N}(0,\, K_{ZZ} + \mathrm{jitter}\,I).

The standard SVGP convention is to subtract any global mean function before forming the prior over u and to add it back at predict time, so the inducing-prior mean is fixed to zero (this is what the guides' KL terms assume — see :meth:FullRankGuide.kl_divergence, :meth:MeanFieldGuide.kl_divergence, :meth:WhitenedGuide.kl_divergence). The :attr:mean_fn attribute on this class is exposed as a convenience for callers that want to add mu(X_*) back onto the predictive mean returned by :meth:Guide.predict; it is not incorporated in :meth:inducing_operator or in the guides' KL.

Pair with a sparse variational guide that owns q(u) = N(m, S) to obtain the standard SVGP predictive

.. math::

\mu_*(x) = K_{xZ} K_{ZZ}^{-1} m, \qquad
\sigma_*^2(x) = k(x, x) - K_{xZ} K_{ZZ}^{-1} K_{Zx}
               + K_{xZ} K_{ZZ}^{-1} S K_{ZZ}^{-1} K_{Zx}.

Attributes:

Name Type Description
kernel Kernel

Any :class:pyrox.gp.Kernel — evaluated on Z.

Z Float[Array, 'M D'] | None

Inducing inputs of shape (M, D).

mean_fn Callable[[Float[Array, 'N D']], Float[Array, ' N']] | None

Callable X -> (N,) or None for the zero mean. Convenience accessor; not folded into the inducing prior.

solver AbstractSolverStrategy | None

Any gaussx.AbstractSolverStrategy. Defaults to gaussx.DenseSolver(). Used by guides that need to solve against K_zz (e.g.\ for KL or unwhitening).

jitter float

Diagonal regularization added to K_zz for numerical stability. Not a noise model — sparse SVGP does not put observation noise on the inducing-value prior.

Source code in src/pyrox/gp/_sparse.py
class SparseGPPrior(eqx.Module):
    r"""GP prior parameterized over inducing inputs ``Z``.

    Represents the *zero-mean* prior over inducing values ``u = f(Z)``
    used by sparse variational guides:

    .. math::

        p(u) = \mathcal{N}(0,\, K_{ZZ} + \mathrm{jitter}\,I).

    The standard SVGP convention is to subtract any global mean function
    before forming the prior over ``u`` and to add it back at predict
    time, so the inducing-prior mean is fixed to zero (this is what the
    guides' KL terms assume — see :meth:`FullRankGuide.kl_divergence`,
    :meth:`MeanFieldGuide.kl_divergence`, :meth:`WhitenedGuide.kl_divergence`).
    The :attr:`mean_fn` attribute on this class is exposed as a
    convenience for callers that want to add ``mu(X_*)`` back onto the
    predictive mean returned by :meth:`Guide.predict`; it is **not**
    incorporated in :meth:`inducing_operator` or in the guides' KL.

    Pair with a sparse variational guide that owns ``q(u) = N(m, S)`` to
    obtain the standard SVGP predictive

    .. math::

        \mu_*(x) = K_{xZ} K_{ZZ}^{-1} m, \qquad
        \sigma_*^2(x) = k(x, x) - K_{xZ} K_{ZZ}^{-1} K_{Zx}
                       + K_{xZ} K_{ZZ}^{-1} S K_{ZZ}^{-1} K_{Zx}.

    Attributes:
        kernel: Any :class:`pyrox.gp.Kernel` — evaluated on ``Z``.
        Z: Inducing inputs of shape ``(M, D)``.
        mean_fn: Callable ``X -> (N,)`` or ``None`` for the zero mean.
            Convenience accessor; not folded into the inducing prior.
        solver: Any ``gaussx.AbstractSolverStrategy``. Defaults to
            ``gaussx.DenseSolver()``. Used by guides that need to solve
            against ``K_zz`` (e.g.\ for KL or unwhitening).
        jitter: Diagonal regularization added to ``K_zz`` for numerical
            stability. Not a noise model — sparse SVGP does not put
            observation noise on the inducing-value prior.
    """

    kernel: Kernel
    Z: Float[Array, "M D"] | None = None
    inducing: InducingFeatures | None = None
    mean_fn: Callable[[Float[Array, "N D"]], Float[Array, " N"]] | None = None
    solver: AbstractSolverStrategy | None = None
    jitter: float = 1e-6

    def __check_init__(self) -> None:
        if (self.Z is None) == (self.inducing is None):
            raise ValueError(
                "SparseGPPrior must be constructed with exactly one of `Z` "
                "(point inducing) or `inducing` (inducing features)."
            )

    @property
    def num_inducing(self) -> int:
        """Number of inducing inputs / features ``M``."""
        if self.inducing is not None:
            return self.inducing.num_features
        assert self.Z is not None
        return self.Z.shape[0]

    def mean(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
        """Evaluate the mean function at ``X``; zero by default."""
        if self.mean_fn is None:
            return jnp.zeros(X.shape[0], dtype=X.dtype)
        return self.mean_fn(X)

    def inducing_operator(self) -> lx.AbstractLinearOperator:
        r"""Return ``K_{ZZ} + \text{jitter}\,I`` as a ``lineax`` operator.

        For point-inducing priors, returns a dense
        :class:`lineax.MatrixLinearOperator` with ``positive_semidefinite_tag``.
        For inducing-feature priors, delegates to
        :meth:`InducingFeatures.K_uu` — typically a
        :class:`lineax.DiagonalLinearOperator` so the downstream
        :func:`gaussx.solve` dispatches in O(M) instead of O(M^3).

        Single kernel call; safe standalone for kernels with priors. For
        building several SVGP blocks together, prefer
        :meth:`predictive_blocks`, which scopes one shared kernel
        context across ``K_zz``, ``K_xz``, and ``K_xx_diag`` so
        Pattern B / C kernels register their NumPyro hyperparameter
        sites once instead of resampling per call.
        """
        if self.inducing is not None:
            with _kernel_context(self.kernel):
                return self.inducing.K_uu(self.kernel, jitter=self.jitter)
        assert self.Z is not None
        with _kernel_context(self.kernel):
            K = self.kernel(self.Z, self.Z)
        K = K + self.jitter * jnp.eye(K.shape[0], dtype=K.dtype)
        return _psd_operator(K)

    def cross_covariance(self, X: Array) -> Float[Array, "N M"]:
        r""":math:`K_{XZ}` — covariance between ``X`` and the inducing inputs/features.

        The expected shape of ``X`` is inducing-family-dependent:

        - Point-inducing (``Z``) or :class:`FourierInducingFeatures`:
          coordinates ``(N, D)``.
        - :class:`SphericalHarmonicInducingFeatures`: unit vectors ``(N, 3)``.
        - :class:`LaplacianInducingFeatures`: integer node indices ``(N,)``.

        See :meth:`predictive_blocks` for the shared-context batch
        helper to use when assembling several SVGP blocks together.
        """
        if self.inducing is not None:
            with _kernel_context(self.kernel):
                return self.inducing.k_ux(X, self.kernel)
        assert self.Z is not None
        with _kernel_context(self.kernel):
            return self.kernel(X, self.Z)

    def kernel_diag(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
        r"""Prior diagonal ``\mathrm{diag}\,K(X, X)`` — variance at each ``x``.

        See :meth:`predictive_blocks` for the shared-context batch
        helper to use when assembling several SVGP blocks together.
        """
        with _kernel_context(self.kernel):
            return self.kernel.diag(X)

    def predictive_blocks(
        self, X: Array
    ) -> tuple[
        lx.AbstractLinearOperator,
        Float[Array, "N M"],
        Float[Array, " N"],
    ]:
        r"""Return ``(K_zz_op, K_xz, K_xx_diag)`` under one shared kernel context.

        For Pattern B / C kernels with prior'd hyperparameters, the three
        kernel evaluations needed for an SVGP predictive must share a
        single :class:`pyrox.PyroxModule` context so the underlying
        ``pyrox_sample`` sites register once and yield consistent
        hyperparameter draws across ``K_{ZZ}``, ``K_{XZ}``, and the
        diagonal ``\mathrm{diag}\,K(X, X)``. Without this scoping, three
        separate calls would draw three independent hyperparameter
        samples (under seed) or raise NumPyro duplicate-site errors
        (under tracing) — either way invalidating the SVGP math.

        For pure :class:`equinox.Module` kernels (no ``_get_context``),
        this is equivalent to calling :meth:`inducing_operator`,
        :meth:`cross_covariance`, and :meth:`kernel_diag` independently.

        For inducing-feature priors, ``K_zz_op`` is a
        :class:`lineax.DiagonalLinearOperator` (jitter folded into the
        diagonal vector — never ``+ jnp.eye``) so the downstream solve
        stays O(M).
        """
        with _kernel_context(self.kernel):
            if self.inducing is not None:
                K_zz_op = self.inducing.K_uu(self.kernel, jitter=self.jitter)
                K_xz = self.inducing.k_ux(X, self.kernel)
            else:
                assert self.Z is not None
                K_zz_raw = self.kernel(self.Z, self.Z)
                K_xz = self.kernel(X, self.Z)
                K_zz = K_zz_raw + self.jitter * jnp.eye(
                    K_zz_raw.shape[0], dtype=K_zz_raw.dtype
                )
                K_zz_op = _psd_operator(K_zz)
            K_xx_diag = self.kernel.diag(X)
        return K_zz_op, K_xz, K_xx_diag

    def _resolved_solver(self) -> AbstractSolverStrategy:
        return DenseSolver() if self.solver is None else self.solver  # ty: ignore[invalid-return-type]

    def log_prob(self, u: Float[Array, " M"]) -> Float[Array, ""]:
        r"""Log-density under :math:`p(u) = \mathcal{N}(0, K_{ZZ} + \text{jitter}\,I)`.

        Delegates to :func:`gaussx.gaussian_log_prob` with the
        configured :attr:`solver` so the user-supplied solver controls
        the ``solve`` / ``logdet`` work on ``K_zz_op``. Useful for
        scoring inducing values against the SVGP prior in non-NumPyro
        contexts (e.g.\\ tests, diagnostics).
        """
        m = jnp.zeros(self.num_inducing, dtype=u.dtype)
        return gaussian_log_prob(
            m, self.inducing_operator(), u, solver=self._resolved_solver()
        )

    def sample(self, key: Array) -> Float[Array, " M"]:
        r"""Draw ``u \sim p(u)`` from the inducing prior.

        Wraps the inducing operator in a
        :class:`gaussx.MultivariateNormal` with the configured
        :attr:`solver`. ``MultivariateNormal.sample`` factors the
        covariance via :func:`gaussx.cholesky` and reparameterizes;
        the returned draw has shape ``(M,)``.

        Note: the SVGP variational workflow samples ``u`` from the
        *guide* :math:`q(u)`, not the prior. This method exists so the
        prior surface is symmetric with the guide surface and so users
        can score / draw inducing values against the prior directly
        (e.g.\\ for tests or for prior-sample initialization).
        """
        n = self.num_inducing
        op = self.inducing_operator()
        loc = jnp.zeros(n, dtype=op.out_structure().dtype)
        mvn = MultivariateNormal(loc, op, solver=self._resolved_solver())
        return mvn.sample(key)

num_inducing property

Number of inducing inputs / features M.

cross_covariance(X)

:math:K_{XZ} — covariance between X and the inducing inputs/features.

The expected shape of X is inducing-family-dependent:

  • Point-inducing (Z) or :class:FourierInducingFeatures: coordinates (N, D).
  • :class:SphericalHarmonicInducingFeatures: unit vectors (N, 3).
  • :class:LaplacianInducingFeatures: integer node indices (N,).

See :meth:predictive_blocks for the shared-context batch helper to use when assembling several SVGP blocks together.

Source code in src/pyrox/gp/_sparse.py
def cross_covariance(self, X: Array) -> Float[Array, "N M"]:
    r""":math:`K_{XZ}` — covariance between ``X`` and the inducing inputs/features.

    The expected shape of ``X`` is inducing-family-dependent:

    - Point-inducing (``Z``) or :class:`FourierInducingFeatures`:
      coordinates ``(N, D)``.
    - :class:`SphericalHarmonicInducingFeatures`: unit vectors ``(N, 3)``.
    - :class:`LaplacianInducingFeatures`: integer node indices ``(N,)``.

    See :meth:`predictive_blocks` for the shared-context batch
    helper to use when assembling several SVGP blocks together.
    """
    if self.inducing is not None:
        with _kernel_context(self.kernel):
            return self.inducing.k_ux(X, self.kernel)
    assert self.Z is not None
    with _kernel_context(self.kernel):
        return self.kernel(X, self.Z)

inducing_operator()

Return K_{ZZ} + \text{jitter}\,I as a lineax operator.

For point-inducing priors, returns a dense :class:lineax.MatrixLinearOperator with positive_semidefinite_tag. For inducing-feature priors, delegates to :meth:InducingFeatures.K_uu — typically a :class:lineax.DiagonalLinearOperator so the downstream :func:gaussx.solve dispatches in O(M) instead of O(M^3).

Single kernel call; safe standalone for kernels with priors. For building several SVGP blocks together, prefer :meth:predictive_blocks, which scopes one shared kernel context across K_zz, K_xz, and K_xx_diag so Pattern B / C kernels register their NumPyro hyperparameter sites once instead of resampling per call.

Source code in src/pyrox/gp/_sparse.py
def inducing_operator(self) -> lx.AbstractLinearOperator:
    r"""Return ``K_{ZZ} + \text{jitter}\,I`` as a ``lineax`` operator.

    For point-inducing priors, returns a dense
    :class:`lineax.MatrixLinearOperator` with ``positive_semidefinite_tag``.
    For inducing-feature priors, delegates to
    :meth:`InducingFeatures.K_uu` — typically a
    :class:`lineax.DiagonalLinearOperator` so the downstream
    :func:`gaussx.solve` dispatches in O(M) instead of O(M^3).

    Single kernel call; safe standalone for kernels with priors. For
    building several SVGP blocks together, prefer
    :meth:`predictive_blocks`, which scopes one shared kernel
    context across ``K_zz``, ``K_xz``, and ``K_xx_diag`` so
    Pattern B / C kernels register their NumPyro hyperparameter
    sites once instead of resampling per call.
    """
    if self.inducing is not None:
        with _kernel_context(self.kernel):
            return self.inducing.K_uu(self.kernel, jitter=self.jitter)
    assert self.Z is not None
    with _kernel_context(self.kernel):
        K = self.kernel(self.Z, self.Z)
    K = K + self.jitter * jnp.eye(K.shape[0], dtype=K.dtype)
    return _psd_operator(K)

kernel_diag(X)

Prior diagonal \mathrm{diag}\,K(X, X) — variance at each x.

See :meth:predictive_blocks for the shared-context batch helper to use when assembling several SVGP blocks together.

Source code in src/pyrox/gp/_sparse.py
def kernel_diag(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
    r"""Prior diagonal ``\mathrm{diag}\,K(X, X)`` — variance at each ``x``.

    See :meth:`predictive_blocks` for the shared-context batch
    helper to use when assembling several SVGP blocks together.
    """
    with _kernel_context(self.kernel):
        return self.kernel.diag(X)

log_prob(u)

Log-density under :math:p(u) = \mathcal{N}(0, K_{ZZ} + \text{jitter}\,I).

Delegates to :func:gaussx.gaussian_log_prob with the configured :attr:solver so the user-supplied solver controls the solve / logdet work on K_zz_op. Useful for scoring inducing values against the SVGP prior in non-NumPyro contexts (e.g.\ tests, diagnostics).

Source code in src/pyrox/gp/_sparse.py
def log_prob(self, u: Float[Array, " M"]) -> Float[Array, ""]:
    r"""Log-density under :math:`p(u) = \mathcal{N}(0, K_{ZZ} + \text{jitter}\,I)`.

    Delegates to :func:`gaussx.gaussian_log_prob` with the
    configured :attr:`solver` so the user-supplied solver controls
    the ``solve`` / ``logdet`` work on ``K_zz_op``. Useful for
    scoring inducing values against the SVGP prior in non-NumPyro
    contexts (e.g.\\ tests, diagnostics).
    """
    m = jnp.zeros(self.num_inducing, dtype=u.dtype)
    return gaussian_log_prob(
        m, self.inducing_operator(), u, solver=self._resolved_solver()
    )

mean(X)

Evaluate the mean function at X; zero by default.

Source code in src/pyrox/gp/_sparse.py
def mean(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
    """Evaluate the mean function at ``X``; zero by default."""
    if self.mean_fn is None:
        return jnp.zeros(X.shape[0], dtype=X.dtype)
    return self.mean_fn(X)

predictive_blocks(X)

Return (K_zz_op, K_xz, K_xx_diag) under one shared kernel context.

For Pattern B / C kernels with prior'd hyperparameters, the three kernel evaluations needed for an SVGP predictive must share a single :class:pyrox.PyroxModule context so the underlying pyrox_sample sites register once and yield consistent hyperparameter draws across K_{ZZ}, K_{XZ}, and the diagonal \mathrm{diag}\,K(X, X). Without this scoping, three separate calls would draw three independent hyperparameter samples (under seed) or raise NumPyro duplicate-site errors (under tracing) — either way invalidating the SVGP math.

For pure :class:equinox.Module kernels (no _get_context), this is equivalent to calling :meth:inducing_operator, :meth:cross_covariance, and :meth:kernel_diag independently.

For inducing-feature priors, K_zz_op is a :class:lineax.DiagonalLinearOperator (jitter folded into the diagonal vector — never + jnp.eye) so the downstream solve stays O(M).

Source code in src/pyrox/gp/_sparse.py
def predictive_blocks(
    self, X: Array
) -> tuple[
    lx.AbstractLinearOperator,
    Float[Array, "N M"],
    Float[Array, " N"],
]:
    r"""Return ``(K_zz_op, K_xz, K_xx_diag)`` under one shared kernel context.

    For Pattern B / C kernels with prior'd hyperparameters, the three
    kernel evaluations needed for an SVGP predictive must share a
    single :class:`pyrox.PyroxModule` context so the underlying
    ``pyrox_sample`` sites register once and yield consistent
    hyperparameter draws across ``K_{ZZ}``, ``K_{XZ}``, and the
    diagonal ``\mathrm{diag}\,K(X, X)``. Without this scoping, three
    separate calls would draw three independent hyperparameter
    samples (under seed) or raise NumPyro duplicate-site errors
    (under tracing) — either way invalidating the SVGP math.

    For pure :class:`equinox.Module` kernels (no ``_get_context``),
    this is equivalent to calling :meth:`inducing_operator`,
    :meth:`cross_covariance`, and :meth:`kernel_diag` independently.

    For inducing-feature priors, ``K_zz_op`` is a
    :class:`lineax.DiagonalLinearOperator` (jitter folded into the
    diagonal vector — never ``+ jnp.eye``) so the downstream solve
    stays O(M).
    """
    with _kernel_context(self.kernel):
        if self.inducing is not None:
            K_zz_op = self.inducing.K_uu(self.kernel, jitter=self.jitter)
            K_xz = self.inducing.k_ux(X, self.kernel)
        else:
            assert self.Z is not None
            K_zz_raw = self.kernel(self.Z, self.Z)
            K_xz = self.kernel(X, self.Z)
            K_zz = K_zz_raw + self.jitter * jnp.eye(
                K_zz_raw.shape[0], dtype=K_zz_raw.dtype
            )
            K_zz_op = _psd_operator(K_zz)
        K_xx_diag = self.kernel.diag(X)
    return K_zz_op, K_xz, K_xx_diag

sample(key)

Draw u \sim p(u) from the inducing prior.

Wraps the inducing operator in a :class:gaussx.MultivariateNormal with the configured :attr:solver. MultivariateNormal.sample factors the covariance via :func:gaussx.cholesky and reparameterizes; the returned draw has shape (M,).

Note: the SVGP variational workflow samples u from the guide :math:q(u), not the prior. This method exists so the prior surface is symmetric with the guide surface and so users can score / draw inducing values against the prior directly (e.g.\ for tests or for prior-sample initialization).

Source code in src/pyrox/gp/_sparse.py
def sample(self, key: Array) -> Float[Array, " M"]:
    r"""Draw ``u \sim p(u)`` from the inducing prior.

    Wraps the inducing operator in a
    :class:`gaussx.MultivariateNormal` with the configured
    :attr:`solver`. ``MultivariateNormal.sample`` factors the
    covariance via :func:`gaussx.cholesky` and reparameterizes;
    the returned draw has shape ``(M,)``.

    Note: the SVGP variational workflow samples ``u`` from the
    *guide* :math:`q(u)`, not the prior. This method exists so the
    prior surface is symmetric with the guide surface and so users
    can score / draw inducing values against the prior directly
    (e.g.\\ for tests or for prior-sample initialization).
    """
    n = self.num_inducing
    op = self.inducing_operator()
    loc = jnp.zeros(n, dtype=op.out_structure().dtype)
    mvn = MultivariateNormal(loc, op, solver=self._resolved_solver())
    return mvn.sample(key)

Component protocols

Abstract pyrox-local bases for the orthogonal component stack. Wave 2 ships only the contracts for Guide, Integrator, and Likelihood — concrete implementations land in later waves. Solver strategies live in gaussx._strategies.

pyrox.gp.Kernel

Bases: Module

Abstract base for GP covariance functions.

Subclasses implement :meth:__call__ returning the Gram matrix on a pair of input batches. :meth:gram and :meth:diag are convenience defaults that derive from :meth:__call__; structured subclasses (Kronecker, state-space, etc.) should override them for efficiency.

Source code in src/pyrox/gp/_protocols.py
class Kernel(eqx.Module):
    """Abstract base for GP covariance functions.

    Subclasses implement :meth:`__call__` returning the Gram matrix on a pair
    of input batches. :meth:`gram` and :meth:`diag` are convenience defaults
    that derive from :meth:`__call__`; structured subclasses (Kronecker,
    state-space, etc.) should override them for efficiency.
    """

    @abstractmethod
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        raise NotImplementedError

    def gram(self, X: Float[Array, "N D"]) -> Float[Array, "N N"]:
        """Symmetric Gram matrix ``K(X, X)``."""
        return self(X, X)

    def diag(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
        """Diagonal of ``K(X, X)``.

        Default implementation extracts the diagonal of the full Gram. For
        stationary kernels with constant diagonal, override with a vectorized
        broadcast for the ``O(N)`` shortcut.
        """
        return jnp.diag(self(X, X))

diag(X)

Diagonal of K(X, X).

Default implementation extracts the diagonal of the full Gram. For stationary kernels with constant diagonal, override with a vectorized broadcast for the O(N) shortcut.

Source code in src/pyrox/gp/_protocols.py
def diag(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
    """Diagonal of ``K(X, X)``.

    Default implementation extracts the diagonal of the full Gram. For
    stationary kernels with constant diagonal, override with a vectorized
    broadcast for the ``O(N)`` shortcut.
    """
    return jnp.diag(self(X, X))

gram(X)

Symmetric Gram matrix K(X, X).

Source code in src/pyrox/gp/_protocols.py
def gram(self, X: Float[Array, "N D"]) -> Float[Array, "N N"]:
    """Symmetric Gram matrix ``K(X, X)``."""
    return self(X, X)

pyrox.gp.Guide

Bases: Module

Abstract base for variational posterior families.

Concrete guides (DeltaGuide, MeanFieldGuide, LowRankGuide, FullRankGuide, etc.) land in the dedicated guide waves (#28, #29). The whitening principle keeps optimization geometry well-conditioned — sample from a unit-scale latent and unwhiten with the prior Cholesky.

Two distinct entry points:

  • :meth:sample / :meth:log_prob — pure variational draws and densities. sample(self, key) returns a draw from q(f); log_prob(self, f) evaluates log q(f). Neither touches the NumPyro trace.
  • register(name, prior) (optional) — the NumPyro-integration hook invoked by :func:pyrox.gp.gp_sample when a guide is supplied. Use it to register a sample / param site (or compose one out of guide state) under name and return the latent function value. Concrete guides that participate in :func:gp_sample should implement this; the protocol leaves it unspecified so guides usable purely outside NumPyro stay valid.
Source code in src/pyrox/gp/_protocols.py
class Guide(eqx.Module):
    """Abstract base for variational posterior families.

    Concrete guides (``DeltaGuide``, ``MeanFieldGuide``, ``LowRankGuide``,
    ``FullRankGuide``, etc.) land in the dedicated guide waves (#28, #29).
    The whitening principle keeps optimization geometry well-conditioned —
    sample from a unit-scale latent and unwhiten with the prior Cholesky.

    Two distinct entry points:

    * :meth:`sample` / :meth:`log_prob` — pure variational draws and
      densities. ``sample(self, key)`` returns a draw from ``q(f)``;
      ``log_prob(self, f)`` evaluates ``log q(f)``. Neither touches the
      NumPyro trace.
    * ``register(name, prior)`` (optional) — the NumPyro-integration hook
      invoked by :func:`pyrox.gp.gp_sample` when a guide is supplied. Use
      it to register a sample / param site (or compose one out of guide
      state) under ``name`` and return the latent function value. Concrete
      guides that participate in :func:`gp_sample` should implement this;
      the protocol leaves it unspecified so guides usable purely outside
      NumPyro stay valid.
    """

    @abstractmethod
    def sample(self, key: Any) -> Float[Array, " ..."]:
        raise NotImplementedError

    @abstractmethod
    def log_prob(self, f: Float[Array, " ..."]) -> Float[Array, ""]:
        raise NotImplementedError

pyrox.gp.Integrator

Bases: Module

Abstract base for Gaussian-expectation integrators.

Computes :math:\mathbb{E}_{q(f)}[g(f)] where q(f) = N(mean, var). Concrete integrators (Gauss-Hermite, sigma-points, cubature, Taylor, Monte Carlo) land in later waves and may delegate to gaussx's quadrature primitives.

Source code in src/pyrox/gp/_protocols.py
class Integrator(eqx.Module):
    """Abstract base for Gaussian-expectation integrators.

    Computes :math:`\\mathbb{E}_{q(f)}[g(f)]` where ``q(f) = N(mean, var)``.
    Concrete integrators (Gauss-Hermite, sigma-points, cubature, Taylor,
    Monte Carlo) land in later waves and may delegate to ``gaussx``'s
    quadrature primitives.
    """

    @abstractmethod
    def integrate(
        self,
        fn: Callable[[Float[Array, " ..."]], Float[Array, " ..."]],
        mean: Float[Array, " ..."],
        var: Float[Array, " ..."],
    ) -> Float[Array, " ..."]:
        raise NotImplementedError

pyrox.gp.Likelihood

Bases: Module

Abstract base for observation models.

Implements the conditional p(y | f) and a default :meth:expected_log_prob that integrates over a Gaussian latent via an :class:Integrator. Concrete likelihoods (Gaussian, Bernoulli, Poisson, StudentT, ...) land in later waves.

Source code in src/pyrox/gp/_protocols.py
class Likelihood(eqx.Module):
    """Abstract base for observation models.

    Implements the conditional ``p(y | f)`` and a default
    :meth:`expected_log_prob` that integrates over a Gaussian latent via an
    :class:`Integrator`. Concrete likelihoods (Gaussian, Bernoulli, Poisson,
    StudentT, ...) land in later waves.
    """

    @abstractmethod
    def log_prob(
        self,
        f: Float[Array, " ..."],
        y: Float[Array, " ..."],
    ) -> Float[Array, ""]:
        raise NotImplementedError

Math primitives

Pure JAX kernel functions. Stateless, differentiable, composable — (Array, ..., hyperparams) -> Gram. No NumPyro, no protocols.

Pure JAX kernel evaluation primitives — math definitions only.

Each function takes two input matrices X1 of shape (N1, D), X2 of shape (N2, D), hyperparameters as JAX arrays, and returns the (N1, N2) Gram matrix. All inputs are 2-D; callers that have 1-D arrays must add a trailing singleton dimension first.

These are the canonical closed-form math for each kernel — small, readable, and tutorial-facing. The companion scalable construction surface (numerically stable matrix assembly, mixed-precision accumulation, implicit/structured operators, batched matvec) lives in gaussx; see :func:gaussx.stable_rbf_kernel and :class:gaussx.ImplicitKernelOperator for the production path.

The composition helpers :func:kernel_add and :func:kernel_mul act on already-evaluated Gram matrices, not on callables. Higher-level :class:pyrox.gp.Kernel classes (Wave 2 Layer 1, see issue #20) compose callables and may opt in to gaussx's scalable variants when needed.

Index axes are named via :mod:einops (einsum / rearrange) rather than raw broadcasting so shape intent stays legible at the call site.

constant_kernel(X1, X2, variance)

Constant kernel.

.. math:: k(x, x') = \sigma^2

A rank-one kernel useful as a scalar offset additive component.

Parameters:

Name Type Description Default
X1 Float[Array, 'N1 D']

(N1, D) inputs.

required
X2 Float[Array, 'N2 D']

(N2, D) inputs.

required
variance Float[Array, '']

Scalar value.

required

Returns:

Type Description
Float[Array, 'N1 N2']

(N1, N2) Gram matrix filled with variance.

Source code in src/pyrox/gp/_src/kernels.py
def constant_kernel(
    X1: Float[Array, "N1 D"],
    X2: Float[Array, "N2 D"],
    variance: Float[Array, ""],
) -> Float[Array, "N1 N2"]:
    r"""Constant kernel.

    .. math::
        k(x, x') = \sigma^2

    A rank-one kernel useful as a scalar offset additive component.

    Args:
        X1: ``(N1, D)`` inputs.
        X2: ``(N2, D)`` inputs.
        variance: Scalar value.

    Returns:
        ``(N1, N2)`` Gram matrix filled with ``variance``.
    """
    return variance * jnp.ones((X1.shape[0], X2.shape[0]), dtype=X1.dtype)

cosine_kernel(X1, X2, variance, period)

Cosine kernel.

.. math:: k(x, x') = \sigma^2 \cos!\left( \frac{2 \pi |x - x'|}{p} \right)

Useful as a simple periodic building block alongside :func:periodic_kernel; unlike the Mackay form this one uses plain cosine of distance and can go negative.

Parameters:

Name Type Description Default
X1 Float[Array, 'N1 D']

(N1, D) inputs.

required
X2 Float[Array, 'N2 D']

(N2, D) inputs.

required
variance Float[Array, '']

Scalar signal variance.

required
period Float[Array, '']

Scalar period.

required

Returns:

Type Description
Float[Array, 'N1 N2']

(N1, N2) Gram matrix.

Source code in src/pyrox/gp/_src/kernels.py
def cosine_kernel(
    X1: Float[Array, "N1 D"],
    X2: Float[Array, "N2 D"],
    variance: Float[Array, ""],
    period: Float[Array, ""],
) -> Float[Array, "N1 N2"]:
    r"""Cosine kernel.

    .. math::
        k(x, x') = \sigma^2 \cos\!\left(
            \frac{2 \pi \|x - x'\|}{p}
        \right)

    Useful as a simple periodic building block alongside
    :func:`periodic_kernel`; unlike the Mackay form this one uses plain
    cosine of distance and can go negative.

    Args:
        X1: ``(N1, D)`` inputs.
        X2: ``(N2, D)`` inputs.
        variance: Scalar signal variance.
        period: Scalar period.

    Returns:
        ``(N1, N2)`` Gram matrix.
    """
    sq = _pairwise_sq_dist(X1, X2)
    # Jitter inside sqrt avoids NaN gradients at r = 0 (sqrt' is undefined).
    r = jnp.sqrt(jnp.clip(sq, min=1e-30))
    return variance * jnp.cos(2.0 * jnp.pi * r / period)

kernel_add(K1, K2)

Pointwise sum of two already-evaluated Gram matrices.

Source code in src/pyrox/gp/_src/kernels.py
def kernel_add(
    K1: Float[Array, "N1 N2"],
    K2: Float[Array, "N1 N2"],
) -> Float[Array, "N1 N2"]:
    """Pointwise sum of two already-evaluated Gram matrices."""
    return K1 + K2

kernel_mul(K1, K2)

Pointwise (Hadamard) product of two already-evaluated Gram matrices.

Source code in src/pyrox/gp/_src/kernels.py
def kernel_mul(
    K1: Float[Array, "N1 N2"],
    K2: Float[Array, "N1 N2"],
) -> Float[Array, "N1 N2"]:
    """Pointwise (Hadamard) product of two already-evaluated Gram matrices."""
    return K1 * K2

linear_kernel(X1, X2, variance, bias)

Linear kernel.

.. math:: k(x, x') = \sigma^2\, x^\top x' + b

Parameters:

Name Type Description Default
X1 Float[Array, 'N1 D']

(N1, D) inputs.

required
X2 Float[Array, 'N2 D']

(N2, D) inputs.

required
variance Float[Array, '']

Scalar variance multiplier on the dot product.

required
bias Float[Array, '']

Scalar additive bias.

required

Returns:

Type Description
Float[Array, 'N1 N2']

(N1, N2) Gram matrix.

Source code in src/pyrox/gp/_src/kernels.py
def linear_kernel(
    X1: Float[Array, "N1 D"],
    X2: Float[Array, "N2 D"],
    variance: Float[Array, ""],
    bias: Float[Array, ""],
) -> Float[Array, "N1 N2"]:
    r"""Linear kernel.

    .. math::
        k(x, x') = \sigma^2\, x^\top x' + b

    Args:
        X1: ``(N1, D)`` inputs.
        X2: ``(N2, D)`` inputs.
        variance: Scalar variance multiplier on the dot product.
        bias: Scalar additive bias.

    Returns:
        ``(N1, N2)`` Gram matrix.
    """
    return variance * einsum(X1, X2, "n1 d, n2 d -> n1 n2") + bias

matern_kernel(X1, X2, variance, lengthscale, nu)

Matern kernel with closed-form nu in {1/2, 3/2, 5/2}.

.. math:: k(x, x') = \sigma^2\, f_\nu(r / \ell), \qquad r = |x - x'|

Only the three common half-integer orders are supported because those admit closed-form expressions without Bessel evaluations. nu is a static Python float (not a JAX array) so the branch specializes at trace time.

Parameters:

Name Type Description Default
X1 Float[Array, 'N1 D']

(N1, D) inputs.

required
X2 Float[Array, 'N2 D']

(N2, D) inputs.

required
variance Float[Array, '']

Scalar signal variance.

required
lengthscale Float[Array, '']

Scalar lengthscale.

required
nu float

Smoothness parameter; must be 0.5, 1.5, or 2.5.

required

Returns:

Type Description
Float[Array, 'N1 N2']

(N1, N2) Gram matrix.

Source code in src/pyrox/gp/_src/kernels.py
def matern_kernel(
    X1: Float[Array, "N1 D"],
    X2: Float[Array, "N2 D"],
    variance: Float[Array, ""],
    lengthscale: Float[Array, ""],
    nu: float,
) -> Float[Array, "N1 N2"]:
    r"""Matern kernel with closed-form ``nu in {1/2, 3/2, 5/2}``.

    .. math::
        k(x, x') = \sigma^2\, f_\nu(r / \ell),
        \qquad r = \|x - x'\|

    Only the three common half-integer orders are supported because those
    admit closed-form expressions without Bessel evaluations. ``nu`` is a
    static Python float (not a JAX array) so the branch specializes at
    trace time.

    Args:
        X1: ``(N1, D)`` inputs.
        X2: ``(N2, D)`` inputs.
        variance: Scalar signal variance.
        lengthscale: Scalar lengthscale.
        nu: Smoothness parameter; must be ``0.5``, ``1.5``, or ``2.5``.

    Returns:
        ``(N1, N2)`` Gram matrix.
    """
    sq = _pairwise_sq_dist(X1, X2)
    # Jitter inside sqrt avoids NaN gradients at r = 0 (sqrt' is undefined).
    r = jnp.sqrt(jnp.clip(sq, min=1e-30)) / lengthscale
    if nu == 0.5:
        shape = jnp.exp(-r)
    elif nu == 1.5:
        a = jnp.sqrt(3.0) * r
        shape = (1.0 + a) * jnp.exp(-a)
    elif nu == 2.5:
        a = jnp.sqrt(5.0) * r
        shape = (1.0 + a + (a * a) / 3.0) * jnp.exp(-a)
    else:
        raise ValueError(f"matern_kernel supports nu in {{0.5, 1.5, 2.5}}, got {nu!r}")
    return variance * shape

periodic_kernel(X1, X2, variance, lengthscale, period)

Periodic (MacKay) kernel.

.. math:: k(x, x') = \sigma^2 \exp!\left( -\frac{2 \sin^2(\pi |x - x'| / p)}{\ell^2} \right)

For multi-dimensional inputs the argument uses the Euclidean distance, matching the common GPML convention.

Parameters:

Name Type Description Default
X1 Float[Array, 'N1 D']

(N1, D) inputs.

required
X2 Float[Array, 'N2 D']

(N2, D) inputs.

required
variance Float[Array, '']

Scalar signal variance.

required
lengthscale Float[Array, '']

Scalar lengthscale.

required
period Float[Array, '']

Scalar period p.

required

Returns:

Type Description
Float[Array, 'N1 N2']

(N1, N2) Gram matrix.

Source code in src/pyrox/gp/_src/kernels.py
def periodic_kernel(
    X1: Float[Array, "N1 D"],
    X2: Float[Array, "N2 D"],
    variance: Float[Array, ""],
    lengthscale: Float[Array, ""],
    period: Float[Array, ""],
) -> Float[Array, "N1 N2"]:
    r"""Periodic (MacKay) kernel.

    .. math::
        k(x, x') = \sigma^2 \exp\!\left(
            -\frac{2 \sin^2(\pi \|x - x'\| / p)}{\ell^2}
        \right)

    For multi-dimensional inputs the argument uses the Euclidean distance,
    matching the common GPML convention.

    Args:
        X1: ``(N1, D)`` inputs.
        X2: ``(N2, D)`` inputs.
        variance: Scalar signal variance.
        lengthscale: Scalar lengthscale.
        period: Scalar period ``p``.

    Returns:
        ``(N1, N2)`` Gram matrix.
    """
    sq = _pairwise_sq_dist(X1, X2)
    # Jitter inside sqrt avoids NaN gradients at r = 0 (sqrt' is undefined).
    r = jnp.sqrt(jnp.clip(sq, min=1e-30))
    sinsq = jnp.sin(jnp.pi * r / period) ** 2
    return variance * jnp.exp(-2.0 * sinsq / (lengthscale * lengthscale))

polynomial_kernel(X1, X2, variance, bias, degree)

Polynomial kernel.

.. math:: k(x, x') = \sigma^2 \bigl(x^\top x' + b\bigr)^d

:func:linear_kernel is the special case degree == 1 without the outer power. degree is a static Python int so the kernel specializes at trace time.

Parameters:

Name Type Description Default
X1 Float[Array, 'N1 D']

(N1, D) inputs.

required
X2 Float[Array, 'N2 D']

(N2, D) inputs.

required
variance Float[Array, '']

Scalar multiplier.

required
bias Float[Array, '']

Scalar additive bias inside the power.

required
degree int

Positive integer polynomial degree.

required

Returns:

Type Description
Float[Array, 'N1 N2']

(N1, N2) Gram matrix.

Source code in src/pyrox/gp/_src/kernels.py
def polynomial_kernel(
    X1: Float[Array, "N1 D"],
    X2: Float[Array, "N2 D"],
    variance: Float[Array, ""],
    bias: Float[Array, ""],
    degree: int,
) -> Float[Array, "N1 N2"]:
    r"""Polynomial kernel.

    .. math::
        k(x, x') = \sigma^2 \bigl(x^\top x' + b\bigr)^d

    :func:`linear_kernel` is the special case ``degree == 1`` without the
    outer power. ``degree`` is a static Python int so the kernel specializes
    at trace time.

    Args:
        X1: ``(N1, D)`` inputs.
        X2: ``(N2, D)`` inputs.
        variance: Scalar multiplier.
        bias: Scalar additive bias inside the power.
        degree: Positive integer polynomial degree.

    Returns:
        ``(N1, N2)`` Gram matrix.
    """
    if degree < 1:
        raise ValueError(f"polynomial_kernel requires degree >= 1, got {degree!r}")
    dot = einsum(X1, X2, "n1 d, n2 d -> n1 n2")
    return variance * (dot + bias) ** degree

rational_quadratic_kernel(X1, X2, variance, lengthscale, alpha)

Rational quadratic kernel.

.. math:: k(x, x') = \sigma^2 \left( 1 + \frac{|x - x'|^2}{2\alpha \ell^2} \right)^{-\alpha}

Scale mixture of RBF kernels: the limit alpha -> infty recovers the RBF, small alpha yields heavier-tailed correlations.

Parameters:

Name Type Description Default
X1 Float[Array, 'N1 D']

(N1, D) inputs.

required
X2 Float[Array, 'N2 D']

(N2, D) inputs.

required
variance Float[Array, '']

Scalar signal variance.

required
lengthscale Float[Array, '']

Scalar lengthscale.

required
alpha Float[Array, '']

Scalar shape parameter; must be positive.

required

Returns:

Type Description
Float[Array, 'N1 N2']

(N1, N2) Gram matrix.

Source code in src/pyrox/gp/_src/kernels.py
def rational_quadratic_kernel(
    X1: Float[Array, "N1 D"],
    X2: Float[Array, "N2 D"],
    variance: Float[Array, ""],
    lengthscale: Float[Array, ""],
    alpha: Float[Array, ""],
) -> Float[Array, "N1 N2"]:
    r"""Rational quadratic kernel.

    .. math::
        k(x, x') = \sigma^2 \left(
            1 + \frac{\|x - x'\|^2}{2\alpha \ell^2}
        \right)^{-\alpha}

    Scale mixture of RBF kernels: the limit ``alpha -> infty`` recovers the
    RBF, small ``alpha`` yields heavier-tailed correlations.

    Args:
        X1: ``(N1, D)`` inputs.
        X2: ``(N2, D)`` inputs.
        variance: Scalar signal variance.
        lengthscale: Scalar lengthscale.
        alpha: Scalar shape parameter; must be positive.

    Returns:
        ``(N1, N2)`` Gram matrix.
    """
    sq = _pairwise_sq_dist(X1, X2)
    return variance * (1.0 + sq / (2.0 * alpha * lengthscale * lengthscale)) ** (-alpha)

rbf_kernel(X1, X2, variance, lengthscale)

Radial basis function (squared exponential) kernel.

.. math:: k(x, x') = \sigma^2 \exp!\left(-\frac{|x - x'|^2}{2\ell^2}\right)

Parameters:

Name Type Description Default
X1 Float[Array, 'N1 D']

(N1, D) inputs.

required
X2 Float[Array, 'N2 D']

(N2, D) inputs.

required
variance Float[Array, '']

Scalar signal variance sigma^2.

required
lengthscale Float[Array, '']

Scalar lengthscale ell.

required

Returns:

Type Description
Float[Array, 'N1 N2']

(N1, N2) kernel Gram matrix.

Source code in src/pyrox/gp/_src/kernels.py
def rbf_kernel(
    X1: Float[Array, "N1 D"],
    X2: Float[Array, "N2 D"],
    variance: Float[Array, ""],
    lengthscale: Float[Array, ""],
) -> Float[Array, "N1 N2"]:
    r"""Radial basis function (squared exponential) kernel.

    .. math::
        k(x, x') = \sigma^2 \exp\!\left(-\frac{\|x - x'\|^2}{2\ell^2}\right)

    Args:
        X1: ``(N1, D)`` inputs.
        X2: ``(N2, D)`` inputs.
        variance: Scalar signal variance ``sigma^2``.
        lengthscale: Scalar lengthscale ``ell``.

    Returns:
        ``(N1, N2)`` kernel Gram matrix.
    """
    sq = _pairwise_sq_dist(X1, X2)
    return variance * jnp.exp(-0.5 * sq / (lengthscale * lengthscale))

white_kernel(X1, X2, variance)

White-noise kernel.

.. math:: k(x, x') = \sigma^2 \,\delta(x, x')

Nonzero only where X1[i] exactly matches X2[j] across all feature dimensions. When evaluated at X1 == X2 this yields sigma^2 * I.

Parameters:

Name Type Description Default
X1 Float[Array, 'N1 D']

(N1, D) inputs.

required
X2 Float[Array, 'N2 D']

(N2, D) inputs.

required
variance Float[Array, '']

Scalar noise variance.

required

Returns:

Type Description
Float[Array, 'N1 N2']

(N1, N2) Gram matrix.

Source code in src/pyrox/gp/_src/kernels.py
def white_kernel(
    X1: Float[Array, "N1 D"],
    X2: Float[Array, "N2 D"],
    variance: Float[Array, ""],
) -> Float[Array, "N1 N2"]:
    r"""White-noise kernel.

    .. math::
        k(x, x') = \sigma^2 \,\delta(x, x')

    Nonzero only where ``X1[i]`` exactly matches ``X2[j]`` across all feature
    dimensions. When evaluated at ``X1 == X2`` this yields ``sigma^2 * I``.

    Args:
        X1: ``(N1, D)`` inputs.
        X2: ``(N2, D)`` inputs.
        variance: Scalar noise variance.

    Returns:
        ``(N1, N2)`` Gram matrix.
    """
    diff = rearrange(X1, "n1 d -> n1 1 d") - rearrange(X2, "n2 d -> 1 n2 d")
    match = jnp.all(diff == 0.0, axis=-1)
    return variance * match.astype(X1.dtype)