Skip to content

GP API

Wave 2 ships the dense-GP foundation: kernel math functions, concrete Parameterized kernel classes, abstract component protocols, and the model-facing entry points (GPPrior, ConditionedGP, gp_factor, gp_sample). Scalable matrix construction and solver strategies (numerically stable assembly, implicit operators, batched matvec, Cholesky / CG / BBMM / LSMR / SLQ) live in gaussx.

Split with gaussx

pyrox owns the kernel function side — closed-form math primitives readable in a dozen lines — plus the NumPyro-aware model shell (GPPrior, gp_factor, gp_sample). gaussx owns every piece of linear algebra: stable matrix construction, solver strategies, and the underlying MultivariateNormal distribution. The model entry points accept any gaussx.AbstractSolverStrategy (default gaussx.DenseSolver()).

Model entry points

import jax.numpy as jnp
import numpyro
from pyrox.gp import GPPrior, RBF, gp_factor, gp_sample

def regression_model(X, y):
    """Collapsed Gaussian-likelihood GP regression."""
    kernel = RBF()
    prior = GPPrior(kernel=kernel, X=X)
    gp_factor("obs", prior, y, noise_var=jnp.array(0.05))


def latent_model(X):
    """Latent-function GP for non-conjugate likelihoods."""
    kernel = RBF()
    prior = GPPrior(kernel=kernel, X=X)
    f = gp_sample("f", prior)
    # ... attach any likelihood to f here, e.g. Bernoulli or Poisson.

Swap the solver strategy at construction time:

from gaussx import CGSolver, ComposedSolver, DenseLogdet, DenseSolver
prior = GPPrior(kernel=RBF(), X=X, solver=CGSolver())
# Or compose — CG for solve, dense Cholesky for logdet:
prior = GPPrior(
    kernel=RBF(), X=X,
    solver=ComposedSolver(solve_strategy=CGSolver(), logdet_strategy=DenseLogdet()),
)

pyrox.gp.GPPrior

Bases: Module

Finite-dimensional GP prior over a fixed training input set.

Holds a kernel, training inputs X, an optional mean function, an optional solver strategy, and a small diagonal jitter for numerical stability on otherwise-singular prior covariances.

Attributes:

Name Type Description
kernel Kernel

Any :class:pyrox.gp.Kernel — evaluated on X.

X Float[Array, 'N D']

Training inputs of shape (N, D).

mean_fn Callable[[Float[Array, 'N D']], Float[Array, ' N']] | None

Callable X -> (N,) or None for the zero mean.

solver AbstractSolverStrategy | None

Any gaussx.AbstractSolverStrategy. Defaults to gaussx.DenseSolver() — swap for CGSolver, BBMMSolver, ComposedSolver(solve=..., logdet=...), etc.

jitter float

Diagonal regularization added to the prior covariance for numerical stability. Not a noise model — use noise_var on :meth:condition for that.

Source code in src/pyrox/gp/_models.py
class GPPrior(eqx.Module):
    """Finite-dimensional GP prior over a fixed training input set.

    Holds a kernel, training inputs ``X``, an optional mean function, an
    optional solver strategy, and a small diagonal jitter for numerical
    stability on otherwise-singular prior covariances.

    Attributes:
        kernel: Any :class:`pyrox.gp.Kernel` — evaluated on ``X``.
        X: Training inputs of shape ``(N, D)``.
        mean_fn: Callable ``X -> (N,)`` or ``None`` for the zero mean.
        solver: Any ``gaussx.AbstractSolverStrategy``. Defaults to
            ``gaussx.DenseSolver()`` — swap for ``CGSolver``,
            ``BBMMSolver``, ``ComposedSolver(solve=..., logdet=...)``, etc.
        jitter: Diagonal regularization added to the prior covariance
            for numerical stability. Not a noise model — use
            ``noise_var`` on :meth:`condition` for that.
    """

    kernel: Kernel
    X: Float[Array, "N D"]
    mean_fn: Callable[[Float[Array, "N D"]], Float[Array, " N"]] | None = None
    solver: AbstractSolverStrategy | None = None
    jitter: float = 1e-6

    def mean(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
        """Evaluate the mean function at ``X``; zero by default."""
        if self.mean_fn is None:
            return jnp.zeros(X.shape[0], dtype=X.dtype)
        return self.mean_fn(X)

    def _prior_operator(self) -> lx.AbstractLinearOperator:
        K = self.kernel(self.X, self.X)
        K = K + self.jitter * jnp.eye(K.shape[0], dtype=K.dtype)
        return _psd_operator(K)

    def _noisy_operator(self, noise_var: Float[Array, ""]) -> lx.AbstractLinearOperator:
        K = self.kernel(self.X, self.X)
        reg = (self.jitter + noise_var) * jnp.eye(K.shape[0], dtype=K.dtype)
        return _psd_operator(K + reg)

    def _resolved_solver(self) -> AbstractSolverStrategy:
        return DenseSolver() if self.solver is None else self.solver  # ty: ignore[invalid-return-type]

    def log_prob(self, f: Float[Array, " N"]) -> Float[Array, ""]:
        r"""Marginal log-density of ``f`` under the GP prior.

        Computes :math:`\log \mathcal{N}(f \mid \mu(X), K(X, X) + \text{jitter}\,I)`
        using :func:`gaussx.log_marginal_likelihood`, so any solver strategy
        on this prior applies.
        """
        return log_marginal_likelihood(
            self.mean(self.X),
            self._prior_operator(),
            f,
            solver=self._resolved_solver(),
        )

    def sample(self, key: Array) -> Float[Array, " N"]:
        r"""Draw ``f \sim p(f) = \mathcal{N}(\mu(X), K + \text{jitter}\,I)``.

        Wraps the prior in a :class:`gaussx.MultivariateNormal` with
        the configured :attr:`solver`. This is the non-NumPyro analogue
        of :func:`gp_sample` — useful for tests, diagnostics, and
        prior-sample initialization without registering a sample site.
        """
        op = self._prior_operator()
        loc = self.mean(self.X)
        mvn = MultivariateNormal(loc, op, solver=self._resolved_solver())
        return mvn.sample(key)

    def condition(
        self,
        y: Float[Array, " N"],
        noise_var: Float[Array, ""],
    ) -> ConditionedGP:
        """Condition on Gaussian-likelihood observations ``y``.

        Precomputes
        ``alpha = (K + (jitter + noise_var) * I)^{-1} (y - mu(X))`` and
        caches it in the returned :class:`ConditionedGP`. The same
        ``jitter`` regularization configured on this prior is included
        alongside ``noise_var`` in the conditioned operator and solve, so
        every downstream predict / sample call sees the regularized
        covariance.

        The operator construction and any subsequent hyperparameter
        capture share one :func:`_kernel_context`, so for Pattern B/C
        kernels with priors the cached operator and the resolved
        hyperparameters on the returned :class:`ConditionedGP` come from
        the same draw. Downstream consumers (notably
        :class:`pyrox.gp.PathwiseSampler`) reuse those values to stay
        consistent with the cached operator.
        """
        with _kernel_context(self.kernel):
            operator = self._noisy_operator(noise_var)
            resolved_hyperparams = _resolve_kernel_hyperparams(self.kernel)
        residual = y - self.mean(self.X)
        cache = build_prediction_cache(
            operator, residual, solver=self._resolved_solver()
        )
        return ConditionedGP(  # ty: ignore[invalid-return-type]
            prior=self,
            y=y,
            noise_var=noise_var,
            cache=cache,
            operator=operator,
            resolved_hyperparams=resolved_hyperparams,
        )

    def condition_nongauss(
        self,
        likelihood: Likelihood,
        y: Float[Array, " N"],
        *,
        strategy: _NonGaussStrategy,
    ) -> NonGaussConditionedGP:
        """Condition on a non-Gaussian likelihood via a site-based strategy.

        Convenience that forwards to ``strategy.fit(self, likelihood, y)``.
        Pick any of the site-based strategies in
        :mod:`pyrox.gp._inference_nongauss`:
        :class:`pyrox.gp.LaplaceInference`,
        :class:`pyrox.gp.GaussNewtonInference`,
        :class:`pyrox.gp.PosteriorLinearization`,
        :class:`pyrox.gp.ExpectationPropagation`, or
        :class:`pyrox.gp.QuasiNewtonInference`. Returns a
        :class:`pyrox.gp.NonGaussConditionedGP` with the same
        ``predict`` / ``predict_mean`` / ``predict_var`` API as the
        Gaussian-likelihood :class:`ConditionedGP`.

        Example::

            from pyrox.gp import (
                BernoulliLikelihood,
                ExpectationPropagation,
                GPPrior,
                RBF,
            )

            prior = GPPrior(kernel=RBF(), X=X)
            cond = prior.condition_nongauss(
                BernoulliLikelihood(), y,
                strategy=ExpectationPropagation(),
            )
            mean, var = cond.predict(X_star)
        """
        return strategy.fit(self, likelihood, y)

condition(y, noise_var)

Condition on Gaussian-likelihood observations y.

Precomputes alpha = (K + (jitter + noise_var) * I)^{-1} (y - mu(X)) and caches it in the returned :class:ConditionedGP. The same jitter regularization configured on this prior is included alongside noise_var in the conditioned operator and solve, so every downstream predict / sample call sees the regularized covariance.

The operator construction and any subsequent hyperparameter capture share one :func:_kernel_context, so for Pattern B/C kernels with priors the cached operator and the resolved hyperparameters on the returned :class:ConditionedGP come from the same draw. Downstream consumers (notably :class:pyrox.gp.PathwiseSampler) reuse those values to stay consistent with the cached operator.

Source code in src/pyrox/gp/_models.py
def condition(
    self,
    y: Float[Array, " N"],
    noise_var: Float[Array, ""],
) -> ConditionedGP:
    """Condition on Gaussian-likelihood observations ``y``.

    Precomputes
    ``alpha = (K + (jitter + noise_var) * I)^{-1} (y - mu(X))`` and
    caches it in the returned :class:`ConditionedGP`. The same
    ``jitter`` regularization configured on this prior is included
    alongside ``noise_var`` in the conditioned operator and solve, so
    every downstream predict / sample call sees the regularized
    covariance.

    The operator construction and any subsequent hyperparameter
    capture share one :func:`_kernel_context`, so for Pattern B/C
    kernels with priors the cached operator and the resolved
    hyperparameters on the returned :class:`ConditionedGP` come from
    the same draw. Downstream consumers (notably
    :class:`pyrox.gp.PathwiseSampler`) reuse those values to stay
    consistent with the cached operator.
    """
    with _kernel_context(self.kernel):
        operator = self._noisy_operator(noise_var)
        resolved_hyperparams = _resolve_kernel_hyperparams(self.kernel)
    residual = y - self.mean(self.X)
    cache = build_prediction_cache(
        operator, residual, solver=self._resolved_solver()
    )
    return ConditionedGP(  # ty: ignore[invalid-return-type]
        prior=self,
        y=y,
        noise_var=noise_var,
        cache=cache,
        operator=operator,
        resolved_hyperparams=resolved_hyperparams,
    )

condition_nongauss(likelihood, y, *, strategy)

Condition on a non-Gaussian likelihood via a site-based strategy.

Convenience that forwards to strategy.fit(self, likelihood, y). Pick any of the site-based strategies in :mod:pyrox.gp._inference_nongauss: :class:pyrox.gp.LaplaceInference, :class:pyrox.gp.GaussNewtonInference, :class:pyrox.gp.PosteriorLinearization, :class:pyrox.gp.ExpectationPropagation, or :class:pyrox.gp.QuasiNewtonInference. Returns a :class:pyrox.gp.NonGaussConditionedGP with the same predict / predict_mean / predict_var API as the Gaussian-likelihood :class:ConditionedGP.

Example::

from pyrox.gp import (
    BernoulliLikelihood,
    ExpectationPropagation,
    GPPrior,
    RBF,
)

prior = GPPrior(kernel=RBF(), X=X)
cond = prior.condition_nongauss(
    BernoulliLikelihood(), y,
    strategy=ExpectationPropagation(),
)
mean, var = cond.predict(X_star)
Source code in src/pyrox/gp/_models.py
def condition_nongauss(
    self,
    likelihood: Likelihood,
    y: Float[Array, " N"],
    *,
    strategy: _NonGaussStrategy,
) -> NonGaussConditionedGP:
    """Condition on a non-Gaussian likelihood via a site-based strategy.

    Convenience that forwards to ``strategy.fit(self, likelihood, y)``.
    Pick any of the site-based strategies in
    :mod:`pyrox.gp._inference_nongauss`:
    :class:`pyrox.gp.LaplaceInference`,
    :class:`pyrox.gp.GaussNewtonInference`,
    :class:`pyrox.gp.PosteriorLinearization`,
    :class:`pyrox.gp.ExpectationPropagation`, or
    :class:`pyrox.gp.QuasiNewtonInference`. Returns a
    :class:`pyrox.gp.NonGaussConditionedGP` with the same
    ``predict`` / ``predict_mean`` / ``predict_var`` API as the
    Gaussian-likelihood :class:`ConditionedGP`.

    Example::

        from pyrox.gp import (
            BernoulliLikelihood,
            ExpectationPropagation,
            GPPrior,
            RBF,
        )

        prior = GPPrior(kernel=RBF(), X=X)
        cond = prior.condition_nongauss(
            BernoulliLikelihood(), y,
            strategy=ExpectationPropagation(),
        )
        mean, var = cond.predict(X_star)
    """
    return strategy.fit(self, likelihood, y)

log_prob(f)

Marginal log-density of f under the GP prior.

Computes :math:\log \mathcal{N}(f \mid \mu(X), K(X, X) + \text{jitter}\,I) using :func:gaussx.log_marginal_likelihood, so any solver strategy on this prior applies.

Source code in src/pyrox/gp/_models.py
def log_prob(self, f: Float[Array, " N"]) -> Float[Array, ""]:
    r"""Marginal log-density of ``f`` under the GP prior.

    Computes :math:`\log \mathcal{N}(f \mid \mu(X), K(X, X) + \text{jitter}\,I)`
    using :func:`gaussx.log_marginal_likelihood`, so any solver strategy
    on this prior applies.
    """
    return log_marginal_likelihood(
        self.mean(self.X),
        self._prior_operator(),
        f,
        solver=self._resolved_solver(),
    )

mean(X)

Evaluate the mean function at X; zero by default.

Source code in src/pyrox/gp/_models.py
def mean(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
    """Evaluate the mean function at ``X``; zero by default."""
    if self.mean_fn is None:
        return jnp.zeros(X.shape[0], dtype=X.dtype)
    return self.mean_fn(X)

sample(key)

Draw f \sim p(f) = \mathcal{N}(\mu(X), K + \text{jitter}\,I).

Wraps the prior in a :class:gaussx.MultivariateNormal with the configured :attr:solver. This is the non-NumPyro analogue of :func:gp_sample — useful for tests, diagnostics, and prior-sample initialization without registering a sample site.

Source code in src/pyrox/gp/_models.py
def sample(self, key: Array) -> Float[Array, " N"]:
    r"""Draw ``f \sim p(f) = \mathcal{N}(\mu(X), K + \text{jitter}\,I)``.

    Wraps the prior in a :class:`gaussx.MultivariateNormal` with
    the configured :attr:`solver`. This is the non-NumPyro analogue
    of :func:`gp_sample` — useful for tests, diagnostics, and
    prior-sample initialization without registering a sample site.
    """
    op = self._prior_operator()
    loc = self.mean(self.X)
    mvn = MultivariateNormal(loc, op, solver=self._resolved_solver())
    return mvn.sample(key)

pyrox.gp.ConditionedGP

Bases: Module

GP conditioned on Gaussian-likelihood training observations.

Holds the precomputed training solve alpha (via :class:gaussx.PredictionCache) and the noisy covariance operator so predictions at multiple test sets reuse the training solve.

Source code in src/pyrox/gp/_models.py
class ConditionedGP(eqx.Module):
    """GP conditioned on Gaussian-likelihood training observations.

    Holds the precomputed training solve ``alpha`` (via
    :class:`gaussx.PredictionCache`) and the noisy covariance operator so
    predictions at multiple test sets reuse the training solve.
    """

    prior: GPPrior
    y: Float[Array, " N"]
    noise_var: Float[Array, ""]
    cache: PredictionCache
    operator: lx.AbstractLinearOperator
    resolved_hyperparams: tuple[Float[Array, ""], Float[Array, ""]] | None = None

    def predict_mean(self, X_star: Float[Array, "M D"]) -> Float[Array, " M"]:
        r""":math:`\mu_* = \mu(X_*) + K_{*f}\,\alpha`."""
        with _kernel_context(self.prior.kernel):
            K_cross = self.prior.kernel(X_star, self.prior.X)
        return self.prior.mean(X_star) + predict_mean(self.cache, K_cross)

    def predict_var(self, X_star: Float[Array, "M D"]) -> Float[Array, " M"]:
        r"""Diagonal predictive variance at ``X_*``.

        .. math::
            \sigma^2_{*,i} = k(x_{*,i}, x_{*,i})
                - K_{*f}[i,:] \cdot (K + \sigma^2 I)^{-1} K_{f*}[:,i]

        ``K_cross`` and ``K_diag`` are computed under one shared kernel
        context so Pattern B / C kernels with prior'd hyperparameters
        register their NumPyro sites once and reuse them across both
        kernel calls (and the cached training solve).
        """
        with _kernel_context(self.prior.kernel):
            K_cross = self.prior.kernel(X_star, self.prior.X)
            K_diag = self.prior.kernel.diag(X_star)
        return predict_variance(
            K_cross,
            K_diag,
            self.operator,
            solver=self.prior._resolved_solver(),
        )

    def predict(
        self, X_star: Float[Array, "M D"]
    ) -> tuple[Float[Array, " M"], Float[Array, " M"]]:
        """Return ``(mean, variance)`` at ``X_*`` as a tuple.

        Both kernel evaluations share a single kernel context; see
        :meth:`predict_var`.
        """
        with _kernel_context(self.prior.kernel):
            return self.predict_mean(X_star), self.predict_var(X_star)

    def sample(
        self,
        key: Array,
        X_star: Float[Array, "M D"],
        n_samples: int = 1,
    ) -> Float[Array, "S M"]:
        """Sample from the diagonal predictive ``N(mean, diag(var))``.

        Returns samples independently per test point; correlated joint
        samples from the full predictive covariance are not covered by
        the Wave 2 dense surface. For correlated samples, build the full
        predictive covariance explicitly and draw from
        :class:`gaussx.MultivariateNormal`.
        """
        with _kernel_context(self.prior.kernel):
            mean = self.predict_mean(X_star)
            var = self.predict_var(X_star)
        std = jnp.sqrt(jnp.clip(var, min=0.0))
        eps = jax.random.normal(key, (n_samples, X_star.shape[0]), dtype=mean.dtype)
        return einsum(std, eps, "m, s m -> s m") + mean

predict(X_star)

Return (mean, variance) at X_* as a tuple.

Both kernel evaluations share a single kernel context; see :meth:predict_var.

Source code in src/pyrox/gp/_models.py
def predict(
    self, X_star: Float[Array, "M D"]
) -> tuple[Float[Array, " M"], Float[Array, " M"]]:
    """Return ``(mean, variance)`` at ``X_*`` as a tuple.

    Both kernel evaluations share a single kernel context; see
    :meth:`predict_var`.
    """
    with _kernel_context(self.prior.kernel):
        return self.predict_mean(X_star), self.predict_var(X_star)

predict_mean(X_star)

:math:\mu_* = \mu(X_*) + K_{*f}\,\alpha.

Source code in src/pyrox/gp/_models.py
def predict_mean(self, X_star: Float[Array, "M D"]) -> Float[Array, " M"]:
    r""":math:`\mu_* = \mu(X_*) + K_{*f}\,\alpha`."""
    with _kernel_context(self.prior.kernel):
        K_cross = self.prior.kernel(X_star, self.prior.X)
    return self.prior.mean(X_star) + predict_mean(self.cache, K_cross)

predict_var(X_star)

Diagonal predictive variance at X_*.

.. math:: \sigma^2_{,i} = k(x_{,i}, x_{,i}) - K_{f}[i,:] \cdot (K + \sigma^2 I)^{-1} K_{f*}[:,i]

K_cross and K_diag are computed under one shared kernel context so Pattern B / C kernels with prior'd hyperparameters register their NumPyro sites once and reuse them across both kernel calls (and the cached training solve).

Source code in src/pyrox/gp/_models.py
def predict_var(self, X_star: Float[Array, "M D"]) -> Float[Array, " M"]:
    r"""Diagonal predictive variance at ``X_*``.

    .. math::
        \sigma^2_{*,i} = k(x_{*,i}, x_{*,i})
            - K_{*f}[i,:] \cdot (K + \sigma^2 I)^{-1} K_{f*}[:,i]

    ``K_cross`` and ``K_diag`` are computed under one shared kernel
    context so Pattern B / C kernels with prior'd hyperparameters
    register their NumPyro sites once and reuse them across both
    kernel calls (and the cached training solve).
    """
    with _kernel_context(self.prior.kernel):
        K_cross = self.prior.kernel(X_star, self.prior.X)
        K_diag = self.prior.kernel.diag(X_star)
    return predict_variance(
        K_cross,
        K_diag,
        self.operator,
        solver=self.prior._resolved_solver(),
    )

sample(key, X_star, n_samples=1)

Sample from the diagonal predictive N(mean, diag(var)).

Returns samples independently per test point; correlated joint samples from the full predictive covariance are not covered by the Wave 2 dense surface. For correlated samples, build the full predictive covariance explicitly and draw from :class:gaussx.MultivariateNormal.

Source code in src/pyrox/gp/_models.py
def sample(
    self,
    key: Array,
    X_star: Float[Array, "M D"],
    n_samples: int = 1,
) -> Float[Array, "S M"]:
    """Sample from the diagonal predictive ``N(mean, diag(var))``.

    Returns samples independently per test point; correlated joint
    samples from the full predictive covariance are not covered by
    the Wave 2 dense surface. For correlated samples, build the full
    predictive covariance explicitly and draw from
    :class:`gaussx.MultivariateNormal`.
    """
    with _kernel_context(self.prior.kernel):
        mean = self.predict_mean(X_star)
        var = self.predict_var(X_star)
    std = jnp.sqrt(jnp.clip(var, min=0.0))
    eps = jax.random.normal(key, (n_samples, X_star.shape[0]), dtype=mean.dtype)
    return einsum(std, eps, "m, s m -> s m") + mean

pyrox.gp.gp_factor(name, prior, y, noise_var)

Register the collapsed GP log marginal likelihood with NumPyro.

Adds log p(y | X, theta) = log N(y | mu, K + (jitter + sigma^2) I) to the NumPyro trace as numpyro.factor(name, ...). The prior's jitter is included in addition to the observation noise variance so the covariance matches what :meth:GPPrior.condition builds. Use this inside a NumPyro model when the likelihood is Gaussian and you want the latent function marginalized analytically.

Source code in src/pyrox/gp/_models.py
def gp_factor(
    name: str,
    prior: GPPrior,
    y: Float[Array, " N"],
    noise_var: Float[Array, ""],
) -> None:
    """Register the collapsed GP log marginal likelihood with NumPyro.

    Adds
    ``log p(y | X, theta) = log N(y | mu, K + (jitter + sigma^2) I)``
    to the NumPyro trace as ``numpyro.factor(name, ...)``. The prior's
    ``jitter`` is included in addition to the observation noise variance
    so the covariance matches what :meth:`GPPrior.condition` builds. Use
    this inside a NumPyro model when the likelihood is Gaussian and you
    want the latent function marginalized analytically.
    """
    logp = log_marginal_likelihood(
        prior.mean(prior.X),
        prior._noisy_operator(noise_var),
        y,
        solver=prior._resolved_solver(),
    )
    numpyro.factor(name, logp)

pyrox.gp.gp_sample(name, prior, *, whitened=False, guide=None)

Sample a latent function f at the prior's training inputs.

Three mutually exclusive modes:

  • whitened=False, guide=None (default) — register a single numpyro.sample(name, MVN(mu, K + jitter I)) site. The latent function is sampled directly from the prior.
  • whitened=True, guide=None — register a unit-normal latent site f"{name}_u" with shape (N,) and return the deterministic value f = mu(X) + L u where L is the Cholesky factor of K + jitter I. This reparameterization is the standard fix for mean-field SVI on GP-correlated latents (Murray & Adams, 2010): a NumPyro auto-guide such as :class:numpyro.infer.autoguide.AutoNormal then approximates the well-conditioned isotropic posterior over u instead of the ill-conditioned correlated posterior over f.
  • guide provided — delegate to guide.register(name, prior). Concrete variational guides (Wave 3) own their own parameterization, so combining whitened=True with guide is rejected.

Use this inside a NumPyro model for non-conjugate likelihoods, where the latent function cannot be marginalized analytically.

Source code in src/pyrox/gp/_models.py
def gp_sample(
    name: str,
    prior: GPPrior,
    *,
    whitened: bool = False,
    guide: object | None = None,
) -> Float[Array, " N"]:
    r"""Sample a latent function ``f`` at the prior's training inputs.

    Three mutually exclusive modes:

    * ``whitened=False``, ``guide=None`` (default) — register a single
      ``numpyro.sample(name, MVN(mu, K + jitter I))`` site. The latent
      function is sampled directly from the prior.
    * ``whitened=True``, ``guide=None`` — register a unit-normal latent
      site ``f"{name}_u"`` with shape ``(N,)`` and return the
      deterministic value ``f = mu(X) + L u`` where ``L`` is the
      Cholesky factor of ``K + jitter I``. This reparameterization is the
      standard fix for mean-field SVI on GP-correlated latents
      (Murray & Adams, 2010): a NumPyro auto-guide such as
      :class:`numpyro.infer.autoguide.AutoNormal` then approximates the
      well-conditioned isotropic posterior over ``u`` instead of the
      ill-conditioned correlated posterior over ``f``.
    * ``guide`` provided — delegate to ``guide.register(name, prior)``.
      Concrete variational guides (Wave 3) own their own
      parameterization, so combining ``whitened=True`` with ``guide`` is
      rejected.

    Use this inside a NumPyro model for non-conjugate likelihoods, where
    the latent function cannot be marginalized analytically.
    """
    if guide is not None:
        if whitened:
            raise ValueError(
                "gp_sample: cannot combine `whitened=True` with `guide=...`. "
                "Provide one or the other; concrete guides own their own "
                "parameterization."
            )
        return guide.register(name, prior)  # type: ignore[attr-defined]  # ty: ignore[unresolved-attribute]

    if whitened:
        L = cholesky(prior._prior_operator())
        n = prior.X.shape[0]
        dtype = prior.X.dtype
        u = numpyro.sample(
            f"{name}_u",
            dist.Normal(jnp.zeros(n, dtype=dtype), jnp.ones((), dtype=dtype)).to_event(
                1
            ),
        )
        f = prior.mean(prior.X) + unwhiten(jnp.asarray(u), L)
        return numpyro.deterministic(name, f)  # ty: ignore[invalid-return-type]

    return numpyro.sample(  # ty: ignore[invalid-return-type]
        name,
        MultivariateNormal(
            prior.mean(prior.X),
            prior._prior_operator(),
            solver=prior._resolved_solver(),
        ),
    )

Concrete kernels

Each Parameterized kernel registers its hyperparameters with positivity constraints where appropriate. Attach priors with set_prior, autoguides with autoguide, and flip set_mode("model" | "guide").

pyrox.gp.RBF

Bases: _ParameterizedKernel

Radial basis function (squared exponential) kernel.

Source code in src/pyrox/gp/_kernels.py
class RBF(_ParameterizedKernel):
    """Radial basis function (squared exponential) kernel."""

    pyrox_name: str = "RBF"
    init_variance: float = 1.0
    init_lengthscale: float = 1.0

    def setup(self) -> None:
        self.register_param(
            "variance",
            jnp.asarray(self.init_variance),
            constraint=dist.constraints.positive,
        )
        self.register_param(
            "lengthscale",
            jnp.asarray(self.init_lengthscale),
            constraint=dist.constraints.positive,
        )

    @pyrox_method
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        return _k.rbf_kernel(
            X1, X2, self.get_param("variance"), self.get_param("lengthscale")
        )

pyrox.gp.Matern

Bases: _ParameterizedKernel

Matern kernel with nu in {0.5, 1.5, 2.5}.

nu is a static class attribute — it selects a code path in the underlying math primitive and is not a trainable parameter.

Source code in src/pyrox/gp/_kernels.py
class Matern(_ParameterizedKernel):
    """Matern kernel with ``nu in {0.5, 1.5, 2.5}``.

    ``nu`` is a static class attribute — it selects a code path in the
    underlying math primitive and is not a trainable parameter.
    """

    pyrox_name: str = "Matern"
    init_variance: float = 1.0
    init_lengthscale: float = 1.0
    nu: float = 2.5

    def setup(self) -> None:
        self.register_param(
            "variance",
            jnp.asarray(self.init_variance),
            constraint=dist.constraints.positive,
        )
        self.register_param(
            "lengthscale",
            jnp.asarray(self.init_lengthscale),
            constraint=dist.constraints.positive,
        )

    @pyrox_method
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        return _k.matern_kernel(
            X1,
            X2,
            self.get_param("variance"),
            self.get_param("lengthscale"),
            self.nu,
        )

pyrox.gp.Periodic

Bases: _ParameterizedKernel

Periodic (MacKay) kernel.

Source code in src/pyrox/gp/_kernels.py
class Periodic(_ParameterizedKernel):
    """Periodic (MacKay) kernel."""

    pyrox_name: str = "Periodic"
    init_variance: float = 1.0
    init_lengthscale: float = 1.0
    init_period: float = 1.0

    def setup(self) -> None:
        self.register_param(
            "variance",
            jnp.asarray(self.init_variance),
            constraint=dist.constraints.positive,
        )
        self.register_param(
            "lengthscale",
            jnp.asarray(self.init_lengthscale),
            constraint=dist.constraints.positive,
        )
        self.register_param(
            "period",
            jnp.asarray(self.init_period),
            constraint=dist.constraints.positive,
        )

    @pyrox_method
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        return _k.periodic_kernel(
            X1,
            X2,
            self.get_param("variance"),
            self.get_param("lengthscale"),
            self.get_param("period"),
        )

pyrox.gp.Linear

Bases: _ParameterizedKernel

Linear kernel sigma^2 x^T x' + bias.

bias is constrained nonnegative because k = sigma^2 X X^T + b 1 1^T is only PSD for b >= 0 (e.g. X = 0 gives eigenvalue N*b).

Source code in src/pyrox/gp/_kernels.py
class Linear(_ParameterizedKernel):
    """Linear kernel ``sigma^2 x^T x' + bias``.

    ``bias`` is constrained nonnegative because ``k = sigma^2 X X^T + b 1 1^T``
    is only PSD for ``b >= 0`` (e.g. ``X = 0`` gives eigenvalue ``N*b``).
    """

    pyrox_name: str = "Linear"
    init_variance: float = 1.0
    init_bias: float = 0.0

    def setup(self) -> None:
        self.register_param(
            "variance",
            jnp.asarray(self.init_variance),
            constraint=dist.constraints.positive,
        )
        self.register_param(
            "bias",
            jnp.asarray(self.init_bias),
            constraint=dist.constraints.nonnegative,
        )

    @pyrox_method
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        return _k.linear_kernel(
            X1, X2, self.get_param("variance"), self.get_param("bias")
        )

    def diag(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
        # Non-stationary: diagonal depends on |X[i]|^2.
        v = self.get_param("variance")
        b = self.get_param("bias")
        return v * jnp.sum(X * X, axis=-1) + b

pyrox.gp.RationalQuadratic

Bases: _ParameterizedKernel

Rational quadratic kernel.

Source code in src/pyrox/gp/_kernels.py
class RationalQuadratic(_ParameterizedKernel):
    """Rational quadratic kernel."""

    pyrox_name: str = "RationalQuadratic"
    init_variance: float = 1.0
    init_lengthscale: float = 1.0
    init_alpha: float = 1.0

    def setup(self) -> None:
        self.register_param(
            "variance",
            jnp.asarray(self.init_variance),
            constraint=dist.constraints.positive,
        )
        self.register_param(
            "lengthscale",
            jnp.asarray(self.init_lengthscale),
            constraint=dist.constraints.positive,
        )
        self.register_param(
            "alpha",
            jnp.asarray(self.init_alpha),
            constraint=dist.constraints.positive,
        )

    @pyrox_method
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        return _k.rational_quadratic_kernel(
            X1,
            X2,
            self.get_param("variance"),
            self.get_param("lengthscale"),
            self.get_param("alpha"),
        )

pyrox.gp.Polynomial

Bases: _ParameterizedKernel

Polynomial kernel sigma^2 (x^T x' + bias)^degree.

degree is a static class field (it selects an integer power, not an optimization target). bias is constrained nonnegative — the degree=1 case reduces to :class:Linear and has the same PSD-requires-b>=0 failure mode.

Source code in src/pyrox/gp/_kernels.py
class Polynomial(_ParameterizedKernel):
    """Polynomial kernel ``sigma^2 (x^T x' + bias)^degree``.

    ``degree`` is a static class field (it selects an integer power, not
    an optimization target). ``bias`` is constrained nonnegative — the
    ``degree=1`` case reduces to :class:`Linear` and has the same
    PSD-requires-``b>=0`` failure mode.
    """

    pyrox_name: str = "Polynomial"
    init_variance: float = 1.0
    init_bias: float = 0.0
    degree: int = 2

    def setup(self) -> None:
        self.register_param(
            "variance",
            jnp.asarray(self.init_variance),
            constraint=dist.constraints.positive,
        )
        self.register_param(
            "bias",
            jnp.asarray(self.init_bias),
            constraint=dist.constraints.nonnegative,
        )

    @pyrox_method
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        return _k.polynomial_kernel(
            X1,
            X2,
            self.get_param("variance"),
            self.get_param("bias"),
            self.degree,
        )

    def diag(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
        v = self.get_param("variance")
        b = self.get_param("bias")
        return v * (jnp.sum(X * X, axis=-1) + b) ** self.degree

pyrox.gp.Cosine

Bases: _ParameterizedKernel

Cosine kernel sigma^2 cos(2 pi ||x - x'|| / period).

Source code in src/pyrox/gp/_kernels.py
class Cosine(_ParameterizedKernel):
    """Cosine kernel ``sigma^2 cos(2 pi ||x - x'|| / period)``."""

    pyrox_name: str = "Cosine"
    init_variance: float = 1.0
    init_period: float = 1.0

    def setup(self) -> None:
        self.register_param(
            "variance",
            jnp.asarray(self.init_variance),
            constraint=dist.constraints.positive,
        )
        self.register_param(
            "period",
            jnp.asarray(self.init_period),
            constraint=dist.constraints.positive,
        )

    @pyrox_method
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        return _k.cosine_kernel(
            X1, X2, self.get_param("variance"), self.get_param("period")
        )

pyrox.gp.White

Bases: _ParameterizedKernel

White-noise kernel sigma^2 delta(x, x').

Source code in src/pyrox/gp/_kernels.py
class White(_ParameterizedKernel):
    """White-noise kernel ``sigma^2 delta(x, x')``."""

    pyrox_name: str = "White"
    init_variance: float = 1.0

    def setup(self) -> None:
        self.register_param(
            "variance",
            jnp.asarray(self.init_variance),
            constraint=dist.constraints.positive,
        )

    @pyrox_method
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        return _k.white_kernel(X1, X2, self.get_param("variance"))

pyrox.gp.Constant

Bases: _ParameterizedKernel

Constant kernel k(x, x') = sigma^2.

Source code in src/pyrox/gp/_kernels.py
class Constant(_ParameterizedKernel):
    """Constant kernel ``k(x, x') = sigma^2``."""

    pyrox_name: str = "Constant"
    init_variance: float = 1.0

    def setup(self) -> None:
        self.register_param(
            "variance",
            jnp.asarray(self.init_variance),
            constraint=dist.constraints.positive,
        )

    @pyrox_method
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        return _k.constant_kernel(X1, X2, self.get_param("variance"))

Sparse-GP inducing features (#49)

Inter-domain inducing-feature families used to build scalable sparse GPs where the inducing-prior covariance K_uu becomes diagonal. Pass any of these to :class:SparseGPPrior via the inducing= keyword in place of a raw point matrix Z.

from pyrox.gp import RBF, FourierInducingFeatures, SparseGPPrior

kernel   = RBF(init_lengthscale=0.3, init_variance=1.0)
features = FourierInducingFeatures.init(in_features=1, num_basis_per_dim=64, L=5.0)
prior    = SparseGPPrior(kernel=kernel, inducing=features)   # K_uu is diagonal!

pyrox.gp.InducingFeatures

Bases: Protocol

Protocol for inter-domain inducing features.

Implementations expose the inducing-prior covariance K_uu and the cross-covariance k_ux(X) between data points and inducing features. Diagonal-friendly concretions return :class:lineax.DiagonalLinearOperator so the downstream solve dispatches to elementwise division.

Input shape is family-dependent. k_ux takes a batch of data points X in whatever representation the family consumes:

  • :class:FourierInducingFeatures: coordinates (N, D).
  • :class:SphericalHarmonicInducingFeatures: unit vectors (N, 3).
  • :class:LaplacianInducingFeatures: integer node indices (N,).

Each implementation validates its own expected shape and dtype.

Source code in src/pyrox/gp/_inducing.py
@runtime_checkable
class InducingFeatures(Protocol):
    """Protocol for inter-domain inducing features.

    Implementations expose the inducing-prior covariance ``K_uu`` and the
    cross-covariance ``k_ux(X)`` between data points and inducing
    features. Diagonal-friendly concretions return
    :class:`lineax.DiagonalLinearOperator` so the downstream solve dispatches
    to elementwise division.

    **Input shape is family-dependent.** ``k_ux`` takes a batch of data
    points ``X`` in whatever representation the family consumes:

    - :class:`FourierInducingFeatures`: coordinates ``(N, D)``.
    - :class:`SphericalHarmonicInducingFeatures`: unit vectors ``(N, 3)``.
    - :class:`LaplacianInducingFeatures`: integer node indices ``(N,)``.

    Each implementation validates its own expected shape and dtype.
    """

    @property
    def num_features(self) -> int: ...

    def K_uu(
        self, kernel: Kernel, *, jitter: float = 1e-6
    ) -> lx.AbstractLinearOperator: ...

    def k_ux(self, x: Array, kernel: Kernel) -> Float[Array, "N M"]: ...

pyrox.gp.FourierInducingFeatures

Bases: Module

VFF — Variational Fourier inducing features on :math:[-L, L]^D.

For a stationary kernel with spectral density :math:S(\cdot), the basis :math:\{\phi_j\} of Laplacian eigenfunctions on the box gives

.. math::

K_{uu} = \mathrm{diag}\!\big(S(\sqrt{\lambda_j})\big),
\qquad
K_{ux}(x)_j = S(\sqrt{\lambda_j})\,\phi_j(x).

With this convention :math:K_{ux} K_{uu}^{-1} = \phi_j(x), so the SVGP predictive mean reduces to a basis evaluation. K_{uu} is returned as a :class:lineax.DiagonalLinearOperator to preserve the O(M) solve dispatch end-to-end.

Attributes:

Name Type Description
in_features int

Input dimension :math:D.

num_basis_per_dim tuple[int, ...]

Per-axis number of 1D eigenfunctions; total count is prod(num_basis_per_dim).

L tuple[float, ...]

Per-axis box half-width.

Source code in src/pyrox/gp/_inducing.py
class FourierInducingFeatures(eqx.Module):
    r"""VFF — Variational Fourier inducing features on :math:`[-L, L]^D`.

    For a stationary kernel with spectral density :math:`S(\cdot)`, the
    basis :math:`\{\phi_j\}` of Laplacian eigenfunctions on the box gives

    .. math::

        K_{uu} = \mathrm{diag}\!\big(S(\sqrt{\lambda_j})\big),
        \qquad
        K_{ux}(x)_j = S(\sqrt{\lambda_j})\,\phi_j(x).

    With this convention :math:`K_{ux} K_{uu}^{-1} = \phi_j(x)`, so the
    SVGP predictive mean reduces to a basis evaluation. ``K_{uu}`` is
    returned as a :class:`lineax.DiagonalLinearOperator` to preserve the
    O(M) solve dispatch end-to-end.

    Attributes:
        in_features: Input dimension :math:`D`.
        num_basis_per_dim: Per-axis number of 1D eigenfunctions; total
            count is ``prod(num_basis_per_dim)``.
        L: Per-axis box half-width.
    """

    in_features: int = eqx.field(static=True)
    num_basis_per_dim: tuple[int, ...] = eqx.field(static=True)
    L: tuple[float, ...] = eqx.field(static=True)

    @classmethod
    def init(
        cls,
        in_features: int,
        num_basis_per_dim: int | tuple[int, ...],
        L: float | tuple[float, ...],
    ) -> FourierInducingFeatures:
        M_per = _to_tuple(num_basis_per_dim, in_features, "num_basis_per_dim")
        L_per = _to_tuple(L, in_features, "L")
        if any(L_d <= 0 for L_d in L_per):
            raise ValueError(f"L must be all positive; got {L_per}.")
        if any(M_d < 1 for M_d in M_per):
            raise ValueError(f"num_basis_per_dim must be all >= 1; got {M_per}.")
        return cls(
            in_features=in_features,
            num_basis_per_dim=M_per,
            L=tuple(float(L_d) for L_d in L_per),
        )

    @property
    def num_features(self) -> int:
        n = 1
        for m in self.num_basis_per_dim:
            n *= m
        return n

    def _check_stationary(self, kernel: Kernel) -> None:
        if not _is_stationary(kernel):
            raise ValueError(
                f"FourierInducingFeatures requires a stationary kernel with a "
                f"registered spectral density (RBF or Matern); got "
                f"{type(kernel).__name__}."
            )

    def K_uu(
        self, kernel: Kernel, *, jitter: float = 1e-6
    ) -> lx.DiagonalLinearOperator:
        """Diagonal :math:`K_{uu}` — entries ``S(sqrt(lambda_j))`` plus jitter."""
        self._check_stationary(kernel)
        with _kernel_context(kernel):
            lam = fourier_eigenvalues(self.num_basis_per_dim, self.L, self.in_features)
            S = spectral_density(kernel, lam, D=self.in_features)
        return _diagonal_with_jitter(S, jitter)

    def k_ux(self, x: Float[Array, "N D"], kernel: Kernel) -> Float[Array, "N M"]:
        """Cross-covariance entries :math:`S(\\sqrt{\\lambda_j})\\,\\phi_j(x)`."""
        self._check_stationary(kernel)
        if x.ndim != 2 or x.shape[-1] != self.in_features:
            raise ValueError(f"x must be (N, {self.in_features}); got shape {x.shape}.")
        with _kernel_context(kernel):
            Phi, lam = fourier_basis(x, self.num_basis_per_dim, self.L)
            S = spectral_density(kernel, lam, D=self.in_features)
        return Phi * S[None, :]

K_uu(kernel, *, jitter=1e-06)

Diagonal :math:K_{uu} — entries S(sqrt(lambda_j)) plus jitter.

Source code in src/pyrox/gp/_inducing.py
def K_uu(
    self, kernel: Kernel, *, jitter: float = 1e-6
) -> lx.DiagonalLinearOperator:
    """Diagonal :math:`K_{uu}` — entries ``S(sqrt(lambda_j))`` plus jitter."""
    self._check_stationary(kernel)
    with _kernel_context(kernel):
        lam = fourier_eigenvalues(self.num_basis_per_dim, self.L, self.in_features)
        S = spectral_density(kernel, lam, D=self.in_features)
    return _diagonal_with_jitter(S, jitter)

k_ux(x, kernel)

Cross-covariance entries :math:S(\sqrt{\lambda_j})\,\phi_j(x).

Source code in src/pyrox/gp/_inducing.py
def k_ux(self, x: Float[Array, "N D"], kernel: Kernel) -> Float[Array, "N M"]:
    """Cross-covariance entries :math:`S(\\sqrt{\\lambda_j})\\,\\phi_j(x)`."""
    self._check_stationary(kernel)
    if x.ndim != 2 or x.shape[-1] != self.in_features:
        raise ValueError(f"x must be (N, {self.in_features}); got shape {x.shape}.")
    with _kernel_context(kernel):
        Phi, lam = fourier_basis(x, self.num_basis_per_dim, self.L)
        S = spectral_density(kernel, lam, D=self.in_features)
    return Phi * S[None, :]

pyrox.gp.SphericalHarmonicInducingFeatures

Bases: Module

VISH — inducing harmonics on :math:S^2 (Dutordoir et al. 2020).

For any zonal kernel :math:k(x, x') = \kappa(x \cdot x') on the unit 2-sphere, the Funk-Hecke theorem gives a diagonal :math:K_{uu} whose eigenvalues are the kernel's Funk-Hecke coefficients :math:a_l. The cross-covariance is :math:a_l\,Y_{lm}(x).

Funk-Hecke coefficients are computed by Gauss-Legendre quadrature (arbitrary kernels supported, no closed form required). For kernels that have a closed-form Funk-Hecke series (RBF on S² via Bessel functions etc.), the numerical and analytic answers should agree to the quadrature tolerance.

Attributes:

Name Type Description
l_max int

Maximum harmonic degree, inclusive.

num_quadrature int

Gauss-Legendre nodes for the Funk-Hecke integral.

Source code in src/pyrox/gp/_inducing.py
class SphericalHarmonicInducingFeatures(eqx.Module):
    r"""VISH — inducing harmonics on :math:`S^2` (Dutordoir et al. 2020).

    For any zonal kernel :math:`k(x, x') = \kappa(x \cdot x')` on the
    unit 2-sphere, the Funk-Hecke theorem gives a diagonal :math:`K_{uu}`
    whose eigenvalues are the kernel's Funk-Hecke coefficients
    :math:`a_l`. The cross-covariance is :math:`a_l\,Y_{lm}(x)`.

    Funk-Hecke coefficients are computed by Gauss-Legendre quadrature
    (arbitrary kernels supported, no closed form required). For
    kernels that have a closed-form Funk-Hecke series (RBF on S² via
    Bessel functions etc.), the numerical and analytic answers should
    agree to the quadrature tolerance.

    Attributes:
        l_max: Maximum harmonic degree, inclusive.
        num_quadrature: Gauss-Legendre nodes for the Funk-Hecke integral.
    """

    l_max: int = eqx.field(static=True)
    num_quadrature: int = eqx.field(static=True, default=256)

    @classmethod
    def init(
        cls, l_max: int, *, num_quadrature: int = 256
    ) -> SphericalHarmonicInducingFeatures:
        if l_max < 0:
            raise ValueError(f"l_max must be >= 0; got {l_max}.")
        if num_quadrature < 1:
            raise ValueError(f"num_quadrature must be >= 1; got {num_quadrature}.")
        return cls(l_max=l_max, num_quadrature=num_quadrature)

    @property
    def num_features(self) -> int:
        return (self.l_max + 1) ** 2

    def _per_feature_coeffs(self, kernel: Kernel) -> Float[Array, " M"]:
        a = funk_hecke_coefficients(
            kernel, self.l_max, num_quadrature=self.num_quadrature
        )
        # Each l contributes 2l+1 features with the same coefficient.
        return jnp.concatenate(
            [jnp.full((2 * ell + 1,), a[ell]) for ell in range(self.l_max + 1)]
        )

    def K_uu(
        self, kernel: Kernel, *, jitter: float = 1e-6
    ) -> lx.DiagonalLinearOperator:
        """Diagonal :math:`K_{uu}` — Funk-Hecke coefficients per harmonic."""
        diag = self._per_feature_coeffs(kernel)
        return _diagonal_with_jitter(diag, jitter)

    def k_ux(
        self,
        unit_xyz: Float[Array, "N 3"],
        kernel: Kernel,
    ) -> Float[Array, "N M"]:
        r"""Cross-covariance: :math:`a_l\,Y_{lm}(x)`."""
        if unit_xyz.ndim != 2 or unit_xyz.shape[-1] != 3:
            raise ValueError(f"unit_xyz must be (N, 3); got {unit_xyz.shape}.")
        Y = real_spherical_harmonics(unit_xyz, self.l_max)
        a_per_feature = self._per_feature_coeffs(kernel)
        return Y * a_per_feature[None, :]

K_uu(kernel, *, jitter=1e-06)

Diagonal :math:K_{uu} — Funk-Hecke coefficients per harmonic.

Source code in src/pyrox/gp/_inducing.py
def K_uu(
    self, kernel: Kernel, *, jitter: float = 1e-6
) -> lx.DiagonalLinearOperator:
    """Diagonal :math:`K_{uu}` — Funk-Hecke coefficients per harmonic."""
    diag = self._per_feature_coeffs(kernel)
    return _diagonal_with_jitter(diag, jitter)

k_ux(unit_xyz, kernel)

Cross-covariance: :math:a_l\,Y_{lm}(x).

Source code in src/pyrox/gp/_inducing.py
def k_ux(
    self,
    unit_xyz: Float[Array, "N 3"],
    kernel: Kernel,
) -> Float[Array, "N M"]:
    r"""Cross-covariance: :math:`a_l\,Y_{lm}(x)`."""
    if unit_xyz.ndim != 2 or unit_xyz.shape[-1] != 3:
        raise ValueError(f"unit_xyz must be (N, 3); got {unit_xyz.shape}.")
    Y = real_spherical_harmonics(unit_xyz, self.l_max)
    a_per_feature = self._per_feature_coeffs(kernel)
    return Y * a_per_feature[None, :]

pyrox.gp.LaplacianInducingFeatures

Bases: Module

Inducing features from low-frequency graph Laplacian eigenvectors.

For a graph with normalized Laplacian :math:L, take the smallest num_basis eigenpairs :math:(\mu_j, v_j). Treating the kernel as a function of the graph distance — specifically, applying the kernel spectrum :math:g(\mu) to the Laplacian eigenvalues — gives a diagonal :math:K_{uu}.

This implementation supports the heat-kernel family :math:g(\mu) = \exp(-\mu / (2 \ell^2)) (matching :class:pyrox.gp.RBF in spectrum) by reusing :func:pyrox._basis.spectral_density with the eigenvalues as input.

Attributes:

Name Type Description
eigvals Float[Array, ' M']

(M,) Laplacian eigenvalues.

eigvecs Float[Array, 'V M']

(V, M) Laplacian eigenvectors.

num_quadrature Float[Array, 'V M']

Unused (kept for protocol uniformity).

Note

X is a vector of node indices (integer-valued), not coordinates. The returned cross-covariance gathers the relevant rows of eigvecs.

Source code in src/pyrox/gp/_inducing.py
class LaplacianInducingFeatures(eqx.Module):
    r"""Inducing features from low-frequency graph Laplacian eigenvectors.

    For a graph with normalized Laplacian :math:`L`, take the smallest
    ``num_basis`` eigenpairs :math:`(\mu_j, v_j)`. Treating the kernel as
    a function of the graph distance — specifically, applying the kernel
    *spectrum* :math:`g(\mu)` to the Laplacian eigenvalues — gives a
    diagonal :math:`K_{uu}`.

    This implementation supports the *heat-kernel* family
    :math:`g(\mu) = \exp(-\mu / (2 \ell^2))` (matching :class:`pyrox.gp.RBF`
    in spectrum) by reusing :func:`pyrox._basis.spectral_density` with the
    eigenvalues as input.

    Attributes:
        eigvals: ``(M,)`` Laplacian eigenvalues.
        eigvecs: ``(V, M)`` Laplacian eigenvectors.
        num_quadrature: Unused (kept for protocol uniformity).

    Note:
        ``X`` is a vector of *node indices* (integer-valued), not
        coordinates. The returned cross-covariance gathers the relevant
        rows of ``eigvecs``.
    """

    eigvals: Float[Array, " M"]
    eigvecs: Float[Array, "V M"]

    @classmethod
    def fit(
        cls,
        adjacency: Float[Array, "V V"],
        num_basis: int,
        *,
        normalized: bool = True,
    ) -> LaplacianInducingFeatures:
        eigvals, eigvecs = graph_laplacian_eigpairs(
            adjacency, num_basis, normalized=normalized
        )
        return cls(eigvals=eigvals, eigvecs=eigvecs)

    @property
    def num_features(self) -> int:
        return int(self.eigvals.shape[0])

    def _check_stationary(self, kernel: Kernel) -> None:
        if not _is_stationary(kernel):
            raise ValueError(
                "LaplacianInducingFeatures requires a stationary kernel with a "
                f"registered spectral density; got {type(kernel).__name__}."
            )

    def K_uu(
        self, kernel: Kernel, *, jitter: float = 1e-6
    ) -> lx.DiagonalLinearOperator:
        self._check_stationary(kernel)
        with _kernel_context(kernel):
            S = spectral_density(kernel, self.eigvals, D=1)
        return _diagonal_with_jitter(S, jitter)

    def k_ux(
        self, node_indices: Int[Array, " N"], kernel: Kernel
    ) -> Float[Array, "N M"]:
        self._check_stationary(kernel)
        if node_indices.ndim != 1:
            raise ValueError(
                "node_indices must be a 1D integer array; got shape "
                f"{node_indices.shape}."
            )
        with _kernel_context(kernel):
            S = spectral_density(kernel, self.eigvals, D=1)
        rows = self.eigvecs[node_indices]
        return rows * S[None, :]

pyrox.gp.DecoupledInducingFeatures

Bases: Module

Decoupled mean / covariance inducing-feature bases (Cheng & Boots 2017).

Two independent inducing-feature sets:

  • mean_features: a large alpha-basis used by the SVGP posterior mean (cheap — predictive mean cost is linear in the mean-basis size).
  • cov_features: a small beta-basis used for the posterior covariance (the true bottleneck; keep this small).

The two bases need not share the same family — a common pattern is a large Fourier basis for the mean and a small spherical-harmonic basis for the covariance, or vice versa. The downstream guide consumes both via the standard SVGP machinery.

Attributes:

Name Type Description
mean_features InducingFeatures

Inducing-feature object backing the predictive mean.

cov_features InducingFeatures

Inducing-feature object backing the predictive covariance.

Note

DecoupledInducingFeatures itself does not implement :class:InducingFeatures (no single K_uu makes sense for two bases). Consumers should access .mean_features and .cov_features directly.

Source code in src/pyrox/gp/_inducing.py
class DecoupledInducingFeatures(eqx.Module):
    r"""Decoupled mean / covariance inducing-feature bases (Cheng & Boots 2017).

    Two independent inducing-feature sets:

    - ``mean_features``: a large ``alpha``-basis used by the SVGP
      posterior *mean* (cheap — predictive mean cost is linear in the
      mean-basis size).
    - ``cov_features``: a small ``beta``-basis used for the posterior
      *covariance* (the true bottleneck; keep this small).

    The two bases need not share the same family — a common pattern is a
    large Fourier basis for the mean and a small spherical-harmonic
    basis for the covariance, or vice versa. The downstream guide
    consumes both via the standard SVGP machinery.

    Attributes:
        mean_features: Inducing-feature object backing the predictive mean.
        cov_features: Inducing-feature object backing the predictive covariance.

    Note:
        ``DecoupledInducingFeatures`` itself does *not* implement
        :class:`InducingFeatures` (no single ``K_uu`` makes sense for two
        bases). Consumers should access ``.mean_features`` and
        ``.cov_features`` directly.
    """

    mean_features: InducingFeatures
    cov_features: InducingFeatures

    @property
    def num_mean_features(self) -> int:
        return self.mean_features.num_features

    @property
    def num_cov_features(self) -> int:
        return self.cov_features.num_features

pyrox.gp.funk_hecke_coefficients(kernel, l_max, *, num_quadrature=256)

Funk-Hecke coefficients of a zonal kernel on :math:S^2.

For a kernel of the form :math:k(x, x') = \kappa(x \cdot x') on the unit 2-sphere, the Funk-Hecke theorem gives:

.. math::

a_l = 2\pi \int_{-1}^{1} \kappa(t)\,P_l(t)\,dt.

Returns (l_max + 1,) coefficients indexed by l. We treat any Euclidean kernel as zonal-on-the-sphere via :math:\kappa(t) = k_{\mathrm{euc}}(\hat{n}_0, \hat{n}_t) for unit vectors at angular separation arccos(t).

Source code in src/pyrox/gp/_inducing.py
def funk_hecke_coefficients(
    kernel: Kernel,
    l_max: int,
    *,
    num_quadrature: int = 256,
) -> Float[Array, " l_max_plus_1"]:
    r"""Funk-Hecke coefficients of a zonal kernel on :math:`S^2`.

    For a kernel of the form :math:`k(x, x') = \kappa(x \cdot x')` on the
    unit 2-sphere, the Funk-Hecke theorem gives:

    .. math::

        a_l = 2\pi \int_{-1}^{1} \kappa(t)\,P_l(t)\,dt.

    Returns ``(l_max + 1,)`` coefficients indexed by ``l``. We treat any
    Euclidean kernel as zonal-on-the-sphere via
    :math:`\kappa(t) = k_{\mathrm{euc}}(\hat{n}_0, \hat{n}_t)` for unit
    vectors at angular separation ``arccos(t)``.
    """
    # Gauss-Legendre quadrature nodes on [-1, 1] (host-side setup constants).
    t, w = _gauss_legendre_nodes(num_quadrature)
    # Build pairs of unit vectors: x0 = (0, 0, 1), x_t = (sin(arccos t), 0, t).
    sin_t = jnp.sqrt(jnp.maximum(1.0 - t**2, 0.0))
    n0 = jnp.array([0.0, 0.0, 1.0])
    nT = jnp.stack([sin_t, jnp.zeros_like(t), t], axis=-1)  # (Q, 3)
    # Single batched kernel call — stays on-device and keeps autodiff edges
    # to any hyperparameters sampled inside ``kernel``. Taking row 0 of the
    # ``(1, Q)`` Gram is O(Q), not O(Q^2).
    with _kernel_context(kernel):
        kt = kernel(n0[None, :], nT)[0]  # (Q,)
    # Evaluate P_l(t) for l = 0, ..., l_max via three-term recurrence.
    P_lm1 = jnp.ones_like(t)  # P_0
    P_l = t  # P_1
    coeffs = [2.0 * jnp.pi * jnp.sum(w * kt)]  # a_0 = 2pi * int kt * 1 dt
    if l_max >= 1:
        coeffs.append(2.0 * jnp.pi * jnp.sum(w * kt * P_l))  # a_1
    for ell in range(2, l_max + 1):
        P_lp1 = ((2 * ell - 1) * t * P_l - (ell - 1) * P_lm1) / ell
        coeffs.append(2.0 * jnp.pi * jnp.sum(w * kt * P_lp1))
        P_lm1, P_l = P_l, P_lp1
    return jnp.stack(coeffs, axis=0)

Sparse GP prior

pyrox.gp.SparseGPPrior

Bases: Module

GP prior parameterized over inducing inputs Z.

Represents the zero-mean prior over inducing values u = f(Z) used by sparse variational guides:

.. math::

p(u) = \mathcal{N}(0,\, K_{ZZ} + \mathrm{jitter}\,I).

The standard SVGP convention is to subtract any global mean function before forming the prior over u and to add it back at predict time, so the inducing-prior mean is fixed to zero (this is what the guides' KL terms assume — see :meth:FullRankGuide.kl_divergence, :meth:MeanFieldGuide.kl_divergence, :meth:WhitenedGuide.kl_divergence). The :attr:mean_fn attribute on this class is exposed as a convenience for callers that want to add mu(X_*) back onto the predictive mean returned by :meth:Guide.predict; it is not incorporated in :meth:inducing_operator or in the guides' KL.

Pair with a sparse variational guide that owns q(u) = N(m, S) to obtain the standard SVGP predictive

.. math::

\mu_*(x) = K_{xZ} K_{ZZ}^{-1} m, \qquad
\sigma_*^2(x) = k(x, x) - K_{xZ} K_{ZZ}^{-1} K_{Zx}
               + K_{xZ} K_{ZZ}^{-1} S K_{ZZ}^{-1} K_{Zx}.

Attributes:

Name Type Description
kernel Kernel

Any :class:pyrox.gp.Kernel — evaluated on Z.

Z Float[Array, 'M D'] | None

Inducing inputs of shape (M, D).

mean_fn Callable[[Float[Array, 'N D']], Float[Array, ' N']] | None

Callable X -> (N,) or None for the zero mean. Convenience accessor; not folded into the inducing prior.

solver AbstractSolverStrategy | None

Any gaussx.AbstractSolverStrategy. Defaults to gaussx.DenseSolver(). Used by guides that need to solve against K_zz (e.g.\ for KL or unwhitening).

jitter float

Diagonal regularization added to K_zz for numerical stability. Not a noise model — sparse SVGP does not put observation noise on the inducing-value prior.

Source code in src/pyrox/gp/_sparse.py
class SparseGPPrior(eqx.Module):
    r"""GP prior parameterized over inducing inputs ``Z``.

    Represents the *zero-mean* prior over inducing values ``u = f(Z)``
    used by sparse variational guides:

    .. math::

        p(u) = \mathcal{N}(0,\, K_{ZZ} + \mathrm{jitter}\,I).

    The standard SVGP convention is to subtract any global mean function
    before forming the prior over ``u`` and to add it back at predict
    time, so the inducing-prior mean is fixed to zero (this is what the
    guides' KL terms assume — see :meth:`FullRankGuide.kl_divergence`,
    :meth:`MeanFieldGuide.kl_divergence`, :meth:`WhitenedGuide.kl_divergence`).
    The :attr:`mean_fn` attribute on this class is exposed as a
    convenience for callers that want to add ``mu(X_*)`` back onto the
    predictive mean returned by :meth:`Guide.predict`; it is **not**
    incorporated in :meth:`inducing_operator` or in the guides' KL.

    Pair with a sparse variational guide that owns ``q(u) = N(m, S)`` to
    obtain the standard SVGP predictive

    .. math::

        \mu_*(x) = K_{xZ} K_{ZZ}^{-1} m, \qquad
        \sigma_*^2(x) = k(x, x) - K_{xZ} K_{ZZ}^{-1} K_{Zx}
                       + K_{xZ} K_{ZZ}^{-1} S K_{ZZ}^{-1} K_{Zx}.

    Attributes:
        kernel: Any :class:`pyrox.gp.Kernel` — evaluated on ``Z``.
        Z: Inducing inputs of shape ``(M, D)``.
        mean_fn: Callable ``X -> (N,)`` or ``None`` for the zero mean.
            Convenience accessor; not folded into the inducing prior.
        solver: Any ``gaussx.AbstractSolverStrategy``. Defaults to
            ``gaussx.DenseSolver()``. Used by guides that need to solve
            against ``K_zz`` (e.g.\ for KL or unwhitening).
        jitter: Diagonal regularization added to ``K_zz`` for numerical
            stability. Not a noise model — sparse SVGP does not put
            observation noise on the inducing-value prior.
    """

    kernel: Kernel
    Z: Float[Array, "M D"] | None = None
    inducing: InducingFeatures | None = None
    mean_fn: Callable[[Float[Array, "N D"]], Float[Array, " N"]] | None = None
    solver: AbstractSolverStrategy | None = None
    jitter: float = 1e-6

    def __check_init__(self) -> None:
        if (self.Z is None) == (self.inducing is None):
            raise ValueError(
                "SparseGPPrior must be constructed with exactly one of `Z` "
                "(point inducing) or `inducing` (inducing features)."
            )

    @property
    def num_inducing(self) -> int:
        """Number of inducing inputs / features ``M``."""
        if self.inducing is not None:
            return self.inducing.num_features
        assert self.Z is not None
        return self.Z.shape[0]

    def mean(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
        """Evaluate the mean function at ``X``; zero by default."""
        if self.mean_fn is None:
            return jnp.zeros(X.shape[0], dtype=X.dtype)
        return self.mean_fn(X)

    def inducing_operator(self) -> lx.AbstractLinearOperator:
        r"""Return ``K_{ZZ} + \text{jitter}\,I`` as a ``lineax`` operator.

        For point-inducing priors, returns a dense
        :class:`lineax.MatrixLinearOperator` with ``positive_semidefinite_tag``.
        For inducing-feature priors, delegates to
        :meth:`InducingFeatures.K_uu` — typically a
        :class:`lineax.DiagonalLinearOperator` so the downstream
        :func:`gaussx.solve` dispatches in O(M) instead of O(M^3).

        Single kernel call; safe standalone for kernels with priors. For
        building several SVGP blocks together, prefer
        :meth:`predictive_blocks`, which scopes one shared kernel
        context across ``K_zz``, ``K_xz``, and ``K_xx_diag`` so
        Pattern B / C kernels register their NumPyro hyperparameter
        sites once instead of resampling per call.
        """
        if self.inducing is not None:
            with _kernel_context(self.kernel):
                return self.inducing.K_uu(self.kernel, jitter=self.jitter)
        assert self.Z is not None
        with _kernel_context(self.kernel):
            K = self.kernel(self.Z, self.Z)
        K = K + self.jitter * jnp.eye(K.shape[0], dtype=K.dtype)
        return _psd_operator(K)

    def cross_covariance(self, X: Array) -> Float[Array, "N M"]:
        r""":math:`K_{XZ}` — covariance between ``X`` and the inducing inputs/features.

        The expected shape of ``X`` is inducing-family-dependent:

        - Point-inducing (``Z``) or :class:`FourierInducingFeatures`:
          coordinates ``(N, D)``.
        - :class:`SphericalHarmonicInducingFeatures`: unit vectors ``(N, 3)``.
        - :class:`LaplacianInducingFeatures`: integer node indices ``(N,)``.

        See :meth:`predictive_blocks` for the shared-context batch
        helper to use when assembling several SVGP blocks together.
        """
        if self.inducing is not None:
            with _kernel_context(self.kernel):
                return self.inducing.k_ux(X, self.kernel)
        assert self.Z is not None
        with _kernel_context(self.kernel):
            return self.kernel(X, self.Z)

    def kernel_diag(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
        r"""Prior diagonal ``\mathrm{diag}\,K(X, X)`` — variance at each ``x``.

        See :meth:`predictive_blocks` for the shared-context batch
        helper to use when assembling several SVGP blocks together.
        """
        with _kernel_context(self.kernel):
            return self.kernel.diag(X)

    def predictive_blocks(
        self, X: Array
    ) -> tuple[
        lx.AbstractLinearOperator,
        Float[Array, "N M"],
        Float[Array, " N"],
    ]:
        r"""Return ``(K_zz_op, K_xz, K_xx_diag)`` under one shared kernel context.

        For Pattern B / C kernels with prior'd hyperparameters, the three
        kernel evaluations needed for an SVGP predictive must share a
        single :class:`pyrox.PyroxModule` context so the underlying
        ``pyrox_sample`` sites register once and yield consistent
        hyperparameter draws across ``K_{ZZ}``, ``K_{XZ}``, and the
        diagonal ``\mathrm{diag}\,K(X, X)``. Without this scoping, three
        separate calls would draw three independent hyperparameter
        samples (under seed) or raise NumPyro duplicate-site errors
        (under tracing) — either way invalidating the SVGP math.

        For pure :class:`equinox.Module` kernels (no ``_get_context``),
        this is equivalent to calling :meth:`inducing_operator`,
        :meth:`cross_covariance`, and :meth:`kernel_diag` independently.

        For inducing-feature priors, ``K_zz_op`` is a
        :class:`lineax.DiagonalLinearOperator` (jitter folded into the
        diagonal vector — never ``+ jnp.eye``) so the downstream solve
        stays O(M).
        """
        with _kernel_context(self.kernel):
            if self.inducing is not None:
                K_zz_op = self.inducing.K_uu(self.kernel, jitter=self.jitter)
                K_xz = self.inducing.k_ux(X, self.kernel)
            else:
                assert self.Z is not None
                K_zz_raw = self.kernel(self.Z, self.Z)
                K_xz = self.kernel(X, self.Z)
                K_zz = K_zz_raw + self.jitter * jnp.eye(
                    K_zz_raw.shape[0], dtype=K_zz_raw.dtype
                )
                K_zz_op = _psd_operator(K_zz)
            K_xx_diag = self.kernel.diag(X)
        return K_zz_op, K_xz, K_xx_diag

    def _resolved_solver(self) -> AbstractSolverStrategy:
        return DenseSolver() if self.solver is None else self.solver  # ty: ignore[invalid-return-type]

    def log_prob(self, u: Float[Array, " M"]) -> Float[Array, ""]:
        r"""Log-density under :math:`p(u) = \mathcal{N}(0, K_{ZZ} + \text{jitter}\,I)`.

        Delegates to :func:`gaussx.gaussian_log_prob` with the
        configured :attr:`solver` so the user-supplied solver controls
        the ``solve`` / ``logdet`` work on ``K_zz_op``. Useful for
        scoring inducing values against the SVGP prior in non-NumPyro
        contexts (e.g.\\ tests, diagnostics).
        """
        m = jnp.zeros(self.num_inducing, dtype=u.dtype)
        return gaussian_log_prob(
            m, self.inducing_operator(), u, solver=self._resolved_solver()
        )

    def sample(self, key: Array) -> Float[Array, " M"]:
        r"""Draw ``u \sim p(u)`` from the inducing prior.

        Wraps the inducing operator in a
        :class:`gaussx.MultivariateNormal` with the configured
        :attr:`solver`. ``MultivariateNormal.sample`` factors the
        covariance via :func:`gaussx.cholesky` and reparameterizes;
        the returned draw has shape ``(M,)``.

        Note: the SVGP variational workflow samples ``u`` from the
        *guide* :math:`q(u)`, not the prior. This method exists so the
        prior surface is symmetric with the guide surface and so users
        can score / draw inducing values against the prior directly
        (e.g.\\ for tests or for prior-sample initialization).
        """
        n = self.num_inducing
        op = self.inducing_operator()
        loc = jnp.zeros(n, dtype=op.out_structure().dtype)
        mvn = MultivariateNormal(loc, op, solver=self._resolved_solver())
        return mvn.sample(key)

num_inducing property

Number of inducing inputs / features M.

cross_covariance(X)

:math:K_{XZ} — covariance between X and the inducing inputs/features.

The expected shape of X is inducing-family-dependent:

  • Point-inducing (Z) or :class:FourierInducingFeatures: coordinates (N, D).
  • :class:SphericalHarmonicInducingFeatures: unit vectors (N, 3).
  • :class:LaplacianInducingFeatures: integer node indices (N,).

See :meth:predictive_blocks for the shared-context batch helper to use when assembling several SVGP blocks together.

Source code in src/pyrox/gp/_sparse.py
def cross_covariance(self, X: Array) -> Float[Array, "N M"]:
    r""":math:`K_{XZ}` — covariance between ``X`` and the inducing inputs/features.

    The expected shape of ``X`` is inducing-family-dependent:

    - Point-inducing (``Z``) or :class:`FourierInducingFeatures`:
      coordinates ``(N, D)``.
    - :class:`SphericalHarmonicInducingFeatures`: unit vectors ``(N, 3)``.
    - :class:`LaplacianInducingFeatures`: integer node indices ``(N,)``.

    See :meth:`predictive_blocks` for the shared-context batch
    helper to use when assembling several SVGP blocks together.
    """
    if self.inducing is not None:
        with _kernel_context(self.kernel):
            return self.inducing.k_ux(X, self.kernel)
    assert self.Z is not None
    with _kernel_context(self.kernel):
        return self.kernel(X, self.Z)

inducing_operator()

Return K_{ZZ} + \text{jitter}\,I as a lineax operator.

For point-inducing priors, returns a dense :class:lineax.MatrixLinearOperator with positive_semidefinite_tag. For inducing-feature priors, delegates to :meth:InducingFeatures.K_uu — typically a :class:lineax.DiagonalLinearOperator so the downstream :func:gaussx.solve dispatches in O(M) instead of O(M^3).

Single kernel call; safe standalone for kernels with priors. For building several SVGP blocks together, prefer :meth:predictive_blocks, which scopes one shared kernel context across K_zz, K_xz, and K_xx_diag so Pattern B / C kernels register their NumPyro hyperparameter sites once instead of resampling per call.

Source code in src/pyrox/gp/_sparse.py
def inducing_operator(self) -> lx.AbstractLinearOperator:
    r"""Return ``K_{ZZ} + \text{jitter}\,I`` as a ``lineax`` operator.

    For point-inducing priors, returns a dense
    :class:`lineax.MatrixLinearOperator` with ``positive_semidefinite_tag``.
    For inducing-feature priors, delegates to
    :meth:`InducingFeatures.K_uu` — typically a
    :class:`lineax.DiagonalLinearOperator` so the downstream
    :func:`gaussx.solve` dispatches in O(M) instead of O(M^3).

    Single kernel call; safe standalone for kernels with priors. For
    building several SVGP blocks together, prefer
    :meth:`predictive_blocks`, which scopes one shared kernel
    context across ``K_zz``, ``K_xz``, and ``K_xx_diag`` so
    Pattern B / C kernels register their NumPyro hyperparameter
    sites once instead of resampling per call.
    """
    if self.inducing is not None:
        with _kernel_context(self.kernel):
            return self.inducing.K_uu(self.kernel, jitter=self.jitter)
    assert self.Z is not None
    with _kernel_context(self.kernel):
        K = self.kernel(self.Z, self.Z)
    K = K + self.jitter * jnp.eye(K.shape[0], dtype=K.dtype)
    return _psd_operator(K)

kernel_diag(X)

Prior diagonal \mathrm{diag}\,K(X, X) — variance at each x.

See :meth:predictive_blocks for the shared-context batch helper to use when assembling several SVGP blocks together.

Source code in src/pyrox/gp/_sparse.py
def kernel_diag(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
    r"""Prior diagonal ``\mathrm{diag}\,K(X, X)`` — variance at each ``x``.

    See :meth:`predictive_blocks` for the shared-context batch
    helper to use when assembling several SVGP blocks together.
    """
    with _kernel_context(self.kernel):
        return self.kernel.diag(X)

log_prob(u)

Log-density under :math:p(u) = \mathcal{N}(0, K_{ZZ} + \text{jitter}\,I).

Delegates to :func:gaussx.gaussian_log_prob with the configured :attr:solver so the user-supplied solver controls the solve / logdet work on K_zz_op. Useful for scoring inducing values against the SVGP prior in non-NumPyro contexts (e.g.\ tests, diagnostics).

Source code in src/pyrox/gp/_sparse.py
def log_prob(self, u: Float[Array, " M"]) -> Float[Array, ""]:
    r"""Log-density under :math:`p(u) = \mathcal{N}(0, K_{ZZ} + \text{jitter}\,I)`.

    Delegates to :func:`gaussx.gaussian_log_prob` with the
    configured :attr:`solver` so the user-supplied solver controls
    the ``solve`` / ``logdet`` work on ``K_zz_op``. Useful for
    scoring inducing values against the SVGP prior in non-NumPyro
    contexts (e.g.\\ tests, diagnostics).
    """
    m = jnp.zeros(self.num_inducing, dtype=u.dtype)
    return gaussian_log_prob(
        m, self.inducing_operator(), u, solver=self._resolved_solver()
    )

mean(X)

Evaluate the mean function at X; zero by default.

Source code in src/pyrox/gp/_sparse.py
def mean(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
    """Evaluate the mean function at ``X``; zero by default."""
    if self.mean_fn is None:
        return jnp.zeros(X.shape[0], dtype=X.dtype)
    return self.mean_fn(X)

predictive_blocks(X)

Return (K_zz_op, K_xz, K_xx_diag) under one shared kernel context.

For Pattern B / C kernels with prior'd hyperparameters, the three kernel evaluations needed for an SVGP predictive must share a single :class:pyrox.PyroxModule context so the underlying pyrox_sample sites register once and yield consistent hyperparameter draws across K_{ZZ}, K_{XZ}, and the diagonal \mathrm{diag}\,K(X, X). Without this scoping, three separate calls would draw three independent hyperparameter samples (under seed) or raise NumPyro duplicate-site errors (under tracing) — either way invalidating the SVGP math.

For pure :class:equinox.Module kernels (no _get_context), this is equivalent to calling :meth:inducing_operator, :meth:cross_covariance, and :meth:kernel_diag independently.

For inducing-feature priors, K_zz_op is a :class:lineax.DiagonalLinearOperator (jitter folded into the diagonal vector — never + jnp.eye) so the downstream solve stays O(M).

Source code in src/pyrox/gp/_sparse.py
def predictive_blocks(
    self, X: Array
) -> tuple[
    lx.AbstractLinearOperator,
    Float[Array, "N M"],
    Float[Array, " N"],
]:
    r"""Return ``(K_zz_op, K_xz, K_xx_diag)`` under one shared kernel context.

    For Pattern B / C kernels with prior'd hyperparameters, the three
    kernel evaluations needed for an SVGP predictive must share a
    single :class:`pyrox.PyroxModule` context so the underlying
    ``pyrox_sample`` sites register once and yield consistent
    hyperparameter draws across ``K_{ZZ}``, ``K_{XZ}``, and the
    diagonal ``\mathrm{diag}\,K(X, X)``. Without this scoping, three
    separate calls would draw three independent hyperparameter
    samples (under seed) or raise NumPyro duplicate-site errors
    (under tracing) — either way invalidating the SVGP math.

    For pure :class:`equinox.Module` kernels (no ``_get_context``),
    this is equivalent to calling :meth:`inducing_operator`,
    :meth:`cross_covariance`, and :meth:`kernel_diag` independently.

    For inducing-feature priors, ``K_zz_op`` is a
    :class:`lineax.DiagonalLinearOperator` (jitter folded into the
    diagonal vector — never ``+ jnp.eye``) so the downstream solve
    stays O(M).
    """
    with _kernel_context(self.kernel):
        if self.inducing is not None:
            K_zz_op = self.inducing.K_uu(self.kernel, jitter=self.jitter)
            K_xz = self.inducing.k_ux(X, self.kernel)
        else:
            assert self.Z is not None
            K_zz_raw = self.kernel(self.Z, self.Z)
            K_xz = self.kernel(X, self.Z)
            K_zz = K_zz_raw + self.jitter * jnp.eye(
                K_zz_raw.shape[0], dtype=K_zz_raw.dtype
            )
            K_zz_op = _psd_operator(K_zz)
        K_xx_diag = self.kernel.diag(X)
    return K_zz_op, K_xz, K_xx_diag

sample(key)

Draw u \sim p(u) from the inducing prior.

Wraps the inducing operator in a :class:gaussx.MultivariateNormal with the configured :attr:solver. MultivariateNormal.sample factors the covariance via :func:gaussx.cholesky and reparameterizes; the returned draw has shape (M,).

Note: the SVGP variational workflow samples u from the guide :math:q(u), not the prior. This method exists so the prior surface is symmetric with the guide surface and so users can score / draw inducing values against the prior directly (e.g.\ for tests or for prior-sample initialization).

Source code in src/pyrox/gp/_sparse.py
def sample(self, key: Array) -> Float[Array, " M"]:
    r"""Draw ``u \sim p(u)`` from the inducing prior.

    Wraps the inducing operator in a
    :class:`gaussx.MultivariateNormal` with the configured
    :attr:`solver`. ``MultivariateNormal.sample`` factors the
    covariance via :func:`gaussx.cholesky` and reparameterizes;
    the returned draw has shape ``(M,)``.

    Note: the SVGP variational workflow samples ``u`` from the
    *guide* :math:`q(u)`, not the prior. This method exists so the
    prior surface is symmetric with the guide surface and so users
    can score / draw inducing values against the prior directly
    (e.g.\\ for tests or for prior-sample initialization).
    """
    n = self.num_inducing
    op = self.inducing_operator()
    loc = jnp.zeros(n, dtype=op.out_structure().dtype)
    mvn = MultivariateNormal(loc, op, solver=self._resolved_solver())
    return mvn.sample(key)

Pathwise posterior samplers (#39)

Callable posterior function draws via Matheron's rule. Each sampled path is a :class:PathwiseFunction that evaluates in O(N_* · F · D + N_* · N_corr) per path — N_* · F · D for the RFF prior draw and N_* · N_corr for the kernel correction against the N_corr training (exact) or inducing (decoupled) points — so the same draw can be reused at arbitrary test sets without rebuilding a test-set covariance. Standard enabler for Thompson sampling, Bayesian optimization, and posterior visualization.

from pyrox.gp import (
    RBF,
    GPPrior,
    PathwiseSampler,
    DecoupledPathwiseSampler,
    FullRankGuide,
    SparseGPPrior,
)
import jax
import jax.numpy as jnp

# Exact GP:
posterior = GPPrior(kernel=RBF(), X=X).condition(y, jnp.array(0.05))
paths = PathwiseSampler(posterior, n_features=512).sample_paths(
    jax.random.PRNGKey(0), n_paths=32
)
draws = paths(X_star)            # (32, N_star)

# Sparse / decoupled:
sparse  = SparseGPPrior(kernel=RBF(), Z=Z)
guide   = FullRankGuide.init(Z.shape[0])
paths   = DecoupledPathwiseSampler(sparse, guide).sample_paths(key, n_paths=16)
samples = paths(X_star)

Currently supports RBF and Matern kernels. Point-inducing SparseGPPrior only — inducing-feature priors raise at construction.

pyrox.gp.PathwiseSampler

Bases: Module

Exact-GP pathwise posterior sampler using Matheron's rule.

Given a :class:ConditionedGP, draws a zero-mean RFF prior path f_tilde and an iid noise draw eps_tilde at the training inputs, forms the residual y - mu(X) - f_tilde(X) - eps_tilde, solves it against the cached noisy operator K + (jitter + sigma^2)I, and stores the result as posterior correction weights. The returned :class:PathwiseFunction is callable at any X_* in :math:\mathcal{O}(N_* \cdot F \cdot D + N_* \cdot N) per path, where N is the number of training (correction) points: the RFF prior term recomputes features over X_* each call (N_* · F · D), and the correction term forms a fresh K(X_*, X) block (N_* · N).

Example

posterior = GPPrior(kernel=RBF(), X=X).condition(y, jnp.array(0.05)) sampler = PathwiseSampler(posterior, n_features=512) paths = sampler.sample_paths(key, n_paths=32) draws = paths(X_star)

Example

sampler = PathwiseSampler(posterior, n_features=1024) thompson = sampler.sample_paths(key, n_paths=1) values = thompson(X_candidates)

Source code in src/pyrox/gp/_pathwise.py
class PathwiseSampler(eqx.Module):
    """Exact-GP pathwise posterior sampler using Matheron's rule.

    Given a :class:`ConditionedGP`, draws a zero-mean RFF prior path
    ``f_tilde`` and an iid noise draw ``eps_tilde`` at the training
    inputs, forms the residual ``y - mu(X) - f_tilde(X) - eps_tilde``,
    solves it against the cached noisy operator ``K + (jitter + sigma^2)I``,
    and stores the result as posterior correction weights. The returned
    :class:`PathwiseFunction` is callable at any ``X_*`` in
    :math:`\\mathcal{O}(N_* \\cdot F \\cdot D + N_* \\cdot N)` per path,
    where ``N`` is the number of training (correction) points: the RFF
    prior term recomputes features over ``X_*`` each call
    (``N_* · F · D``), and the correction term forms a fresh
    ``K(X_*, X)`` block (``N_* · N``).

    Example:
        >>> posterior = GPPrior(kernel=RBF(), X=X).condition(y, jnp.array(0.05))
        >>> sampler = PathwiseSampler(posterior, n_features=512)
        >>> paths = sampler.sample_paths(key, n_paths=32)
        >>> draws = paths(X_star)

    Example:
        >>> sampler = PathwiseSampler(posterior, n_features=1024)
        >>> thompson = sampler.sample_paths(key, n_paths=1)
        >>> values = thompson(X_candidates)
    """

    conditioned_gp: ConditionedGP
    n_features: int = eqx.field(static=True, default=512)

    def sample_paths(self, key: Array, n_paths: int = 1) -> PathwiseFunction:
        """Sample callable posterior paths.

        ``key`` is split into three subkeys: one for the RFF basis,
        one for the iid training-noise draw, and one reserved for
        future extensions.
        """
        rff_key, noise_key, _reserved = jax.random.split(key, 3)

        X = self.conditioned_gp.prior.X
        kernel = self.conditioned_gp.prior.kernel
        # Reuse the resolved (variance, lengthscale) captured by
        # GPPrior.condition under its kernel context. For Pattern B/C
        # kernels the cached operator was built with these exact
        # values; resampling here would put the RFF basis in a
        # different posterior than the cached training solve.
        cached = self.conditioned_gp.resolved_hyperparams
        cached_variance, cached_lengthscale = (
            cached if cached is not None else (None, None)
        )
        variance, lengthscale, omega, phase, feature_weights = draw_rff_cosine_basis(
            kernel,
            rff_key,
            n_paths=n_paths,
            n_features=self.n_features,
            in_features=X.shape[1],
            dtype=X.dtype,
            variance=cached_variance,
            lengthscale=cached_lengthscale,
        )
        prior_train = evaluate_rff_cosine_paths(
            X,
            variance=variance,
            lengthscale=lengthscale,
            omega=omega,
            phase=phase,
            weights=feature_weights,
        )
        mean_train = _broadcast_mean(self.conditioned_gp.prior.mean_fn, X)
        # Matheron requires Cov(eps_tilde) to match the diagonal added to
        # the cached operator. _noisy_operator uses (noise_var + jitter) I,
        # so eps_tilde must have the same variance — otherwise the
        # correction solve is inconsistent and paths are under-dispersed
        # (pronounced when jitter is bumped up for stability).
        noise_var = jnp.asarray(self.conditioned_gp.noise_var, dtype=X.dtype)
        jitter = jnp.asarray(self.conditioned_gp.prior.jitter, dtype=X.dtype)
        eps_var = noise_var + jitter
        noise = jnp.sqrt(eps_var) * jax.random.normal(
            noise_key, shape=(n_paths, X.shape[0]), dtype=X.dtype
        )
        residual = (
            self.conditioned_gp.y[None, :] - (mean_train[None, :] + prior_train) - noise
        )
        correction_weights = _solve_with_cholesky(
            cholesky(self.conditioned_gp.operator), residual
        )
        return PathwiseFunction(  # ty: ignore[invalid-return-type]
            kernel_fn=_frozen_kernel_fn(kernel, variance, lengthscale),
            correction_points=X,
            correction_weights=correction_weights,
            omega=omega,
            phase=phase,
            feature_weights=feature_weights,
            variance=variance,
            lengthscale=lengthscale,
            mean_fn=self.conditioned_gp.prior.mean_fn,
        )

    def __call__(
        self,
        key: Array,
        X_star: Float[Array, "N D"],
        n_paths: int = 1,
    ) -> Float[Array, "S N"]:
        """Convenience wrapper for ``sample_paths(key, n_paths)(X_star)``."""
        return self.sample_paths(key, n_paths=n_paths)(X_star)

__call__(key, X_star, n_paths=1)

Convenience wrapper for sample_paths(key, n_paths)(X_star).

Source code in src/pyrox/gp/_pathwise.py
def __call__(
    self,
    key: Array,
    X_star: Float[Array, "N D"],
    n_paths: int = 1,
) -> Float[Array, "S N"]:
    """Convenience wrapper for ``sample_paths(key, n_paths)(X_star)``."""
    return self.sample_paths(key, n_paths=n_paths)(X_star)

sample_paths(key, n_paths=1)

Sample callable posterior paths.

key is split into three subkeys: one for the RFF basis, one for the iid training-noise draw, and one reserved for future extensions.

Source code in src/pyrox/gp/_pathwise.py
def sample_paths(self, key: Array, n_paths: int = 1) -> PathwiseFunction:
    """Sample callable posterior paths.

    ``key`` is split into three subkeys: one for the RFF basis,
    one for the iid training-noise draw, and one reserved for
    future extensions.
    """
    rff_key, noise_key, _reserved = jax.random.split(key, 3)

    X = self.conditioned_gp.prior.X
    kernel = self.conditioned_gp.prior.kernel
    # Reuse the resolved (variance, lengthscale) captured by
    # GPPrior.condition under its kernel context. For Pattern B/C
    # kernels the cached operator was built with these exact
    # values; resampling here would put the RFF basis in a
    # different posterior than the cached training solve.
    cached = self.conditioned_gp.resolved_hyperparams
    cached_variance, cached_lengthscale = (
        cached if cached is not None else (None, None)
    )
    variance, lengthscale, omega, phase, feature_weights = draw_rff_cosine_basis(
        kernel,
        rff_key,
        n_paths=n_paths,
        n_features=self.n_features,
        in_features=X.shape[1],
        dtype=X.dtype,
        variance=cached_variance,
        lengthscale=cached_lengthscale,
    )
    prior_train = evaluate_rff_cosine_paths(
        X,
        variance=variance,
        lengthscale=lengthscale,
        omega=omega,
        phase=phase,
        weights=feature_weights,
    )
    mean_train = _broadcast_mean(self.conditioned_gp.prior.mean_fn, X)
    # Matheron requires Cov(eps_tilde) to match the diagonal added to
    # the cached operator. _noisy_operator uses (noise_var + jitter) I,
    # so eps_tilde must have the same variance — otherwise the
    # correction solve is inconsistent and paths are under-dispersed
    # (pronounced when jitter is bumped up for stability).
    noise_var = jnp.asarray(self.conditioned_gp.noise_var, dtype=X.dtype)
    jitter = jnp.asarray(self.conditioned_gp.prior.jitter, dtype=X.dtype)
    eps_var = noise_var + jitter
    noise = jnp.sqrt(eps_var) * jax.random.normal(
        noise_key, shape=(n_paths, X.shape[0]), dtype=X.dtype
    )
    residual = (
        self.conditioned_gp.y[None, :] - (mean_train[None, :] + prior_train) - noise
    )
    correction_weights = _solve_with_cholesky(
        cholesky(self.conditioned_gp.operator), residual
    )
    return PathwiseFunction(  # ty: ignore[invalid-return-type]
        kernel_fn=_frozen_kernel_fn(kernel, variance, lengthscale),
        correction_points=X,
        correction_weights=correction_weights,
        omega=omega,
        phase=phase,
        feature_weights=feature_weights,
        variance=variance,
        lengthscale=lengthscale,
        mean_fn=self.conditioned_gp.prior.mean_fn,
    )

pyrox.gp.DecoupledPathwiseSampler

Bases: Module

Sparse/decoupled pathwise sampler with RFF prior + inducing update.

The prior draw uses random features while the correction is represented in the inducing-point basis, so each sampled path stays callable at arbitrary inputs after a one-time inducing solve.

Supported for point-inducing :class:SparseGPPrior (Z=...); inducing-feature priors (inducing=...) are rejected at construction with a clear error.

Handles :class:WhitenedGuide automatically: whitened guide draws v ~ q(v) are unwhitened to inducing values u = L_ZZ v via :func:gaussx.unwhiten before forming the inducing-space residual.

Example

prior = SparseGPPrior(kernel=RBF(), Z=Z) guide = FullRankGuide.init(Z.shape[0]) sampler = DecoupledPathwiseSampler(prior, guide, n_features=512) paths = sampler.sample_paths(key, n_paths=16) draws = paths(X_star)

Source code in src/pyrox/gp/_pathwise.py
class DecoupledPathwiseSampler(eqx.Module):
    """Sparse/decoupled pathwise sampler with RFF prior + inducing update.

    The prior draw uses random features while the correction is represented in
    the inducing-point basis, so each sampled path stays callable at arbitrary
    inputs after a one-time inducing solve.

    Supported for point-inducing :class:`SparseGPPrior` (``Z=...``);
    inducing-feature priors (``inducing=...``) are rejected at
    construction with a clear error.

    Handles :class:`WhitenedGuide` automatically: whitened guide draws
    ``v ~ q(v)`` are unwhitened to inducing values ``u = L_ZZ v`` via
    :func:`gaussx.unwhiten` before forming the inducing-space residual.

    Example:
        >>> prior = SparseGPPrior(kernel=RBF(), Z=Z)
        >>> guide = FullRankGuide.init(Z.shape[0])
        >>> sampler = DecoupledPathwiseSampler(prior, guide, n_features=512)
        >>> paths = sampler.sample_paths(key, n_paths=16)
        >>> draws = paths(X_star)
    """

    prior: SparseGPPrior
    guide: Guide
    n_features: int = eqx.field(static=True, default=512)

    def __check_init__(self) -> None:
        if self.prior.Z is None:
            raise ValueError(
                "DecoupledPathwiseSampler currently requires a point-inducing "
                "SparseGPPrior constructed with `Z=...`. Inducing-feature "
                "priors (FourierInducingFeatures, SphericalHarmonicInducingFeatures, "
                "LaplacianInducingFeatures) are not yet supported."
            )

    def sample_paths(self, key: Array, n_paths: int = 1) -> PathwiseFunction:
        """Sample callable sparse posterior paths.

        ``key`` is split into three subkeys: one for the RFF basis, one
        for ``n_paths`` independent guide draws, and one for the
        jitter-augmentation of the prior inducing draw. The RFF basis
        draw and the :math:`K_{zz}` assembly share a single
        ``_kernel_context`` so kernels with hyperparameter priors
        (Pattern B / C) sample ``(variance, lengthscale)`` once.

        The Matheron correction needs ``Cov(u_tilde) = K_{zz} + \
        \\text{jitter}\\,I`` so it matches the operator that the
        correction is solved against. The bare RFF draw at ``Z``
        produces only the ``K_{zz}`` part; we add an iid Gaussian
        with variance ``jitter`` per inducing index to close the gap —
        without this, paths are under-dispersed when jitter is bumped
        up for stability.
        """
        rff_key, guide_key, jitter_key = jax.random.split(key, 3)

        Z = self.prior.Z
        assert Z is not None  # __check_init__ guarantees
        with _kernel_context(self.prior.kernel):
            basis = draw_rff_cosine_basis(
                self.prior.kernel,
                rff_key,
                n_paths=n_paths,
                n_features=self.n_features,
                in_features=Z.shape[1],
                dtype=Z.dtype,
            )
            variance, lengthscale, omega, phase, feature_weights = basis
            prior_inducing = evaluate_rff_cosine_paths(
                Z,
                variance=variance,
                lengthscale=lengthscale,
                omega=omega,
                phase=phase,
                weights=feature_weights,
            )
            inducing_chol = cholesky(self.prior.inducing_operator())

        # See docstring: u_tilde must have covariance K_zz + jitter I.
        jitter = jnp.asarray(self.prior.jitter, dtype=Z.dtype)
        prior_inducing = prior_inducing + jnp.sqrt(jitter) * jax.random.normal(
            jitter_key, shape=prior_inducing.shape, dtype=Z.dtype
        )

        guide_keys = jax.random.split(guide_key, n_paths)
        guide_samples = jax.vmap(self.guide.sample)(guide_keys)
        if isinstance(self.guide, WhitenedGuide):
            inducing_samples = jax.vmap(lambda sample: unwhiten(sample, inducing_chol))(
                guide_samples
            )
        else:
            inducing_samples = guide_samples

        correction_weights = _solve_with_cholesky(
            inducing_chol,
            inducing_samples - prior_inducing,
        )
        return PathwiseFunction(  # ty: ignore[invalid-return-type]
            kernel_fn=_frozen_kernel_fn(self.prior.kernel, variance, lengthscale),
            correction_points=Z,
            correction_weights=correction_weights,
            omega=omega,
            phase=phase,
            feature_weights=feature_weights,
            variance=variance,
            lengthscale=lengthscale,
            mean_fn=self.prior.mean_fn,
        )

    def __call__(
        self,
        key: Array,
        X_star: Float[Array, "N D"],
        n_paths: int = 1,
    ) -> Float[Array, "S N"]:
        """Convenience wrapper for ``sample_paths(key, n_paths)(X_star)``."""
        return self.sample_paths(key, n_paths=n_paths)(X_star)

__call__(key, X_star, n_paths=1)

Convenience wrapper for sample_paths(key, n_paths)(X_star).

Source code in src/pyrox/gp/_pathwise.py
def __call__(
    self,
    key: Array,
    X_star: Float[Array, "N D"],
    n_paths: int = 1,
) -> Float[Array, "S N"]:
    """Convenience wrapper for ``sample_paths(key, n_paths)(X_star)``."""
    return self.sample_paths(key, n_paths=n_paths)(X_star)

sample_paths(key, n_paths=1)

Sample callable sparse posterior paths.

key is split into three subkeys: one for the RFF basis, one for n_paths independent guide draws, and one for the jitter-augmentation of the prior inducing draw. The RFF basis draw and the :math:K_{zz} assembly share a single _kernel_context so kernels with hyperparameter priors (Pattern B / C) sample (variance, lengthscale) once.

The Matheron correction needs Cov(u_tilde) = K_{zz} + \text{jitter}\,I so it matches the operator that the correction is solved against. The bare RFF draw at Z produces only the K_{zz} part; we add an iid Gaussian with variance jitter per inducing index to close the gap — without this, paths are under-dispersed when jitter is bumped up for stability.

Source code in src/pyrox/gp/_pathwise.py
def sample_paths(self, key: Array, n_paths: int = 1) -> PathwiseFunction:
    """Sample callable sparse posterior paths.

    ``key`` is split into three subkeys: one for the RFF basis, one
    for ``n_paths`` independent guide draws, and one for the
    jitter-augmentation of the prior inducing draw. The RFF basis
    draw and the :math:`K_{zz}` assembly share a single
    ``_kernel_context`` so kernels with hyperparameter priors
    (Pattern B / C) sample ``(variance, lengthscale)`` once.

    The Matheron correction needs ``Cov(u_tilde) = K_{zz} + \
    \\text{jitter}\\,I`` so it matches the operator that the
    correction is solved against. The bare RFF draw at ``Z``
    produces only the ``K_{zz}`` part; we add an iid Gaussian
    with variance ``jitter`` per inducing index to close the gap —
    without this, paths are under-dispersed when jitter is bumped
    up for stability.
    """
    rff_key, guide_key, jitter_key = jax.random.split(key, 3)

    Z = self.prior.Z
    assert Z is not None  # __check_init__ guarantees
    with _kernel_context(self.prior.kernel):
        basis = draw_rff_cosine_basis(
            self.prior.kernel,
            rff_key,
            n_paths=n_paths,
            n_features=self.n_features,
            in_features=Z.shape[1],
            dtype=Z.dtype,
        )
        variance, lengthscale, omega, phase, feature_weights = basis
        prior_inducing = evaluate_rff_cosine_paths(
            Z,
            variance=variance,
            lengthscale=lengthscale,
            omega=omega,
            phase=phase,
            weights=feature_weights,
        )
        inducing_chol = cholesky(self.prior.inducing_operator())

    # See docstring: u_tilde must have covariance K_zz + jitter I.
    jitter = jnp.asarray(self.prior.jitter, dtype=Z.dtype)
    prior_inducing = prior_inducing + jnp.sqrt(jitter) * jax.random.normal(
        jitter_key, shape=prior_inducing.shape, dtype=Z.dtype
    )

    guide_keys = jax.random.split(guide_key, n_paths)
    guide_samples = jax.vmap(self.guide.sample)(guide_keys)
    if isinstance(self.guide, WhitenedGuide):
        inducing_samples = jax.vmap(lambda sample: unwhiten(sample, inducing_chol))(
            guide_samples
        )
    else:
        inducing_samples = guide_samples

    correction_weights = _solve_with_cholesky(
        inducing_chol,
        inducing_samples - prior_inducing,
    )
    return PathwiseFunction(  # ty: ignore[invalid-return-type]
        kernel_fn=_frozen_kernel_fn(self.prior.kernel, variance, lengthscale),
        correction_points=Z,
        correction_weights=correction_weights,
        omega=omega,
        phase=phase,
        feature_weights=feature_weights,
        variance=variance,
        lengthscale=lengthscale,
        mean_fn=self.prior.mean_fn,
    )

pyrox.gp.PathwiseFunction

Bases: Module

Callable posterior function draw(s) produced by a pathwise sampler.

Carries the random-feature prior basis (omega, phase, feature_weights) and the posterior correction weights evaluated against either the training inputs (exact) or the inducing inputs (sparse). Calling the instance on test points X_star evaluates

.. math::

f_{\text{post}}(x_*) =
    \tilde{f}(x_*)
    + K(x_*,\, X_{\mathrm{corr}})\,\alpha
    + \mu(x_*),

where :math:\tilde f is the stored RFF prior draw and :math:X_{\mathrm{corr}} is either the training set (exact) or the inducing set (sparse).

The kernel enters only as a frozen (X1, X2) -> K callable with the sample-time variance and lengthscale baked in, so repeated evaluations stay consistent with the original RFF draw even for Pattern B/C kernels that register hyperparameter priors.

Example

prior = GPPrior(kernel=RBF(), X=X) posterior = prior.condition(y, noise_var=jnp.array(0.05)) sampler = PathwiseSampler(posterior, n_features=512) paths = sampler.sample_paths(key, n_paths=8) samples = paths(X_star)

Example

sparse_prior = SparseGPPrior(kernel=RBF(), Z=Z) guide = FullRankGuide.init(Z.shape[0]) paths = DecoupledPathwiseSampler(sparse_prior, guide).sample_paths(key) thompson_values = paths(X_candidates)

Source code in src/pyrox/gp/_pathwise.py
class PathwiseFunction(eqx.Module):
    """Callable posterior function draw(s) produced by a pathwise sampler.

    Carries the random-feature prior basis (``omega``, ``phase``,
    ``feature_weights``) and the posterior correction weights evaluated
    against either the training inputs (exact) or the inducing inputs
    (sparse). Calling the instance on test points ``X_star`` evaluates

    .. math::

        f_{\\text{post}}(x_*) =
            \\tilde{f}(x_*)
            + K(x_*,\\, X_{\\mathrm{corr}})\\,\\alpha
            + \\mu(x_*),

    where :math:`\\tilde f` is the stored RFF prior draw and
    :math:`X_{\\mathrm{corr}}` is either the training set (exact) or
    the inducing set (sparse).

    The kernel enters only as a frozen ``(X1, X2) -> K`` callable with
    the sample-time ``variance`` and ``lengthscale`` baked in, so
    repeated evaluations stay consistent with the original RFF draw
    even for Pattern B/C kernels that register hyperparameter priors.

    Example:
        >>> prior = GPPrior(kernel=RBF(), X=X)
        >>> posterior = prior.condition(y, noise_var=jnp.array(0.05))
        >>> sampler = PathwiseSampler(posterior, n_features=512)
        >>> paths = sampler.sample_paths(key, n_paths=8)
        >>> samples = paths(X_star)

    Example:
        >>> sparse_prior = SparseGPPrior(kernel=RBF(), Z=Z)
        >>> guide = FullRankGuide.init(Z.shape[0])
        >>> paths = DecoupledPathwiseSampler(sparse_prior, guide).sample_paths(key)
        >>> thompson_values = paths(X_candidates)
    """

    kernel_fn: Callable[
        [Float[Array, "N1 D"], Float[Array, "N2 D"]], Float[Array, "N1 N2"]
    ]
    correction_points: Float[Array, "R D"]
    correction_weights: Float[Array, "S R"]
    omega: Float[Array, "S D F"]
    phase: Float[Array, "S F"]
    feature_weights: Float[Array, "S F"]
    variance: Float[Array, ""]
    lengthscale: Float[Array, ""]
    mean_fn: Callable[[Float[Array, "N D"]], Float[Array, " N"]] | None = None

    def __call__(self, X_star: Float[Array, "N D"]) -> Float[Array, "S N"]:
        """Evaluate the sampled function(s) at arbitrary inputs ``X_star``."""
        prior = evaluate_rff_cosine_paths(
            X_star,
            variance=self.variance,
            lengthscale=self.lengthscale,
            omega=self.omega,
            phase=self.phase,
            weights=self.feature_weights,
        )
        K_cross = self.kernel_fn(X_star, self.correction_points)
        update = jnp.einsum("nr,sr->sn", K_cross, self.correction_weights)
        mean = _broadcast_mean(self.mean_fn, X_star)
        return prior + update + mean[None, :]

__call__(X_star)

Evaluate the sampled function(s) at arbitrary inputs X_star.

Source code in src/pyrox/gp/_pathwise.py
def __call__(self, X_star: Float[Array, "N D"]) -> Float[Array, "S N"]:
    """Evaluate the sampled function(s) at arbitrary inputs ``X_star``."""
    prior = evaluate_rff_cosine_paths(
        X_star,
        variance=self.variance,
        lengthscale=self.lengthscale,
        omega=self.omega,
        phase=self.phase,
        weights=self.feature_weights,
    )
    K_cross = self.kernel_fn(X_star, self.correction_points)
    update = jnp.einsum("nr,sr->sn", K_cross, self.correction_weights)
    mean = _broadcast_mean(self.mean_fn, X_star)
    return prior + update + mean[None, :]

State-space (SDE) kernels

Stationary 1-D kernels expressed as linear time-invariant SDEs. Once in state-space form, GP inference on a 1-D grid reduces to Kalman filtering in O(N d^3) instead of O(N^3) Cholesky. The protocol exposes sde_params() -> (F, L, H, Q_c, P_inf) and discretise(dt) -> (A_k, Q_k) for downstream Kalman / RTS use.

import jax.numpy as jnp
from pyrox.gp import (
    ConstantSDE, CosineSDE, MaternSDE, PeriodicSDE,
    ProductSDE, QuasiPeriodicSDE, SumSDE,
)

# Primitive kernels
matern = MaternSDE(variance=1.0, lengthscale=0.5, order=1)  # nu = 3/2
cos    = CosineSDE(variance=1.0, frequency=2.0)
const  = ConstantSDE(variance=0.3)
per    = PeriodicSDE(variance=1.0, lengthscale=1.0, period=2.0, n_harmonics=7)

# Composition: trend + offset
trend = SumSDE((matern, const))                   # state dim = 2 + 1 = 3

# Composition: damped oscillation (Matern x Cosine)
damped = ProductSDE(matern, cos)                  # state dim = 2 * 2 = 4

# Quasi-periodic (Matern x Periodic) — convenience wrapper around ProductSDE
qp = QuasiPeriodicSDE(matern, per)                # state dim = 2 * 15 = 30

pyrox.gp.SDEKernel

Bases: Module

Abstract base for kernels with state-space (SDE) representations.

Stationary kernels with rational spectral densities admit exact finite-dimensional state-space representations of the form

.. math:: d\mathbf{x}(t) = F\,\mathbf{x}(t)\, dt + L\, dw(t), \qquad f(t) = H\,\mathbf{x}(t)

where :math:w(t) is white noise with spectral density :math:Q_c and :math:P_\infty is the stationary state covariance solving the Lyapunov equation :math:F P_\infty + P_\infty F^\top + L Q_c L^\top = 0.

Discretisation at time step :math:\Delta t gives

.. math:: A_k = \exp(F\,\Delta t), \qquad Q_k = P_\infty - A_k\,P_\infty\,A_k^\top,

so that :math:x_{k+1} = A_k x_k + q_k with :math:q_k \sim \mathcal{N}(0, Q_k).

Concrete subclasses implement :meth:sde_params returning the closed-form (F, L, H, Q_c, P_inf) tuple. :meth:discretise defaults to a generic expm-based implementation; subclasses with closed-form transitions (e.g. Matern-1/2) may override it.

The continuous-time autocovariance recovered from the SDE is :math:k(\tau) = H\,\exp(F|\tau|)\,P_\infty\,H^\top for stationary kernels.

Source code in src/pyrox/gp/_protocols.py
class SDEKernel(eqx.Module):
    r"""Abstract base for kernels with state-space (SDE) representations.

    Stationary kernels with rational spectral densities admit exact
    finite-dimensional state-space representations of the form

    .. math::
        d\mathbf{x}(t) = F\,\mathbf{x}(t)\, dt + L\, dw(t),
        \qquad f(t) = H\,\mathbf{x}(t)

    where :math:`w(t)` is white noise with spectral density :math:`Q_c`
    and :math:`P_\infty` is the stationary state covariance solving the
    Lyapunov equation :math:`F P_\infty + P_\infty F^\top + L Q_c L^\top = 0`.

    Discretisation at time step :math:`\Delta t` gives

    .. math::
        A_k = \exp(F\,\Delta t),
        \qquad Q_k = P_\infty - A_k\,P_\infty\,A_k^\top,

    so that :math:`x_{k+1} = A_k x_k + q_k` with :math:`q_k \sim \mathcal{N}(0, Q_k)`.

    Concrete subclasses implement :meth:`sde_params` returning the
    closed-form ``(F, L, H, Q_c, P_inf)`` tuple. :meth:`discretise`
    defaults to a generic ``expm``-based implementation; subclasses with
    closed-form transitions (e.g. Matern-1/2) may override it.

    The continuous-time autocovariance recovered from the SDE is
    :math:`k(\tau) = H\,\exp(F|\tau|)\,P_\infty\,H^\top` for stationary
    kernels.
    """

    @property
    @abstractmethod
    def state_dim(self) -> int:
        """State dimension :math:`d` of the SDE representation."""
        raise NotImplementedError

    @abstractmethod
    def sde_params(
        self,
    ) -> tuple[
        Float[Array, "d d"],
        Float[Array, "d s"],
        Float[Array, "1 d"],
        Float[Array, "s s"],
        Float[Array, "d d"],
    ]:
        """Return ``(F, L, H, Q_c, P_inf)`` defining the continuous SDE."""
        raise NotImplementedError

    def discretise(
        self,
        dt: Float[Array, " N"],
    ) -> tuple[Float[Array, "N d d"], Float[Array, "N d d"]]:
        r"""Discretise the SDE at time steps ``dt``.

        Default implementation evaluates ``A_k = expm(F dt_k)`` via
        ``jax.scipy.linalg.expm`` and ``Q_k = P_\infty - A_k P_\infty A_k^\top``.
        Subclasses with closed-form transitions should override.

        Args:
            dt: ``(N,)`` array of (non-negative) time steps.

        Returns:
            Tuple ``(A, Q)`` of ``(N, d, d)`` arrays.
        """
        F, _L, _H, _Q_c, P_inf = self.sde_params()

        def _step(
            dt_n: Float[Array, ""],
        ) -> tuple[Float[Array, "d d"], Float[Array, "d d"]]:
            A = jsl.expm(F * dt_n)
            Q = P_inf - A @ P_inf @ A.T
            # Symmetrise to absorb roundoff: ``Q`` is theoretically symmetric,
            # but float arithmetic can leave asymmetric perturbations that
            # break downstream Cholesky factorisation in Kalman steps.
            Q = 0.5 * (Q + Q.T)
            return A, Q

        return jax.vmap(_step)(jnp.asarray(dt))

state_dim abstractmethod property

State dimension :math:d of the SDE representation.

discretise(dt)

Discretise the SDE at time steps dt.

Default implementation evaluates A_k = expm(F dt_k) via jax.scipy.linalg.expm and Q_k = P_\infty - A_k P_\infty A_k^\top. Subclasses with closed-form transitions should override.

Parameters:

Name Type Description Default
dt Float[Array, ' N']

(N,) array of (non-negative) time steps.

required

Returns:

Type Description
tuple[Float[Array, 'N d d'], Float[Array, 'N d d']]

Tuple (A, Q) of (N, d, d) arrays.

Source code in src/pyrox/gp/_protocols.py
def discretise(
    self,
    dt: Float[Array, " N"],
) -> tuple[Float[Array, "N d d"], Float[Array, "N d d"]]:
    r"""Discretise the SDE at time steps ``dt``.

    Default implementation evaluates ``A_k = expm(F dt_k)`` via
    ``jax.scipy.linalg.expm`` and ``Q_k = P_\infty - A_k P_\infty A_k^\top``.
    Subclasses with closed-form transitions should override.

    Args:
        dt: ``(N,)`` array of (non-negative) time steps.

    Returns:
        Tuple ``(A, Q)`` of ``(N, d, d)`` arrays.
    """
    F, _L, _H, _Q_c, P_inf = self.sde_params()

    def _step(
        dt_n: Float[Array, ""],
    ) -> tuple[Float[Array, "d d"], Float[Array, "d d"]]:
        A = jsl.expm(F * dt_n)
        Q = P_inf - A @ P_inf @ A.T
        # Symmetrise to absorb roundoff: ``Q`` is theoretically symmetric,
        # but float arithmetic can leave asymmetric perturbations that
        # break downstream Cholesky factorisation in Kalman steps.
        Q = 0.5 * (Q + Q.T)
        return A, Q

    return jax.vmap(_step)(jnp.asarray(dt))

sde_params() abstractmethod

Return (F, L, H, Q_c, P_inf) defining the continuous SDE.

Source code in src/pyrox/gp/_protocols.py
@abstractmethod
def sde_params(
    self,
) -> tuple[
    Float[Array, "d d"],
    Float[Array, "d s"],
    Float[Array, "1 d"],
    Float[Array, "s s"],
    Float[Array, "d d"],
]:
    """Return ``(F, L, H, Q_c, P_inf)`` defining the continuous SDE."""
    raise NotImplementedError

pyrox.gp.MaternSDE

Bases: SDEKernel

Matern kernel in state-space (companion) form for order in {0, 1, 2}.

The Matern-:math:\nu kernel with :math:\nu = p + 1/2 for :math:p \in \{0, 1, 2\} has an exact :math:d = p + 1 dimensional SDE representation. The closed-form parameters are:

  • Matern-1/2 (order=0, :math:d=1): :math:\lambda = 1/\ell,

.. math:: F = [-\lambda],\quad L = [1],\quad H = [1],\quad Q_c = 2\sigma^2\lambda,\quad P_\infty = \sigma^2.

  • Matern-3/2 (order=1, :math:d=2): :math:\lambda = \sqrt{3}/\ell,

.. math:: F = \begin{pmatrix} 0 & 1 \ -\lambda^2 & -2\lambda \end{pmatrix}, \quad L = \begin{pmatrix} 0 \ 1 \end{pmatrix},\quad H = \begin{pmatrix} 1 & 0 \end{pmatrix},

.. math:: Q_c = 4\sigma^2\lambda^3,\quad P_\infty = \sigma^2\,\mathrm{diag}(1,\;\lambda^2).

  • Matern-5/2 (order=2, :math:d=3): :math:\lambda = \sqrt{5}/\ell,

.. math:: F = \begin{pmatrix} 0 & 1 & 0 \ 0 & 0 & 1 \ -\lambda^3 & -3\lambda^2 & -3\lambda \end{pmatrix}, \quad L = \begin{pmatrix} 0 \ 0 \ 1 \end{pmatrix},\quad H = \begin{pmatrix} 1 & 0 & 0 \end{pmatrix},

.. math:: Q_c = \tfrac{16}{3}\sigma^2\lambda^5,\quad P_\infty = \sigma^2 \begin{pmatrix} 1 & 0 & -\lambda^2/3 \ 0 & \lambda^2/3 & 0 \ -\lambda^2/3 & 0 & \lambda^4 \end{pmatrix}.

order is a static (Python int) field — it picks a code path, not a trainable parameter. variance and lengthscale are JAX-traced scalars suitable for autograd.

Examples:

>>> import jax.numpy as jnp
>>> sde = MaternSDE(variance=1.0, lengthscale=0.5, order=1)
>>> F, L, H, Q_c, P_inf = sde.sde_params()
>>> A, Q = sde.discretise(jnp.array([0.1, 0.2, 0.3]))
>>> A.shape, Q.shape
((3, 2, 2), (3, 2, 2))
References

Sarkka & Solin (2019), Applied Stochastic Differential Equations, Ch. 12; Hartikainen & Sarkka (2010), Kalman Filtering and Smoothing Solutions to Temporal Gaussian Process Regression Models, IEEE MLSP.

Source code in src/pyrox/gp/_sde_kernels.py
class MaternSDE(SDEKernel):
    r"""Matern kernel in state-space (companion) form for ``order in {0, 1, 2}``.

    The Matern-:math:`\nu` kernel with :math:`\nu = p + 1/2` for
    :math:`p \in \{0, 1, 2\}` has an exact :math:`d = p + 1` dimensional
    SDE representation. The closed-form parameters are:

    * **Matern-1/2** (``order=0``, :math:`d=1`): :math:`\lambda = 1/\ell`,

      .. math::
          F = [-\lambda],\quad L = [1],\quad H = [1],\quad
          Q_c = 2\sigma^2\lambda,\quad P_\infty = \sigma^2.

    * **Matern-3/2** (``order=1``, :math:`d=2`): :math:`\lambda = \sqrt{3}/\ell`,

      .. math::
          F = \begin{pmatrix} 0 & 1 \\ -\lambda^2 & -2\lambda \end{pmatrix},
          \quad L = \begin{pmatrix} 0 \\ 1 \end{pmatrix},\quad
          H = \begin{pmatrix} 1 & 0 \end{pmatrix},

      .. math::
          Q_c = 4\sigma^2\lambda^3,\quad
          P_\infty = \sigma^2\,\mathrm{diag}(1,\;\lambda^2).

    * **Matern-5/2** (``order=2``, :math:`d=3`): :math:`\lambda = \sqrt{5}/\ell`,

      .. math::
          F = \begin{pmatrix} 0 & 1 & 0 \\ 0 & 0 & 1 \\
          -\lambda^3 & -3\lambda^2 & -3\lambda \end{pmatrix},
          \quad L = \begin{pmatrix} 0 \\ 0 \\ 1 \end{pmatrix},\quad
          H = \begin{pmatrix} 1 & 0 & 0 \end{pmatrix},

      .. math::
          Q_c = \tfrac{16}{3}\sigma^2\lambda^5,\quad
          P_\infty = \sigma^2 \begin{pmatrix}
              1 & 0 & -\lambda^2/3 \\
              0 & \lambda^2/3 & 0 \\
              -\lambda^2/3 & 0 & \lambda^4
          \end{pmatrix}.

    ``order`` is a static (Python ``int``) field — it picks a code path,
    not a trainable parameter. ``variance`` and ``lengthscale`` are
    JAX-traced scalars suitable for autograd.

    Examples:
        >>> import jax.numpy as jnp
        >>> sde = MaternSDE(variance=1.0, lengthscale=0.5, order=1)
        >>> F, L, H, Q_c, P_inf = sde.sde_params()
        >>> A, Q = sde.discretise(jnp.array([0.1, 0.2, 0.3]))
        >>> A.shape, Q.shape
        ((3, 2, 2), (3, 2, 2))

    References:
        Sarkka & Solin (2019), *Applied Stochastic Differential Equations*,
        Ch. 12; Hartikainen & Sarkka (2010), *Kalman Filtering and
        Smoothing Solutions to Temporal Gaussian Process Regression
        Models*, IEEE MLSP.
    """

    variance: Float[Array, ""]
    lengthscale: Float[Array, ""]
    order: int = eqx.field(static=True)

    def __init__(
        self,
        variance: float | Float[Array, ""] = 1.0,
        lengthscale: float | Float[Array, ""] = 1.0,
        order: int = 1,
    ) -> None:
        if order not in (0, 1, 2):
            raise ValueError(
                "MaternSDE supports order in {0, 1, 2} (nu = order + 1/2), "
                f"got {order!r}"
            )
        # Eager positivity checks for concrete (non-traced) Python scalar inputs.
        # JAX tracer inputs (e.g. inside ``jax.jit``) are skipped — under tracing
        # we cannot inspect the value; downstream training-time priors handle
        # constraint enforcement.
        if isinstance(variance, (int, float)) and variance <= 0:
            raise ValueError(f"variance must be positive, got {variance!r}")
        if isinstance(lengthscale, (int, float)) and lengthscale <= 0:
            raise ValueError(f"lengthscale must be positive, got {lengthscale!r}")
        # Coerce to a floating dtype so integer inputs (``variance=1``) don't
        # propagate as integer-typed parameters.
        self.variance = jnp.asarray(variance, dtype=jnp.result_type(variance, 0.0))
        self.lengthscale = jnp.asarray(
            lengthscale, dtype=jnp.result_type(lengthscale, 0.0)
        )
        self.order = order

    @property
    def state_dim(self) -> int:
        """State dimension ``d = order + 1``."""
        return self.order + 1

    @property
    def nu(self) -> float:
        """Smoothness ``nu = order + 1/2``."""
        return self.order + 0.5

    def sde_params(
        self,
    ) -> tuple[
        Float[Array, "d d"],
        Float[Array, "d 1"],
        Float[Array, "1 d"],
        Float[Array, "1 1"],
        Float[Array, "d d"],
    ]:
        """Return ``(F, L, H, Q_c, P_inf)`` for the chosen Matern order."""
        sigma2 = self.variance
        ell = self.lengthscale
        zero = jnp.zeros_like(ell)
        one = jnp.ones_like(ell)

        if self.order == 0:
            lam = one / ell
            F = jnp.stack([jnp.stack([-lam])])
            L = jnp.stack([jnp.stack([one])])
            H = jnp.stack([jnp.stack([one])])
            Q_c = jnp.stack([jnp.stack([2.0 * sigma2 * lam])])
            P_inf = jnp.stack([jnp.stack([sigma2])])
            return F, L, H, Q_c, P_inf

        if self.order == 1:
            lam = jnp.sqrt(jnp.asarray(3.0)) / ell
            F = jnp.stack(
                [
                    jnp.stack([zero, one]),
                    jnp.stack([-(lam**2), -2.0 * lam]),
                ]
            )
            L = jnp.stack([jnp.stack([zero]), jnp.stack([one])])
            H = jnp.stack([jnp.stack([one, zero])])
            Q_c = jnp.stack([jnp.stack([4.0 * sigma2 * lam**3])])
            P_inf = jnp.stack(
                [
                    jnp.stack([sigma2, zero]),
                    jnp.stack([zero, sigma2 * lam**2]),
                ]
            )
            return F, L, H, Q_c, P_inf

        # order == 2 (Matern-5/2)
        lam = jnp.sqrt(jnp.asarray(5.0)) / ell
        F = jnp.stack(
            [
                jnp.stack([zero, one, zero]),
                jnp.stack([zero, zero, one]),
                jnp.stack([-(lam**3), -3.0 * lam**2, -3.0 * lam]),
            ]
        )
        L = jnp.stack([jnp.stack([zero]), jnp.stack([zero]), jnp.stack([one])])
        H = jnp.stack([jnp.stack([one, zero, zero])])
        Q_c = jnp.stack([jnp.stack([(16.0 / 3.0) * sigma2 * lam**5])])
        kappa = sigma2 * lam**2 / 3.0  # off-diagonal magnitude
        P_inf = jnp.stack(
            [
                jnp.stack([sigma2, zero, -kappa]),
                jnp.stack([zero, kappa, zero]),
                jnp.stack([-kappa, zero, sigma2 * lam**4]),
            ]
        )
        return F, L, H, Q_c, P_inf

nu property

Smoothness nu = order + 1/2.

state_dim property

State dimension d = order + 1.

sde_params()

Return (F, L, H, Q_c, P_inf) for the chosen Matern order.

Source code in src/pyrox/gp/_sde_kernels.py
def sde_params(
    self,
) -> tuple[
    Float[Array, "d d"],
    Float[Array, "d 1"],
    Float[Array, "1 d"],
    Float[Array, "1 1"],
    Float[Array, "d d"],
]:
    """Return ``(F, L, H, Q_c, P_inf)`` for the chosen Matern order."""
    sigma2 = self.variance
    ell = self.lengthscale
    zero = jnp.zeros_like(ell)
    one = jnp.ones_like(ell)

    if self.order == 0:
        lam = one / ell
        F = jnp.stack([jnp.stack([-lam])])
        L = jnp.stack([jnp.stack([one])])
        H = jnp.stack([jnp.stack([one])])
        Q_c = jnp.stack([jnp.stack([2.0 * sigma2 * lam])])
        P_inf = jnp.stack([jnp.stack([sigma2])])
        return F, L, H, Q_c, P_inf

    if self.order == 1:
        lam = jnp.sqrt(jnp.asarray(3.0)) / ell
        F = jnp.stack(
            [
                jnp.stack([zero, one]),
                jnp.stack([-(lam**2), -2.0 * lam]),
            ]
        )
        L = jnp.stack([jnp.stack([zero]), jnp.stack([one])])
        H = jnp.stack([jnp.stack([one, zero])])
        Q_c = jnp.stack([jnp.stack([4.0 * sigma2 * lam**3])])
        P_inf = jnp.stack(
            [
                jnp.stack([sigma2, zero]),
                jnp.stack([zero, sigma2 * lam**2]),
            ]
        )
        return F, L, H, Q_c, P_inf

    # order == 2 (Matern-5/2)
    lam = jnp.sqrt(jnp.asarray(5.0)) / ell
    F = jnp.stack(
        [
            jnp.stack([zero, one, zero]),
            jnp.stack([zero, zero, one]),
            jnp.stack([-(lam**3), -3.0 * lam**2, -3.0 * lam]),
        ]
    )
    L = jnp.stack([jnp.stack([zero]), jnp.stack([zero]), jnp.stack([one])])
    H = jnp.stack([jnp.stack([one, zero, zero])])
    Q_c = jnp.stack([jnp.stack([(16.0 / 3.0) * sigma2 * lam**5])])
    kappa = sigma2 * lam**2 / 3.0  # off-diagonal magnitude
    P_inf = jnp.stack(
        [
            jnp.stack([sigma2, zero, -kappa]),
            jnp.stack([zero, kappa, zero]),
            jnp.stack([-kappa, zero, sigma2 * lam**4]),
        ]
    )
    return F, L, H, Q_c, P_inf

pyrox.gp.ConstantSDE

Bases: SDEKernel

Constant kernel :math:k(\tau) = \sigma^2 in state-space form.

A degenerate 1-D state space with zero dynamics and zero diffusion:

.. math:: F = [0],\quad L = [0],\quad H = [1],\quad Q_c = [0],\quad P_\infty = [\sigma^2].

The transition is the identity A_k = I and the process noise is zero Q_k = 0. Useful as a non-trivial component of a :class:SumSDE (e.g. Matern + Constant for a fixed offset).

Source code in src/pyrox/gp/_sde_kernels.py
class ConstantSDE(SDEKernel):
    r"""Constant kernel :math:`k(\tau) = \sigma^2` in state-space form.

    A degenerate 1-D state space with zero dynamics and zero diffusion:

    .. math::
        F = [0],\quad L = [0],\quad H = [1],\quad Q_c = [0],\quad
        P_\infty = [\sigma^2].

    The transition is the identity ``A_k = I`` and the process noise is
    zero ``Q_k = 0``. Useful as a non-trivial component of a
    :class:`SumSDE` (e.g. ``Matern + Constant`` for a fixed offset).
    """

    variance: Float[Array, ""]

    def __init__(self, variance: float | Float[Array, ""] = 1.0) -> None:
        # Eager positivity check for concrete (non-traced) Python scalar inputs
        # — a non-positive variance would yield a non-PSD ``P_inf`` and break
        # downstream Cholesky / Kalman steps.
        if isinstance(variance, (int, float)) and variance <= 0:
            raise ValueError(f"variance must be positive, got {variance!r}")
        self.variance = jnp.asarray(variance, dtype=jnp.result_type(variance, 0.0))

    @property
    def state_dim(self) -> int:
        return 1

    def sde_params(
        self,
    ) -> tuple[
        Float[Array, "1 1"],
        Float[Array, "1 1"],
        Float[Array, "1 1"],
        Float[Array, "1 1"],
        Float[Array, "1 1"],
    ]:
        sigma2 = self.variance
        zero = jnp.zeros_like(sigma2)
        one = jnp.ones_like(sigma2)
        F = jnp.stack([jnp.stack([zero])])
        L = jnp.stack([jnp.stack([zero])])
        H = jnp.stack([jnp.stack([one])])
        Q_c = jnp.stack([jnp.stack([zero])])
        P_inf = jnp.stack([jnp.stack([sigma2])])
        return F, L, H, Q_c, P_inf

pyrox.gp.CosineSDE

Bases: SDEKernel

Cosine kernel :math:k(\tau) = \sigma^2 \cos(\omega_0 \tau) in SDE form.

A 2-D deterministic oscillator with rotation matrix transitions:

.. math:: F = \begin{pmatrix} 0 & -\omega_0 \ \omega_0 & 0 \end{pmatrix}, \quad L = \begin{pmatrix} 0 \ 0 \end{pmatrix},\quad H = \begin{pmatrix} 1 & 0 \end{pmatrix},

.. math:: Q_c = 0,\quad P_\infty = \sigma^2 I_2.

There is no driving noise, so the discrete-time transition is a pure rotation :math:A_k = R(\omega_0\,\Delta t_k) and :math:Q_k = 0. The :meth:discretise method overrides the default expm path with the closed-form rotation for efficiency.

Source code in src/pyrox/gp/_sde_kernels.py
class CosineSDE(SDEKernel):
    r"""Cosine kernel :math:`k(\tau) = \sigma^2 \cos(\omega_0 \tau)` in SDE form.

    A 2-D deterministic oscillator with rotation matrix transitions:

    .. math::
        F = \begin{pmatrix} 0 & -\omega_0 \\ \omega_0 & 0 \end{pmatrix},
        \quad L = \begin{pmatrix} 0 \\ 0 \end{pmatrix},\quad
        H = \begin{pmatrix} 1 & 0 \end{pmatrix},

    .. math::
        Q_c = 0,\quad P_\infty = \sigma^2 I_2.

    There is no driving noise, so the discrete-time transition is a pure
    rotation :math:`A_k = R(\omega_0\,\Delta t_k)` and :math:`Q_k = 0`.
    The :meth:`discretise` method overrides the default ``expm`` path
    with the closed-form rotation for efficiency.
    """

    variance: Float[Array, ""]
    frequency: Float[Array, ""]

    def __init__(
        self,
        variance: float | Float[Array, ""] = 1.0,
        frequency: float | Float[Array, ""] = 1.0,
    ) -> None:
        # Eager positivity check for concrete (non-traced) Python scalar inputs.
        # ``variance <= 0`` would make ``P_inf`` non-PSD; ``frequency`` may be
        # any nonzero real, so we don't constrain its sign here.
        if isinstance(variance, (int, float)) and variance <= 0:
            raise ValueError(f"variance must be positive, got {variance!r}")
        self.variance = jnp.asarray(variance, dtype=jnp.result_type(variance, 0.0))
        self.frequency = jnp.asarray(frequency, dtype=jnp.result_type(frequency, 0.0))

    @property
    def state_dim(self) -> int:
        return 2

    def sde_params(
        self,
    ) -> tuple[
        Float[Array, "2 2"],
        Float[Array, "2 1"],
        Float[Array, "1 2"],
        Float[Array, "1 1"],
        Float[Array, "2 2"],
    ]:
        sigma2 = self.variance
        omega = self.frequency
        zero = jnp.zeros_like(omega)
        one = jnp.ones_like(omega)
        F = jnp.stack(
            [
                jnp.stack([zero, -omega]),
                jnp.stack([omega, zero]),
            ]
        )
        L = jnp.stack([jnp.stack([zero]), jnp.stack([zero])])
        H = jnp.stack([jnp.stack([one, zero])])
        Q_c = jnp.stack([jnp.stack([zero])])
        P_inf = jnp.stack(
            [
                jnp.stack([sigma2, zero]),
                jnp.stack([zero, sigma2]),
            ]
        )
        return F, L, H, Q_c, P_inf

    def discretise(
        self,
        dt: Float[Array, " N"],
    ) -> tuple[Float[Array, "N 2 2"], Float[Array, "N 2 2"]]:
        """Closed-form rotation: ``A_k = R(omega * dt_k)``, ``Q_k = 0``."""
        omega = self.frequency
        theta = jnp.asarray(dt) * omega
        c, s = jnp.cos(theta), jnp.sin(theta)
        # A has shape (N, 2, 2); stack rows then columns.
        A = jnp.stack(
            [jnp.stack([c, -s], axis=-1), jnp.stack([s, c], axis=-1)],
            axis=-2,
        )
        Q = jnp.zeros_like(A)
        return A, Q

discretise(dt)

Closed-form rotation: A_k = R(omega * dt_k), Q_k = 0.

Source code in src/pyrox/gp/_sde_kernels.py
def discretise(
    self,
    dt: Float[Array, " N"],
) -> tuple[Float[Array, "N 2 2"], Float[Array, "N 2 2"]]:
    """Closed-form rotation: ``A_k = R(omega * dt_k)``, ``Q_k = 0``."""
    omega = self.frequency
    theta = jnp.asarray(dt) * omega
    c, s = jnp.cos(theta), jnp.sin(theta)
    # A has shape (N, 2, 2); stack rows then columns.
    A = jnp.stack(
        [jnp.stack([c, -s], axis=-1), jnp.stack([s, c], axis=-1)],
        axis=-2,
    )
    Q = jnp.zeros_like(A)
    return A, Q

pyrox.gp.PeriodicSDE

Bases: SDEKernel

Periodic kernel in state-space form via Fourier-series truncation.

The MacKay periodic kernel :math:k(\tau) = \sigma^2 \exp\!\bigl(-2 \sin^2(\pi\tau/T)/\ell^2\bigr) expands as

.. math:: k(\tau) = \sigma^2 e^{-1/\ell^2} \Bigl[I_0(1/\ell^2) + 2 \sum_{j=1}^\infty I_j(1/\ell^2) \cos(j\,\omega_0 \tau)\Bigr],

with :math:\omega_0 = 2\pi/T. Truncating to J = n_harmonics cosines gives a deterministic state-space model whose state collects a degenerate 1-D constant block (the :math:j=0 DC term) and J rotation blocks, one per harmonic. Total state dimension is :math:1 + 2J, L = 0, Q_c = 0 (no driving noise), and :math:P_\infty is block-diagonal with entries

.. math:: q_0 = \sigma^2 e^{-1/\ell^2} I_0(1/\ell^2),\qquad q_j = 2 \sigma^2 e^{-1/\ell^2} I_j(1/\ell^2)\quad (j \geq 1).

The scaled modified Bessel coefficients are computed by :func:_scaled_bessel_i_seq using a log-space Taylor-series accumulation (logsumexp over (j + 2k) log(x/2) - x - log(k!) - log((k+j)!)). For n_harmonics around 7 the truncation matches the dense MacKay periodic kernel to better than 1e-6 across the typical hyperparameter regime.

References

Solin & Sarkka (2014), Explicit Link Between Periodic Covariance Functions and State Space Models, AISTATS.

Source code in src/pyrox/gp/_sde_kernels.py
class PeriodicSDE(SDEKernel):
    r"""Periodic kernel in state-space form via Fourier-series truncation.

    The MacKay periodic kernel
    :math:`k(\tau) = \sigma^2 \exp\!\bigl(-2 \sin^2(\pi\tau/T)/\ell^2\bigr)`
    expands as

    .. math::
        k(\tau) = \sigma^2 e^{-1/\ell^2}
        \Bigl[I_0(1/\ell^2) + 2 \sum_{j=1}^\infty I_j(1/\ell^2)
        \cos(j\,\omega_0 \tau)\Bigr],

    with :math:`\omega_0 = 2\pi/T`. Truncating to ``J = n_harmonics``
    cosines gives a deterministic state-space model whose state collects
    a degenerate 1-D constant block (the :math:`j=0` DC term) and ``J``
    rotation blocks, one per harmonic. Total state dimension is
    :math:`1 + 2J`, ``L = 0``, ``Q_c = 0`` (no driving noise), and
    :math:`P_\infty` is block-diagonal with entries

    .. math::
        q_0 = \sigma^2 e^{-1/\ell^2} I_0(1/\ell^2),\qquad
        q_j = 2 \sigma^2 e^{-1/\ell^2} I_j(1/\ell^2)\quad (j \geq 1).

    The scaled modified Bessel coefficients are computed by
    :func:`_scaled_bessel_i_seq` using a log-space Taylor-series
    accumulation (``logsumexp`` over ``(j + 2k) log(x/2) - x - log(k!) -
    log((k+j)!)``). For ``n_harmonics`` around 7 the truncation matches
    the dense MacKay periodic kernel to better than 1e-6 across the
    typical hyperparameter regime.

    References:
        Solin & Sarkka (2014), *Explicit Link Between Periodic Covariance
        Functions and State Space Models*, AISTATS.
    """

    variance: Float[Array, ""]
    lengthscale: Float[Array, ""]
    period: Float[Array, ""]
    n_harmonics: int = eqx.field(static=True)

    def __init__(
        self,
        variance: float | Float[Array, ""] = 1.0,
        lengthscale: float | Float[Array, ""] = 1.0,
        period: float | Float[Array, ""] = 1.0,
        n_harmonics: int = 7,
    ) -> None:
        if n_harmonics < 1:
            raise ValueError(
                f"PeriodicSDE requires n_harmonics >= 1, got {n_harmonics!r}"
            )
        # Eager positivity checks for concrete (non-traced) Python scalar
        # inputs. ``variance``/``lengthscale <= 0`` would corrupt ``P_inf``
        # (non-PSD or NaN through the Bessel coefficients); ``period <= 0``
        # would divide-by-zero in ``omega0 = 2*pi/T``.
        if isinstance(variance, (int, float)) and variance <= 0:
            raise ValueError(f"variance must be positive, got {variance!r}")
        if isinstance(lengthscale, (int, float)) and lengthscale <= 0:
            raise ValueError(f"lengthscale must be positive, got {lengthscale!r}")
        if isinstance(period, (int, float)) and period <= 0:
            raise ValueError(f"period must be positive, got {period!r}")
        self.variance = jnp.asarray(variance, dtype=jnp.result_type(variance, 0.0))
        self.lengthscale = jnp.asarray(
            lengthscale, dtype=jnp.result_type(lengthscale, 0.0)
        )
        self.period = jnp.asarray(period, dtype=jnp.result_type(period, 0.0))
        self.n_harmonics = n_harmonics

    @property
    def state_dim(self) -> int:
        return 1 + 2 * self.n_harmonics

    def sde_params(
        self,
    ) -> tuple[
        Float[Array, "d d"],
        Float[Array, "d 1"],
        Float[Array, "1 d"],
        Float[Array, "1 1"],
        Float[Array, "d d"],
    ]:
        J = self.n_harmonics
        d = 1 + 2 * J
        sigma2 = self.variance
        ell = self.lengthscale
        T = self.period
        omega0 = 2.0 * jnp.pi / T

        # Rotation blocks F_j = [[0, -j*omega0], [j*omega0, 0]] for j = 1..J.
        # Plus a 1x1 zero block at the top for the j=0 DC mode.
        blocks = [jnp.zeros((1, 1), dtype=ell.dtype)]
        for j in range(1, J + 1):
            wj = jnp.asarray(j, dtype=ell.dtype) * omega0
            zero = jnp.zeros_like(wj)
            blocks.append(jnp.stack([jnp.stack([zero, -wj]), jnp.stack([wj, zero])]))
        F = jsl.block_diag(*blocks)

        # H reads out the cosine ("first") coordinate of each block.
        # Layout: [DC, cos_1, sin_1, cos_2, sin_2, ..., cos_J, sin_J].
        H_entries = jnp.zeros(d, dtype=ell.dtype)
        idx = jnp.array([0] + [1 + 2 * (j - 1) for j in range(1, J + 1)])
        H_entries = H_entries.at[idx].set(1.0)
        H = H_entries[None, :]

        # No driving noise: L = 0, Q_c = 0.
        L = jnp.zeros((d, 1), dtype=ell.dtype)
        Q_c = jnp.zeros((1, 1), dtype=ell.dtype)

        # Bessel-weighted P_inf:
        #   q_0 = sigma^2 * i0e(1/ell^2),  q_j = 2*sigma^2 * i_j_e(1/ell^2).
        x = 1.0 / (ell * ell)
        i_seq = _scaled_bessel_i_seq(x, j_max=J)  # (J+1,)
        q_vals = jnp.concatenate(
            [
                sigma2 * i_seq[0:1],
                2.0 * sigma2 * i_seq[1:],
            ]
        )

        # Diagonal entries of P_inf: q_0 (1 entry), then q_j repeated twice
        # for the (cos, sin) coordinates of each rotation block.
        diag = jnp.concatenate(
            [
                q_vals[0:1],
                jnp.repeat(q_vals[1:], 2),
            ]
        )
        P_inf = jnp.diag(diag)
        return F, L, H, Q_c, P_inf

    def discretise(
        self,
        dt: Float[Array, " N"],
    ) -> tuple[Float[Array, "N d d"], Float[Array, "N d d"]]:
        r"""Closed-form discretisation: harmonic block rotations, ``Q_k = 0``.

        The drift ``F`` is exactly block-diagonal with ``(0)`` (DC mode)
        and ``J`` skew-symmetric ``2x2`` rotation generators
        :math:`F_j = \mathrm{skew}(j\,\omega_0)`. The matrix exponential
        of each block has a closed form (identity for the DC mode and a
        2-D rotation for each harmonic), and ``Q_k`` vanishes identically
        because the diffusion is zero. Using the closed form avoids the
        float32 ``expm`` accumulation that affects ``CosineSDE`` for
        large ``j * omega_0 * dt``.
        """
        J = self.n_harmonics
        d = 1 + 2 * J
        T = self.period
        omega0 = 2.0 * jnp.pi / T

        dt = jnp.asarray(dt)
        n = dt.shape[0]
        # Per-harmonic frequencies (J,)
        js = jnp.arange(1, J + 1, dtype=dt.dtype)
        omegas = js * omega0  # (J,)
        # Angles per (step, harmonic) pair: (n, J)
        theta = dt[:, None] * omegas[None, :]
        c = jnp.cos(theta)  # (n, J)
        s = jnp.sin(theta)  # (n, J)

        A = jnp.zeros((n, d, d), dtype=dt.dtype)
        # DC mode: identity in the leading 1x1 block.
        A = A.at[:, 0, 0].set(jnp.ones((n,), dtype=dt.dtype))
        # Harmonic rotation blocks. Block j occupies indices
        # ``[1 + 2(j-1) : 1 + 2j]`` (cos, sin coordinates).
        idx_cos = 1 + 2 * jnp.arange(J)
        idx_sin = idx_cos + 1
        # A[:, cos, cos] = cos, A[:, cos, sin] = -sin,
        # A[:, sin, cos] = sin, A[:, sin, sin] = cos.
        A = A.at[:, idx_cos, idx_cos].set(c)
        A = A.at[:, idx_cos, idx_sin].set(-s)
        A = A.at[:, idx_sin, idx_cos].set(s)
        A = A.at[:, idx_sin, idx_sin].set(c)

        Q = jnp.zeros((n, d, d), dtype=dt.dtype)
        return A, Q

discretise(dt)

Closed-form discretisation: harmonic block rotations, Q_k = 0.

The drift F is exactly block-diagonal with (0) (DC mode) and J skew-symmetric 2x2 rotation generators :math:F_j = \mathrm{skew}(j\,\omega_0). The matrix exponential of each block has a closed form (identity for the DC mode and a 2-D rotation for each harmonic), and Q_k vanishes identically because the diffusion is zero. Using the closed form avoids the float32 expm accumulation that affects CosineSDE for large j * omega_0 * dt.

Source code in src/pyrox/gp/_sde_kernels.py
def discretise(
    self,
    dt: Float[Array, " N"],
) -> tuple[Float[Array, "N d d"], Float[Array, "N d d"]]:
    r"""Closed-form discretisation: harmonic block rotations, ``Q_k = 0``.

    The drift ``F`` is exactly block-diagonal with ``(0)`` (DC mode)
    and ``J`` skew-symmetric ``2x2`` rotation generators
    :math:`F_j = \mathrm{skew}(j\,\omega_0)`. The matrix exponential
    of each block has a closed form (identity for the DC mode and a
    2-D rotation for each harmonic), and ``Q_k`` vanishes identically
    because the diffusion is zero. Using the closed form avoids the
    float32 ``expm`` accumulation that affects ``CosineSDE`` for
    large ``j * omega_0 * dt``.
    """
    J = self.n_harmonics
    d = 1 + 2 * J
    T = self.period
    omega0 = 2.0 * jnp.pi / T

    dt = jnp.asarray(dt)
    n = dt.shape[0]
    # Per-harmonic frequencies (J,)
    js = jnp.arange(1, J + 1, dtype=dt.dtype)
    omegas = js * omega0  # (J,)
    # Angles per (step, harmonic) pair: (n, J)
    theta = dt[:, None] * omegas[None, :]
    c = jnp.cos(theta)  # (n, J)
    s = jnp.sin(theta)  # (n, J)

    A = jnp.zeros((n, d, d), dtype=dt.dtype)
    # DC mode: identity in the leading 1x1 block.
    A = A.at[:, 0, 0].set(jnp.ones((n,), dtype=dt.dtype))
    # Harmonic rotation blocks. Block j occupies indices
    # ``[1 + 2(j-1) : 1 + 2j]`` (cos, sin coordinates).
    idx_cos = 1 + 2 * jnp.arange(J)
    idx_sin = idx_cos + 1
    # A[:, cos, cos] = cos, A[:, cos, sin] = -sin,
    # A[:, sin, cos] = sin, A[:, sin, sin] = cos.
    A = A.at[:, idx_cos, idx_cos].set(c)
    A = A.at[:, idx_cos, idx_sin].set(-s)
    A = A.at[:, idx_sin, idx_cos].set(s)
    A = A.at[:, idx_sin, idx_sin].set(c)

    Q = jnp.zeros((n, d, d), dtype=dt.dtype)
    return A, Q

pyrox.gp.SumSDE

Bases: SDEKernel

Sum of SDE kernels via block-diagonal state-space composition.

For :math:k(\tau) = \sum_i k_i(\tau), the SDE is the block-diagonal concatenation of each component:

.. math:: F = \mathrm{blkdiag}(F_1, \dots, F_K),\quad L = \mathrm{blkdiag}(L_1, \dots, L_K),\quad Q_c = \mathrm{blkdiag}(Q_{c,1}, \dots, Q_{c,K}),

.. math:: H = [H_1, \dots, H_K],\quad P_\infty = \mathrm{blkdiag}(P_{\infty,1}, \dots, P_{\infty,K}).

Total state dimension is :math:\sum_i d_i. Components with disjoint state spaces evolve independently.

Source code in src/pyrox/gp/_sde_kernels.py
class SumSDE(SDEKernel):
    r"""Sum of SDE kernels via block-diagonal state-space composition.

    For :math:`k(\tau) = \sum_i k_i(\tau)`, the SDE is the block-diagonal
    concatenation of each component:

    .. math::
        F = \mathrm{blkdiag}(F_1, \dots, F_K),\quad
        L = \mathrm{blkdiag}(L_1, \dots, L_K),\quad
        Q_c = \mathrm{blkdiag}(Q_{c,1}, \dots, Q_{c,K}),

    .. math::
        H = [H_1, \dots, H_K],\quad
        P_\infty = \mathrm{blkdiag}(P_{\infty,1}, \dots, P_{\infty,K}).

    Total state dimension is :math:`\sum_i d_i`. Components with disjoint
    state spaces evolve independently.
    """

    components: tuple[SDEKernel, ...]

    def __init__(self, components: tuple[SDEKernel, ...]) -> None:
        if len(components) < 1:
            raise ValueError("SumSDE requires at least one component.")
        self.components = tuple(components)

    @property
    def state_dim(self) -> int:
        return sum(c.state_dim for c in self.components)

    def sde_params(
        self,
    ) -> tuple[
        Float[Array, "d d"],
        Float[Array, "d s"],
        Float[Array, "1 d"],
        Float[Array, "s s"],
        Float[Array, "d d"],
    ]:
        params = [c.sde_params() for c in self.components]
        Fs, Ls, Hs, Qcs, Ps = zip(*params, strict=True)
        F = jsl.block_diag(*Fs)
        L = jsl.block_diag(*Ls)
        H = jnp.concatenate(Hs, axis=1)
        Q_c = jsl.block_diag(*Qcs)
        P_inf = jsl.block_diag(*Ps)
        return F, L, H, Q_c, P_inf

pyrox.gp.ProductSDE

Bases: SDEKernel

Product of two SDE kernels via Kronecker composition.

For :math:k(\tau) = k_1(\tau)\,k_2(\tau), the joint SDE has Kronecker-sum drift and Kronecker-product readout / stationary covariance:

.. math:: F = F_1 \otimes I_{d_2} + I_{d_1} \otimes F_2,\quad H = H_1 \otimes H_2,\quad P_\infty = P_{\infty,1} \otimes P_{\infty,2}.

The diffusion is not a simple Kronecker product. Substituting into the Lyapunov equation yields

.. math:: L Q_c L^\top = (L_1 Q_{c,1} L_1^\top) \otimes P_{\infty,2} + P_{\infty,1} \otimes (L_2 Q_{c,2} L_2^\top).

For simplicity we set :math:L = I_{d_1 d_2} and store the right-hand side as Q_c directly. Total state dimension is :math:d_1 \cdot d_2.

Source code in src/pyrox/gp/_sde_kernels.py
class ProductSDE(SDEKernel):
    r"""Product of two SDE kernels via Kronecker composition.

    For :math:`k(\tau) = k_1(\tau)\,k_2(\tau)`, the joint SDE has
    Kronecker-sum drift and Kronecker-product readout / stationary
    covariance:

    .. math::
        F = F_1 \otimes I_{d_2} + I_{d_1} \otimes F_2,\quad
        H = H_1 \otimes H_2,\quad
        P_\infty = P_{\infty,1} \otimes P_{\infty,2}.

    The diffusion is *not* a simple Kronecker product. Substituting into
    the Lyapunov equation yields

    .. math::
        L Q_c L^\top = (L_1 Q_{c,1} L_1^\top) \otimes P_{\infty,2}
        + P_{\infty,1} \otimes (L_2 Q_{c,2} L_2^\top).

    For simplicity we set :math:`L = I_{d_1 d_2}` and store the right-hand
    side as ``Q_c`` directly. Total state dimension is :math:`d_1 \cdot d_2`.
    """

    left: SDEKernel
    right: SDEKernel

    def __init__(self, left: SDEKernel, right: SDEKernel) -> None:
        self.left = left
        self.right = right

    @property
    def state_dim(self) -> int:
        return self.left.state_dim * self.right.state_dim

    def sde_params(
        self,
    ) -> tuple[
        Float[Array, "d d"],
        Float[Array, "d d"],
        Float[Array, "1 d"],
        Float[Array, "d d"],
        Float[Array, "d d"],
    ]:
        F1, L1, H1, Qc1, P1 = self.left.sde_params()
        F2, L2, H2, Qc2, P2 = self.right.sde_params()
        d1, d2 = F1.shape[0], F2.shape[0]
        I1 = jnp.eye(d1, dtype=F1.dtype)
        I2 = jnp.eye(d2, dtype=F2.dtype)
        F = jnp.kron(F1, I2) + jnp.kron(I1, F2)
        H = jnp.kron(H1, H2)
        P_inf = jnp.kron(P1, P2)
        D1 = L1 @ Qc1 @ L1.T  # (d1, d1)
        D2 = L2 @ Qc2 @ L2.T  # (d2, d2)
        Q_c = jnp.kron(D1, P2) + jnp.kron(P1, D2)
        L = jnp.eye(d1 * d2, dtype=F.dtype)
        return F, L, H, Q_c, P_inf

pyrox.gp.QuasiPeriodicSDE

Bases: ProductSDE

Quasi-periodic kernel: :math:k(\tau) = k_{\rm Mat}(\tau)\,k_{\rm Per}(\tau).

A thin documented subclass of :class:ProductSDE that captures the standard Matern :math:\times Periodic decomposition used for modulated periodic signals (stellar light curves, modulated seasonal patterns). The Matern envelope sets the timescale on which the amplitude drifts; the periodic factor sets the cycle.

Example

import jax.numpy as jnp qp = QuasiPeriodicSDE( ... MaternSDE(variance=1.0, lengthscale=2.0, order=1), ... PeriodicSDE(variance=1.0, lengthscale=1.0, period=1.0, n_harmonics=5), ... ) qp.state_dim 22

References

Sarkka & Solin (2019), Applied Stochastic Differential Equations, Sec. 12.3; Wilkinson et al. (2021), BayesNewton.

Source code in src/pyrox/gp/_sde_kernels.py
class QuasiPeriodicSDE(ProductSDE):
    r"""Quasi-periodic kernel: :math:`k(\tau) = k_{\rm Mat}(\tau)\,k_{\rm Per}(\tau)`.

    A thin documented subclass of :class:`ProductSDE` that captures the
    standard Matern :math:`\times` Periodic decomposition used for
    modulated periodic signals (stellar light curves, modulated seasonal
    patterns). The Matern envelope sets the timescale on which the
    amplitude drifts; the periodic factor sets the cycle.

    Example:
        >>> import jax.numpy as jnp
        >>> qp = QuasiPeriodicSDE(
        ...     MaternSDE(variance=1.0, lengthscale=2.0, order=1),
        ...     PeriodicSDE(variance=1.0, lengthscale=1.0, period=1.0, n_harmonics=5),
        ... )
        >>> qp.state_dim
        22

    References:
        Sarkka & Solin (2019), *Applied Stochastic Differential Equations*,
        Sec. 12.3; Wilkinson et al. (2021), *BayesNewton*.
    """

    def __init__(self, matern: SDEKernel, periodic: SDEKernel) -> None:
        super().__init__(left=matern, right=periodic)

Markov GP — Kalman / RTS workflow

MarkovGPPrior consumes any SDEKernel over a sorted 1-D grid and gives O(N d^3) marginal likelihood (forward Kalman filter) and posterior smoothing (backward RTS), where d is the SDE state dimension. Use it for temporal GP regression / forecasting when the training grid lives on a single time axis. Predictions at arbitrary test times — including forecasting, backcasting, and within-window interpolation — re-run the filter+smoother over the merged grid with the test points masked out of the update step.

import jax.numpy as jnp
from pyrox.gp import MaternSDE, MarkovGPPrior, markov_gp_factor

times = jnp.linspace(0.0, 5.0, 200)
y     = jnp.sin(times) + 0.05 * jnp.cos(7.0 * times)

prior = MarkovGPPrior(
    MaternSDE(variance=1.0, lengthscale=0.5, order=1),  # Matern-3/2
    times,
)
log_marg = prior.log_marginal(y, jnp.asarray(0.01))     # Kalman forward
cond     = prior.condition(y, jnp.asarray(0.01))        # filter + RTS smoother
mean, var = cond.predict(jnp.linspace(-0.5, 6.0, 50))   # arbitrary test times

Inside a NumPyro model, swap gp_factor for markov_gp_factor:

import jax.numpy as jnp
import numpyro
from numpyro import distributions as dist
from pyrox.gp import MarkovGPPrior, MaternSDE, markov_gp_factor

def temporal_model(times, y):
    sigma2 = numpyro.sample("variance",  dist.LogNormal(0.0, 1.0))
    ell    = numpyro.sample("lengthscale", dist.LogNormal(0.0, 1.0))
    sde    = MaternSDE(variance=sigma2, lengthscale=ell, order=1)
    prior  = MarkovGPPrior(sde, times)
    markov_gp_factor("obs", prior, y, jnp.array(0.01))

Currently scoped to Gaussian-likelihood regression on a single time axis. Non-Gaussian likelihoods on top of the Markov path (CVI, EP) and spatio-temporal Markov priors land in later waves.

pyrox.gp.MarkovGPPrior

Bases: Module

Linear-time temporal GP prior over a sorted 1-D grid.

Wraps any :class:pyrox.gp.SDEKernel (e.g. :class:pyrox.gp.MaternSDE, :class:pyrox.gp.SumSDE, :class:pyrox.gp.PeriodicSDE) to give Kalman filtering for the marginal log-likelihood and RTS smoothing for the posterior on the training grid. Supports an optional mean function and a small observation-noise floor for numerical stability.

Attributes:

Name Type Description
sde_kernel SDEKernel

Any :class:SDEKernel. Provides (F, L, H, Q_c, P_inf) via sde_params() and the discrete transition tuple via discretise(dt).

times Float[Array, ' N']

Sorted, strictly increasing observation times of shape (N,). Concrete (non-traced) times arrays are validated for monotonicity at construction time; under :func:jax.jit / SVI / MCMC the input is a tracer and the check is silently skipped — callers must guarantee monotonicity in that case.

mean_fn Callable[[Float[Array, ' N']], Float[Array, ' N']] | None

Optional callable mapping times -> (N,) mean values. Defaults to the zero mean. The mean is subtracted from observations before filtering and added back at predict time.

obs_noise_floor float

Small extra diagonal added to the observation variance R = noise_var + obs_noise_floor for stability when noise_var is near zero. Defaults to 0.0.

Examples:

>>> import jax.numpy as jnp
>>> from pyrox.gp import MaternSDE, MarkovGPPrior
>>> times = jnp.linspace(0.0, 5.0, 50)
>>> sde = MaternSDE(variance=1.0, lengthscale=0.5, order=1)
>>> prior = MarkovGPPrior(sde, times)
>>> y = jnp.sin(times) + 0.05 * jnp.cos(3.0 * times)
>>> log_marg = prior.log_marginal(y, jnp.asarray(0.01))
Notes

The solver-strategy plumbing used by :class:pyrox.gp.GPPrior does not apply here — Kalman filtering is its own linear-algebra path and does not factor through gaussx.AbstractSolverStrategy.

Source code in src/pyrox/gp/_markov.py
class MarkovGPPrior(eqx.Module):
    r"""Linear-time temporal GP prior over a sorted 1-D grid.

    Wraps any :class:`pyrox.gp.SDEKernel` (e.g. :class:`pyrox.gp.MaternSDE`,
    :class:`pyrox.gp.SumSDE`, :class:`pyrox.gp.PeriodicSDE`) to give Kalman
    filtering for the marginal log-likelihood and RTS smoothing for the
    posterior on the training grid. Supports an optional mean function and a
    small observation-noise floor for numerical stability.

    Attributes:
        sde_kernel: Any :class:`SDEKernel`. Provides ``(F, L, H, Q_c, P_inf)``
            via ``sde_params()`` and the discrete transition tuple via
            ``discretise(dt)``.
        times: Sorted, strictly increasing observation times of shape
            ``(N,)``. Concrete (non-traced) ``times`` arrays are validated
            for monotonicity at construction time; under :func:`jax.jit` /
            SVI / MCMC the input is a tracer and the check is silently
            skipped — callers must guarantee monotonicity in that case.
        mean_fn: Optional callable mapping ``times -> (N,)`` mean values.
            Defaults to the zero mean. The mean is subtracted from
            observations before filtering and added back at predict time.
        obs_noise_floor: Small extra diagonal added to the observation
            variance ``R = noise_var + obs_noise_floor`` for stability when
            ``noise_var`` is near zero. Defaults to ``0.0``.

    Examples:
        >>> import jax.numpy as jnp
        >>> from pyrox.gp import MaternSDE, MarkovGPPrior
        >>> times = jnp.linspace(0.0, 5.0, 50)
        >>> sde = MaternSDE(variance=1.0, lengthscale=0.5, order=1)
        >>> prior = MarkovGPPrior(sde, times)
        >>> y = jnp.sin(times) + 0.05 * jnp.cos(3.0 * times)
        >>> log_marg = prior.log_marginal(y, jnp.asarray(0.01))

    Notes:
        The solver-strategy plumbing used by :class:`pyrox.gp.GPPrior` does
        not apply here — Kalman filtering is its own linear-algebra path
        and does not factor through ``gaussx.AbstractSolverStrategy``.
    """

    sde_kernel: SDEKernel
    times: Float[Array, " N"]
    mean_fn: Callable[[Float[Array, " N"]], Float[Array, " N"]] | None = None
    obs_noise_floor: float = eqx.field(static=True, default=0.0)

    def __init__(
        self,
        sde_kernel: SDEKernel,
        times: Float[Array, " N"],
        mean_fn: Callable[[Float[Array, " N"]], Float[Array, " N"]] | None = None,
        obs_noise_floor: float = 0.0,
    ) -> None:
        if obs_noise_floor < 0:
            raise ValueError(
                f"obs_noise_floor must be non-negative, got {obs_noise_floor!r}"
            )
        times_arr = jnp.asarray(times, dtype=jnp.result_type(times, 0.0))
        if times_arr.ndim != 1:
            raise ValueError(f"times must be 1-D, got shape {tuple(times_arr.shape)!r}")
        # Eager monotonicity check for concrete (non-traced) inputs only.
        # Under ``jax.jit`` / SVI / similar transforms ``times`` may arrive as
        # a tracer; the ``bool`` conversion would raise, so we silence that
        # path and let downstream Kalman steps trust the contract.
        if times_arr.shape[0] >= 2:
            try:
                if not bool(jnp.all(jnp.diff(times_arr) > 0)):
                    raise ValueError("times must be strictly increasing")
            except jax.errors.TracerBoolConversionError:
                pass
        self.sde_kernel = sde_kernel
        self.times = times_arr
        self.mean_fn = mean_fn
        self.obs_noise_floor = float(obs_noise_floor)

    @property
    def state_dim(self) -> int:
        """SDE state dimension :math:`d` for this kernel."""
        return self.sde_kernel.state_dim

    def mean(self, times: Float[Array, " M"]) -> Float[Array, " M"]:
        """Evaluate the mean function at ``times``; zero by default."""
        if self.mean_fn is None:
            return jnp.zeros_like(times)
        return self.mean_fn(times)

    def _residual(self, y: Float[Array, " N"]) -> Float[Array, " N"]:
        return y - self.mean(self.times)

    def _R(self, noise_var: Float[Array, ""]) -> Float[Array, ""]:
        return jnp.asarray(noise_var) + jnp.asarray(self.obs_noise_floor)

    def filter(
        self,
        y: Float[Array, " N"],
        noise_var: Float[Array, ""],
    ) -> tuple[
        Float[Array, "N d"],
        Float[Array, "N d d"],
        Float[Array, "N d"],
        Float[Array, "N d d"],
        Float[Array, ""],
    ]:
        """Run the forward Kalman filter on the training grid.

        Returns:
            Tuple ``(m_pred, P_pred, m_filt, P_filt, log_marginal)`` where
            each ``*_pred`` / ``*_filt`` is shaped ``(N, d)`` or
            ``(N, d, d)`` and ``log_marginal`` is the scalar log-likelihood
            ``log p(y | theta)``.
        """
        F, _L, H, _Qc, P_inf = self.sde_kernel.sde_params()
        dt_full = _build_dt_full(self.times)
        A_seq, Q_seq = self.sde_kernel.discretise(dt_full)
        residual = self._residual(y)
        mask = jnp.ones_like(self.times)
        R_seq = jnp.broadcast_to(self._R(noise_var), self.times.shape)
        return _kalman_filter(F, H, P_inf, A_seq, Q_seq, residual, mask, R_seq)

    def log_marginal(
        self,
        y: Float[Array, " N"],
        noise_var: Float[Array, ""],
    ) -> Float[Array, ""]:
        r"""Marginal log-likelihood ``log p(y | theta)`` via Kalman filtering."""
        *_, log_marg = self.filter(y, noise_var)
        return log_marg

    def smooth(
        self,
        y: Float[Array, " N"],
        noise_var: Float[Array, ""],
    ) -> tuple[Float[Array, "N d"], Float[Array, "N d d"], Float[Array, ""]]:
        """Run filter + RTS smoother on the training grid.

        Returns ``(m_smooth, P_smooth, log_marginal)`` over the training
        times.
        """
        F, _L, H, _Qc, P_inf = self.sde_kernel.sde_params()
        dt_full = _build_dt_full(self.times)
        A_seq, Q_seq = self.sde_kernel.discretise(dt_full)
        residual = self._residual(y)
        mask = jnp.ones_like(self.times)
        R_seq = jnp.broadcast_to(self._R(noise_var), self.times.shape)
        m_pred, P_pred, m_filt, P_filt, log_marg = _kalman_filter(
            F, H, P_inf, A_seq, Q_seq, residual, mask, R_seq
        )
        m_smooth, P_smooth = _rts_smoother(m_pred, P_pred, m_filt, P_filt, A_seq)
        return m_smooth, P_smooth, log_marg

    def condition(
        self,
        y: Float[Array, " N"],
        noise_var: Float[Array, ""],
    ) -> ConditionedMarkovGP:
        """Condition on Gaussian-likelihood observations via filter + smoother."""
        m_smooth, P_smooth, log_marg = self.smooth(y, noise_var)
        return ConditionedMarkovGP(  # ty: ignore[invalid-return-type]
            prior=self,
            y=y,
            noise_var=jnp.asarray(noise_var),
            smoothed_means=m_smooth,
            smoothed_covs=P_smooth,
            log_marginal=log_marg,
        )

    def condition_nongauss(
        self,
        likelihood: Likelihood,
        y: Float[Array, " N"],
        *,
        strategy: _NonGaussMarkovStrategy,
    ) -> NonGaussConditionedMarkovGP:
        """Condition on a non-Gaussian likelihood via a site-based strategy.

        Convenience that forwards to ``strategy.fit(self, likelihood, y)``.
        Pick any of the Markov-aware site-based strategies in
        :mod:`pyrox.gp._inference_nongauss_markov`:
        :class:`pyrox.gp.LaplaceMarkovInference`,
        :class:`pyrox.gp.GaussNewtonMarkovInference`,
        :class:`pyrox.gp.PosteriorLinearizationMarkov`, or
        :class:`pyrox.gp.ExpectationPropagationMarkov`. Returns a
        :class:`pyrox.gp.NonGaussConditionedMarkovGP` with the same
        ``predict`` API as the Gaussian-likelihood
        :class:`ConditionedMarkovGP`.
        """
        return strategy.fit(self, likelihood, y)

    def log_prob(self, f: Float[Array, " N"]) -> Float[Array, ""]:
        r"""Log density of an exact-state path :math:`f(t_n) = H x_n` under the prior.

        Evaluates ``log N(f | mu(times), K_NN)`` where ``K_NN`` is the dense
        Gram of the kernel encoded by ``sde_kernel`` on ``self.times``.
        Computes the dense covariance via ``H exp(F |t_i - t_j|) P_inf H^T``
        — one ``expm`` per pairwise lag, costing :math:`O(N^2 d^3)` for the
        Gram plus :math:`O(N^3)` for the Cholesky solve — intended for
        sanity checks and small-grid use rather than scalable inference.
        For training, prefer :meth:`log_marginal`.
        """
        F, _L, H, _Qc, P_inf = self.sde_kernel.sde_params()
        diffs = jnp.abs(self.times[:, None] - self.times[None, :])
        # Vectorise H exp(F |dt|) P_inf H^T over the (N, N) lag grid.
        flat_dt = diffs.reshape(-1)

        def _k(tau: Float[Array, ""]) -> Float[Array, ""]:
            return (H @ jax.scipy.linalg.expm(F * tau) @ P_inf @ H.T)[0, 0]

        K_flat = jax.vmap(_k)(flat_dt)
        K = K_flat.reshape(diffs.shape)
        K = 0.5 * (K + K.T)
        n = self.times.shape[0]
        K = K + 1e-8 * jnp.eye(n, dtype=K.dtype)
        residual = f - self.mean(self.times)
        L = jnp.linalg.cholesky(K)
        alpha = jax.scipy.linalg.solve_triangular(L, residual, lower=True)
        log_2pi = jnp.log(2.0 * jnp.pi).astype(K.dtype)
        return (
            -0.5 * (alpha @ alpha) - jnp.sum(jnp.log(jnp.diag(L))) - 0.5 * n * log_2pi
        )

state_dim property

SDE state dimension :math:d for this kernel.

condition(y, noise_var)

Condition on Gaussian-likelihood observations via filter + smoother.

Source code in src/pyrox/gp/_markov.py
def condition(
    self,
    y: Float[Array, " N"],
    noise_var: Float[Array, ""],
) -> ConditionedMarkovGP:
    """Condition on Gaussian-likelihood observations via filter + smoother."""
    m_smooth, P_smooth, log_marg = self.smooth(y, noise_var)
    return ConditionedMarkovGP(  # ty: ignore[invalid-return-type]
        prior=self,
        y=y,
        noise_var=jnp.asarray(noise_var),
        smoothed_means=m_smooth,
        smoothed_covs=P_smooth,
        log_marginal=log_marg,
    )

condition_nongauss(likelihood, y, *, strategy)

Condition on a non-Gaussian likelihood via a site-based strategy.

Convenience that forwards to strategy.fit(self, likelihood, y). Pick any of the Markov-aware site-based strategies in :mod:pyrox.gp._inference_nongauss_markov: :class:pyrox.gp.LaplaceMarkovInference, :class:pyrox.gp.GaussNewtonMarkovInference, :class:pyrox.gp.PosteriorLinearizationMarkov, or :class:pyrox.gp.ExpectationPropagationMarkov. Returns a :class:pyrox.gp.NonGaussConditionedMarkovGP with the same predict API as the Gaussian-likelihood :class:ConditionedMarkovGP.

Source code in src/pyrox/gp/_markov.py
def condition_nongauss(
    self,
    likelihood: Likelihood,
    y: Float[Array, " N"],
    *,
    strategy: _NonGaussMarkovStrategy,
) -> NonGaussConditionedMarkovGP:
    """Condition on a non-Gaussian likelihood via a site-based strategy.

    Convenience that forwards to ``strategy.fit(self, likelihood, y)``.
    Pick any of the Markov-aware site-based strategies in
    :mod:`pyrox.gp._inference_nongauss_markov`:
    :class:`pyrox.gp.LaplaceMarkovInference`,
    :class:`pyrox.gp.GaussNewtonMarkovInference`,
    :class:`pyrox.gp.PosteriorLinearizationMarkov`, or
    :class:`pyrox.gp.ExpectationPropagationMarkov`. Returns a
    :class:`pyrox.gp.NonGaussConditionedMarkovGP` with the same
    ``predict`` API as the Gaussian-likelihood
    :class:`ConditionedMarkovGP`.
    """
    return strategy.fit(self, likelihood, y)

filter(y, noise_var)

Run the forward Kalman filter on the training grid.

Returns:

Type Description
Float[Array, 'N d']

Tuple (m_pred, P_pred, m_filt, P_filt, log_marginal) where

Float[Array, 'N d d']

each *_pred / *_filt is shaped (N, d) or

Float[Array, 'N d']

(N, d, d) and log_marginal is the scalar log-likelihood

Float[Array, 'N d d']

log p(y | theta).

Source code in src/pyrox/gp/_markov.py
def filter(
    self,
    y: Float[Array, " N"],
    noise_var: Float[Array, ""],
) -> tuple[
    Float[Array, "N d"],
    Float[Array, "N d d"],
    Float[Array, "N d"],
    Float[Array, "N d d"],
    Float[Array, ""],
]:
    """Run the forward Kalman filter on the training grid.

    Returns:
        Tuple ``(m_pred, P_pred, m_filt, P_filt, log_marginal)`` where
        each ``*_pred`` / ``*_filt`` is shaped ``(N, d)`` or
        ``(N, d, d)`` and ``log_marginal`` is the scalar log-likelihood
        ``log p(y | theta)``.
    """
    F, _L, H, _Qc, P_inf = self.sde_kernel.sde_params()
    dt_full = _build_dt_full(self.times)
    A_seq, Q_seq = self.sde_kernel.discretise(dt_full)
    residual = self._residual(y)
    mask = jnp.ones_like(self.times)
    R_seq = jnp.broadcast_to(self._R(noise_var), self.times.shape)
    return _kalman_filter(F, H, P_inf, A_seq, Q_seq, residual, mask, R_seq)

log_marginal(y, noise_var)

Marginal log-likelihood log p(y | theta) via Kalman filtering.

Source code in src/pyrox/gp/_markov.py
def log_marginal(
    self,
    y: Float[Array, " N"],
    noise_var: Float[Array, ""],
) -> Float[Array, ""]:
    r"""Marginal log-likelihood ``log p(y | theta)`` via Kalman filtering."""
    *_, log_marg = self.filter(y, noise_var)
    return log_marg

log_prob(f)

Log density of an exact-state path :math:f(t_n) = H x_n under the prior.

Evaluates log N(f | mu(times), K_NN) where K_NN is the dense Gram of the kernel encoded by sde_kernel on self.times. Computes the dense covariance via H exp(F |t_i - t_j|) P_inf H^T — one expm per pairwise lag, costing :math:O(N^2 d^3) for the Gram plus :math:O(N^3) for the Cholesky solve — intended for sanity checks and small-grid use rather than scalable inference. For training, prefer :meth:log_marginal.

Source code in src/pyrox/gp/_markov.py
def log_prob(self, f: Float[Array, " N"]) -> Float[Array, ""]:
    r"""Log density of an exact-state path :math:`f(t_n) = H x_n` under the prior.

    Evaluates ``log N(f | mu(times), K_NN)`` where ``K_NN`` is the dense
    Gram of the kernel encoded by ``sde_kernel`` on ``self.times``.
    Computes the dense covariance via ``H exp(F |t_i - t_j|) P_inf H^T``
    — one ``expm`` per pairwise lag, costing :math:`O(N^2 d^3)` for the
    Gram plus :math:`O(N^3)` for the Cholesky solve — intended for
    sanity checks and small-grid use rather than scalable inference.
    For training, prefer :meth:`log_marginal`.
    """
    F, _L, H, _Qc, P_inf = self.sde_kernel.sde_params()
    diffs = jnp.abs(self.times[:, None] - self.times[None, :])
    # Vectorise H exp(F |dt|) P_inf H^T over the (N, N) lag grid.
    flat_dt = diffs.reshape(-1)

    def _k(tau: Float[Array, ""]) -> Float[Array, ""]:
        return (H @ jax.scipy.linalg.expm(F * tau) @ P_inf @ H.T)[0, 0]

    K_flat = jax.vmap(_k)(flat_dt)
    K = K_flat.reshape(diffs.shape)
    K = 0.5 * (K + K.T)
    n = self.times.shape[0]
    K = K + 1e-8 * jnp.eye(n, dtype=K.dtype)
    residual = f - self.mean(self.times)
    L = jnp.linalg.cholesky(K)
    alpha = jax.scipy.linalg.solve_triangular(L, residual, lower=True)
    log_2pi = jnp.log(2.0 * jnp.pi).astype(K.dtype)
    return (
        -0.5 * (alpha @ alpha) - jnp.sum(jnp.log(jnp.diag(L))) - 0.5 * n * log_2pi
    )

mean(times)

Evaluate the mean function at times; zero by default.

Source code in src/pyrox/gp/_markov.py
def mean(self, times: Float[Array, " M"]) -> Float[Array, " M"]:
    """Evaluate the mean function at ``times``; zero by default."""
    if self.mean_fn is None:
        return jnp.zeros_like(times)
    return self.mean_fn(times)

smooth(y, noise_var)

Run filter + RTS smoother on the training grid.

Returns (m_smooth, P_smooth, log_marginal) over the training times.

Source code in src/pyrox/gp/_markov.py
def smooth(
    self,
    y: Float[Array, " N"],
    noise_var: Float[Array, ""],
) -> tuple[Float[Array, "N d"], Float[Array, "N d d"], Float[Array, ""]]:
    """Run filter + RTS smoother on the training grid.

    Returns ``(m_smooth, P_smooth, log_marginal)`` over the training
    times.
    """
    F, _L, H, _Qc, P_inf = self.sde_kernel.sde_params()
    dt_full = _build_dt_full(self.times)
    A_seq, Q_seq = self.sde_kernel.discretise(dt_full)
    residual = self._residual(y)
    mask = jnp.ones_like(self.times)
    R_seq = jnp.broadcast_to(self._R(noise_var), self.times.shape)
    m_pred, P_pred, m_filt, P_filt, log_marg = _kalman_filter(
        F, H, P_inf, A_seq, Q_seq, residual, mask, R_seq
    )
    m_smooth, P_smooth = _rts_smoother(m_pred, P_pred, m_filt, P_filt, A_seq)
    return m_smooth, P_smooth, log_marg

pyrox.gp.ConditionedMarkovGP

Bases: Module

Markov GP conditioned on Gaussian-likelihood observations.

Holds the smoothed posterior on the training grid plus the marginal log-likelihood. Use :meth:predict for marginal posterior mean / variance at arbitrary test times.

Attributes:

Name Type Description
prior MarkovGPPrior

The originating :class:MarkovGPPrior.

y Float[Array, ' N']

Observations of shape (N,).

noise_var Float[Array, '']

Observation variance used for conditioning.

smoothed_means Float[Array, 'N d']

(N, d) smoothed state means at training times.

smoothed_covs Float[Array, 'N d d']

(N, d, d) smoothed state covariances at training times.

log_marginal Float[Array, '']

Scalar :math:\log p(y \mid \theta).

Source code in src/pyrox/gp/_markov.py
class ConditionedMarkovGP(eqx.Module):
    """Markov GP conditioned on Gaussian-likelihood observations.

    Holds the smoothed posterior on the training grid plus the marginal
    log-likelihood. Use :meth:`predict` for marginal posterior mean / variance
    at arbitrary test times.

    Attributes:
        prior: The originating :class:`MarkovGPPrior`.
        y: Observations of shape ``(N,)``.
        noise_var: Observation variance used for conditioning.
        smoothed_means: ``(N, d)`` smoothed state means at training times.
        smoothed_covs: ``(N, d, d)`` smoothed state covariances at training
            times.
        log_marginal: Scalar :math:`\\log p(y \\mid \\theta)`.
    """

    prior: MarkovGPPrior
    y: Float[Array, " N"]
    noise_var: Float[Array, ""]
    smoothed_means: Float[Array, "N d"]
    smoothed_covs: Float[Array, "N d d"]
    log_marginal: Float[Array, ""]

    def predict(
        self,
        t_star: Float[Array, " M"],
    ) -> tuple[Float[Array, " M"], Float[Array, " M"]]:
        r"""Predictive marginals ``(mean, var)`` at arbitrary test times.

        Implementation: re-run the filter+smoother over the merged grid
        ``sort(times \\cup t_star)`` with the test points masked out of the
        update step, then read off the smoothed marginals at the test
        positions via ``H @ m`` and ``H @ P @ H^T``. Cost is
        :math:`O((N + M)\\,d^3)`. Handles training-grid lookups, forecasting,
        backcasting, and within-window interpolation under one code path.
        """
        F, _L, H, _Qc, P_inf = self.prior.sde_kernel.sde_params()
        times = self.prior.times
        t_star = jnp.asarray(t_star)

        N = times.shape[0]
        M = t_star.shape[0]
        merged = jnp.concatenate([times, t_star], axis=0)
        # Stable sort so the relative ordering of identical times is preserved
        # (training point sorts before a duplicate test point, so the test
        # point still sees the observation update earlier in the grid).
        order = jnp.argsort(merged, stable=True)
        merged_sorted = merged[order]

        is_obs = jnp.concatenate(
            [jnp.ones(N, dtype=times.dtype), jnp.zeros(M, dtype=times.dtype)]
        )[order]
        residual_full = jnp.concatenate(
            [self.y - self.prior.mean(times), jnp.zeros(M, dtype=self.y.dtype)]
        )[order]

        dt_full = _build_dt_full(merged_sorted)
        A_seq, Q_seq = self.prior.sde_kernel.discretise(dt_full)
        R_seq = jnp.broadcast_to(self.prior._R(self.noise_var), merged_sorted.shape)
        m_pred, P_pred, m_filt, P_filt, _ = _kalman_filter(
            F, H, P_inf, A_seq, Q_seq, residual_full, is_obs, R_seq
        )
        m_smooth, P_smooth = _rts_smoother(m_pred, P_pred, m_filt, P_filt, A_seq)

        # Inverse permutation: position in the sorted grid for each original
        # entry, then slice off the trailing M test entries.
        inv_order = jnp.argsort(order, stable=True)
        test_positions = inv_order[N:]
        m_test_state = m_smooth[test_positions]  # (M, d)
        P_test_state = P_smooth[test_positions]  # (M, d, d)
        means = (m_test_state @ H.T)[:, 0] + self.prior.mean(t_star)
        # var = H P H^T per test point — vmap over axis 0
        vars_ = jax.vmap(lambda P: (H @ P @ H.T)[0, 0])(P_test_state)
        return means, vars_

predict(t_star)

Predictive marginals (mean, var) at arbitrary test times.

Implementation: re-run the filter+smoother over the merged grid sort(times \\cup t_star) with the test points masked out of the update step, then read off the smoothed marginals at the test positions via H @ m and H @ P @ H^T. Cost is :math:O((N + M)\\,d^3). Handles training-grid lookups, forecasting, backcasting, and within-window interpolation under one code path.

Source code in src/pyrox/gp/_markov.py
def predict(
    self,
    t_star: Float[Array, " M"],
) -> tuple[Float[Array, " M"], Float[Array, " M"]]:
    r"""Predictive marginals ``(mean, var)`` at arbitrary test times.

    Implementation: re-run the filter+smoother over the merged grid
    ``sort(times \\cup t_star)`` with the test points masked out of the
    update step, then read off the smoothed marginals at the test
    positions via ``H @ m`` and ``H @ P @ H^T``. Cost is
    :math:`O((N + M)\\,d^3)`. Handles training-grid lookups, forecasting,
    backcasting, and within-window interpolation under one code path.
    """
    F, _L, H, _Qc, P_inf = self.prior.sde_kernel.sde_params()
    times = self.prior.times
    t_star = jnp.asarray(t_star)

    N = times.shape[0]
    M = t_star.shape[0]
    merged = jnp.concatenate([times, t_star], axis=0)
    # Stable sort so the relative ordering of identical times is preserved
    # (training point sorts before a duplicate test point, so the test
    # point still sees the observation update earlier in the grid).
    order = jnp.argsort(merged, stable=True)
    merged_sorted = merged[order]

    is_obs = jnp.concatenate(
        [jnp.ones(N, dtype=times.dtype), jnp.zeros(M, dtype=times.dtype)]
    )[order]
    residual_full = jnp.concatenate(
        [self.y - self.prior.mean(times), jnp.zeros(M, dtype=self.y.dtype)]
    )[order]

    dt_full = _build_dt_full(merged_sorted)
    A_seq, Q_seq = self.prior.sde_kernel.discretise(dt_full)
    R_seq = jnp.broadcast_to(self.prior._R(self.noise_var), merged_sorted.shape)
    m_pred, P_pred, m_filt, P_filt, _ = _kalman_filter(
        F, H, P_inf, A_seq, Q_seq, residual_full, is_obs, R_seq
    )
    m_smooth, P_smooth = _rts_smoother(m_pred, P_pred, m_filt, P_filt, A_seq)

    # Inverse permutation: position in the sorted grid for each original
    # entry, then slice off the trailing M test entries.
    inv_order = jnp.argsort(order, stable=True)
    test_positions = inv_order[N:]
    m_test_state = m_smooth[test_positions]  # (M, d)
    P_test_state = P_smooth[test_positions]  # (M, d, d)
    means = (m_test_state @ H.T)[:, 0] + self.prior.mean(t_star)
    # var = H P H^T per test point — vmap over axis 0
    vars_ = jax.vmap(lambda P: (H @ P @ H.T)[0, 0])(P_test_state)
    return means, vars_

pyrox.gp.markov_gp_factor(name, prior, y, noise_var)

Register the collapsed Markov-GP marginal log-likelihood with NumPyro.

Computes log p(y | times, theta) via Kalman filtering and adds it as numpyro.factor(name, ...). Use this inside a NumPyro model for Gaussian-likelihood temporal GP regression — the latent function is marginalized analytically.

Source code in src/pyrox/gp/_markov.py
def markov_gp_factor(
    name: str,
    prior: MarkovGPPrior,
    y: Float[Array, " N"],
    noise_var: Float[Array, ""],
) -> None:
    """Register the collapsed Markov-GP marginal log-likelihood with NumPyro.

    Computes ``log p(y | times, theta)`` via Kalman filtering and adds it as
    ``numpyro.factor(name, ...)``. Use this inside a NumPyro model for
    Gaussian-likelihood temporal GP regression — the latent function is
    marginalized analytically.
    """
    numpyro.factor(name, prior.log_marginal(y, noise_var))

pyrox.gp.markov_gp_sample(name, prior)

Sample a latent function f at the prior's training times.

Registers a single numpyro.sample(name, MVN(mu, K)) site where K is the dense Gram derived from the SDE autocovariance H exp(F|tau|) P_inf H^T. This is the simple, dense path — use it when N is small. Scalable Markov-aware sample sites land in a later wave alongside non-Gaussian likelihood support.

Source code in src/pyrox/gp/_markov.py
def markov_gp_sample(
    name: str,
    prior: MarkovGPPrior,
) -> Float[Array, " N"]:
    """Sample a latent function ``f`` at the prior's training times.

    Registers a single ``numpyro.sample(name, MVN(mu, K))`` site where ``K``
    is the dense Gram derived from the SDE autocovariance
    ``H exp(F|tau|) P_inf H^T``. This is the simple, dense path — use it
    when ``N`` is small. Scalable Markov-aware sample sites land in a
    later wave alongside non-Gaussian likelihood support.
    """
    F, _L, H, _Qc, P_inf = prior.sde_kernel.sde_params()
    times = prior.times
    diffs = jnp.abs(times[:, None] - times[None, :])
    flat_dt = diffs.reshape(-1)

    def _k(tau: Float[Array, ""]) -> Float[Array, ""]:
        return (H @ jax.scipy.linalg.expm(F * tau) @ P_inf @ H.T)[0, 0]

    K = jax.vmap(_k)(flat_dt).reshape(diffs.shape)
    K = 0.5 * (K + K.T)
    n = times.shape[0]
    K = K + 1e-8 * jnp.eye(n, dtype=K.dtype)
    mu = prior.mean(times)
    return numpyro.sample(  # ty: ignore[invalid-return-type]
        name, dist.MultivariateNormal(mu, covariance_matrix=K)
    )

Component protocols

Abstract pyrox-local bases for the orthogonal component stack. Wave 2 ships only the contracts for Guide, Integrator, and Likelihood — concrete implementations land in later waves. Solver strategies live in gaussx._strategies.

pyrox.gp.Kernel

Bases: Module

Abstract base for GP covariance functions.

Subclasses implement :meth:__call__ returning the Gram matrix on a pair of input batches. :meth:gram and :meth:diag are convenience defaults that derive from :meth:__call__; structured subclasses (Kronecker, state-space, etc.) should override them for efficiency.

Source code in src/pyrox/gp/_protocols.py
class Kernel(eqx.Module):
    """Abstract base for GP covariance functions.

    Subclasses implement :meth:`__call__` returning the Gram matrix on a pair
    of input batches. :meth:`gram` and :meth:`diag` are convenience defaults
    that derive from :meth:`__call__`; structured subclasses (Kronecker,
    state-space, etc.) should override them for efficiency.
    """

    @abstractmethod
    def __call__(
        self,
        X1: Float[Array, "N1 D"],
        X2: Float[Array, "N2 D"],
    ) -> Float[Array, "N1 N2"]:
        raise NotImplementedError

    def gram(self, X: Float[Array, "N D"]) -> Float[Array, "N N"]:
        """Symmetric Gram matrix ``K(X, X)``."""
        return self(X, X)

    def diag(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
        """Diagonal of ``K(X, X)``.

        Default implementation extracts the diagonal of the full Gram. For
        stationary kernels with constant diagonal, override with a vectorized
        broadcast for the ``O(N)`` shortcut.
        """
        return jnp.diag(self(X, X))

diag(X)

Diagonal of K(X, X).

Default implementation extracts the diagonal of the full Gram. For stationary kernels with constant diagonal, override with a vectorized broadcast for the O(N) shortcut.

Source code in src/pyrox/gp/_protocols.py
def diag(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
    """Diagonal of ``K(X, X)``.

    Default implementation extracts the diagonal of the full Gram. For
    stationary kernels with constant diagonal, override with a vectorized
    broadcast for the ``O(N)`` shortcut.
    """
    return jnp.diag(self(X, X))

gram(X)

Symmetric Gram matrix K(X, X).

Source code in src/pyrox/gp/_protocols.py
def gram(self, X: Float[Array, "N D"]) -> Float[Array, "N N"]:
    """Symmetric Gram matrix ``K(X, X)``."""
    return self(X, X)

pyrox.gp.Guide

Bases: Module

Abstract base for variational posterior families.

Concrete guides (DeltaGuide, MeanFieldGuide, LowRankGuide, FullRankGuide, etc.) land in the dedicated guide waves (#28, #29). The whitening principle keeps optimization geometry well-conditioned — sample from a unit-scale latent and unwhiten with the prior Cholesky.

Two distinct entry points:

  • :meth:sample / :meth:log_prob — pure variational draws and densities. sample(self, key) returns a draw from q(f); log_prob(self, f) evaluates log q(f). Neither touches the NumPyro trace.
  • register(name, prior) (optional) — the NumPyro-integration hook invoked by :func:pyrox.gp.gp_sample when a guide is supplied. Use it to register a sample / param site (or compose one out of guide state) under name and return the latent function value. Concrete guides that participate in :func:gp_sample should implement this; the protocol leaves it unspecified so guides usable purely outside NumPyro stay valid.
Source code in src/pyrox/gp/_protocols.py
class Guide(eqx.Module):
    """Abstract base for variational posterior families.

    Concrete guides (``DeltaGuide``, ``MeanFieldGuide``, ``LowRankGuide``,
    ``FullRankGuide``, etc.) land in the dedicated guide waves (#28, #29).
    The whitening principle keeps optimization geometry well-conditioned —
    sample from a unit-scale latent and unwhiten with the prior Cholesky.

    Two distinct entry points:

    * :meth:`sample` / :meth:`log_prob` — pure variational draws and
      densities. ``sample(self, key)`` returns a draw from ``q(f)``;
      ``log_prob(self, f)`` evaluates ``log q(f)``. Neither touches the
      NumPyro trace.
    * ``register(name, prior)`` (optional) — the NumPyro-integration hook
      invoked by :func:`pyrox.gp.gp_sample` when a guide is supplied. Use
      it to register a sample / param site (or compose one out of guide
      state) under ``name`` and return the latent function value. Concrete
      guides that participate in :func:`gp_sample` should implement this;
      the protocol leaves it unspecified so guides usable purely outside
      NumPyro stay valid.
    """

    @abstractmethod
    def sample(self, key: Any) -> Float[Array, " ..."]:
        raise NotImplementedError

    @abstractmethod
    def log_prob(self, f: Float[Array, " ..."]) -> Float[Array, ""]:
        raise NotImplementedError

pyrox.gp.Integrator

Bases: Module

Abstract base for Gaussian-expectation integrators.

Computes :math:\mathbb{E}_{q(f)}[g(f)] where q(f) = N(mean, var). Concrete integrators (Gauss-Hermite, sigma-points, cubature, Taylor, Monte Carlo) land in later waves and may delegate to gaussx's quadrature primitives.

Source code in src/pyrox/gp/_protocols.py
class Integrator(eqx.Module):
    """Abstract base for Gaussian-expectation integrators.

    Computes :math:`\\mathbb{E}_{q(f)}[g(f)]` where ``q(f) = N(mean, var)``.
    Concrete integrators (Gauss-Hermite, sigma-points, cubature, Taylor,
    Monte Carlo) land in later waves and may delegate to ``gaussx``'s
    quadrature primitives.
    """

    @abstractmethod
    def integrate(
        self,
        fn: Callable[[Float[Array, " ..."]], Float[Array, " ..."]],
        mean: Float[Array, " ..."],
        var: Float[Array, " ..."],
    ) -> Float[Array, " ..."]:
        raise NotImplementedError

pyrox.gp.Likelihood

Bases: Module

Abstract base for observation models.

Implements the conditional p(y | f) and a default :meth:expected_log_prob that integrates over a Gaussian latent via an :class:Integrator. Concrete scalar-latent likelihoods (:class:GaussianLikelihood, :class:BernoulliLikelihood, :class:PoissonLikelihood, :class:StudentTLikelihood) and multi-latent ones (:class:SoftmaxLikelihood, :class:HeteroscedasticGaussianLikelihood) live in :mod:pyrox.gp._likelihoods.

Multi-latent likelihoods declare latent_dim: int as a static field (e.g. latent_dim = num_classes for softmax). Scalar likelihoods may omit the field; consumers should read getattr(lik, "latent_dim", 1).

Source code in src/pyrox/gp/_protocols.py
class Likelihood(eqx.Module):
    """Abstract base for observation models.

    Implements the conditional ``p(y | f)`` and a default
    :meth:`expected_log_prob` that integrates over a Gaussian latent via an
    :class:`Integrator`. Concrete scalar-latent likelihoods
    (:class:`GaussianLikelihood`, :class:`BernoulliLikelihood`,
    :class:`PoissonLikelihood`, :class:`StudentTLikelihood`) and
    multi-latent ones (:class:`SoftmaxLikelihood`,
    :class:`HeteroscedasticGaussianLikelihood`) live in
    :mod:`pyrox.gp._likelihoods`.

    Multi-latent likelihoods declare ``latent_dim: int`` as a static
    field (e.g. ``latent_dim = num_classes`` for softmax). Scalar
    likelihoods may omit the field; consumers should read
    ``getattr(lik, "latent_dim", 1)``.
    """

    @abstractmethod
    def log_prob(
        self,
        f: Float[Array, " ..."],
        y: Float[Array, " ..."],
    ) -> Float[Array, ""]:
        raise NotImplementedError

Math primitives

Pure JAX kernel functions. Stateless, differentiable, composable — (Array, ..., hyperparams) -> Gram. No NumPyro, no protocols.

Pure JAX kernel evaluation primitives — math definitions only.

Each function takes two input matrices X1 of shape (N1, D), X2 of shape (N2, D), hyperparameters as JAX arrays, and returns the (N1, N2) Gram matrix. All inputs are 2-D; callers that have 1-D arrays must add a trailing singleton dimension first.

These are the canonical closed-form math for each kernel — small, readable, and tutorial-facing. The companion scalable construction surface (numerically stable matrix assembly, mixed-precision accumulation, implicit/structured operators, batched matvec) lives in gaussx; see :func:gaussx.stable_rbf_kernel and :class:gaussx.ImplicitKernelOperator for the production path.

The composition helpers :func:kernel_add and :func:kernel_mul act on already-evaluated Gram matrices, not on callables. Higher-level :class:pyrox.gp.Kernel classes (Wave 2 Layer 1, see issue #20) compose callables and may opt in to gaussx's scalable variants when needed.

Index axes are named via :mod:einops (einsum / rearrange) rather than raw broadcasting so shape intent stays legible at the call site.

constant_kernel(X1, X2, variance)

Constant kernel.

.. math:: k(x, x') = \sigma^2

A rank-one kernel useful as a scalar offset additive component.

Parameters:

Name Type Description Default
X1 Float[Array, 'N1 D']

(N1, D) inputs.

required
X2 Float[Array, 'N2 D']

(N2, D) inputs.

required
variance Float[Array, '']

Scalar value.

required

Returns:

Type Description
Float[Array, 'N1 N2']

(N1, N2) Gram matrix filled with variance.

Source code in src/pyrox/gp/_src/kernels.py
def constant_kernel(
    X1: Float[Array, "N1 D"],
    X2: Float[Array, "N2 D"],
    variance: Float[Array, ""],
) -> Float[Array, "N1 N2"]:
    r"""Constant kernel.

    .. math::
        k(x, x') = \sigma^2

    A rank-one kernel useful as a scalar offset additive component.

    Args:
        X1: ``(N1, D)`` inputs.
        X2: ``(N2, D)`` inputs.
        variance: Scalar value.

    Returns:
        ``(N1, N2)`` Gram matrix filled with ``variance``.
    """
    return variance * jnp.ones((X1.shape[0], X2.shape[0]), dtype=X1.dtype)

cosine_kernel(X1, X2, variance, period)

Cosine kernel.

.. math:: k(x, x') = \sigma^2 \cos!\left( \frac{2 \pi |x - x'|}{p} \right)

Useful as a simple periodic building block alongside :func:periodic_kernel; unlike the Mackay form this one uses plain cosine of distance and can go negative.

Parameters:

Name Type Description Default
X1 Float[Array, 'N1 D']

(N1, D) inputs.

required
X2 Float[Array, 'N2 D']

(N2, D) inputs.

required
variance Float[Array, '']

Scalar signal variance.

required
period Float[Array, '']

Scalar period.

required

Returns:

Type Description
Float[Array, 'N1 N2']

(N1, N2) Gram matrix.

Source code in src/pyrox/gp/_src/kernels.py
def cosine_kernel(
    X1: Float[Array, "N1 D"],
    X2: Float[Array, "N2 D"],
    variance: Float[Array, ""],
    period: Float[Array, ""],
) -> Float[Array, "N1 N2"]:
    r"""Cosine kernel.

    .. math::
        k(x, x') = \sigma^2 \cos\!\left(
            \frac{2 \pi \|x - x'\|}{p}
        \right)

    Useful as a simple periodic building block alongside
    :func:`periodic_kernel`; unlike the Mackay form this one uses plain
    cosine of distance and can go negative.

    Args:
        X1: ``(N1, D)`` inputs.
        X2: ``(N2, D)`` inputs.
        variance: Scalar signal variance.
        period: Scalar period.

    Returns:
        ``(N1, N2)`` Gram matrix.
    """
    sq = _pairwise_sq_dist(X1, X2)
    # Jitter inside sqrt avoids NaN gradients at r = 0 (sqrt' is undefined).
    r = jnp.sqrt(jnp.clip(sq, min=1e-30))
    return variance * jnp.cos(2.0 * jnp.pi * r / period)

kernel_add(K1, K2)

Pointwise sum of two already-evaluated Gram matrices.

Source code in src/pyrox/gp/_src/kernels.py
def kernel_add(
    K1: Float[Array, "N1 N2"],
    K2: Float[Array, "N1 N2"],
) -> Float[Array, "N1 N2"]:
    """Pointwise sum of two already-evaluated Gram matrices."""
    return K1 + K2

kernel_mul(K1, K2)

Pointwise (Hadamard) product of two already-evaluated Gram matrices.

Source code in src/pyrox/gp/_src/kernels.py
def kernel_mul(
    K1: Float[Array, "N1 N2"],
    K2: Float[Array, "N1 N2"],
) -> Float[Array, "N1 N2"]:
    """Pointwise (Hadamard) product of two already-evaluated Gram matrices."""
    return K1 * K2

linear_kernel(X1, X2, variance, bias)

Linear kernel.

.. math:: k(x, x') = \sigma^2\, x^\top x' + b

Parameters:

Name Type Description Default
X1 Float[Array, 'N1 D']

(N1, D) inputs.

required
X2 Float[Array, 'N2 D']

(N2, D) inputs.

required
variance Float[Array, '']

Scalar variance multiplier on the dot product.

required
bias Float[Array, '']

Scalar additive bias.

required

Returns:

Type Description
Float[Array, 'N1 N2']

(N1, N2) Gram matrix.

Source code in src/pyrox/gp/_src/kernels.py
def linear_kernel(
    X1: Float[Array, "N1 D"],
    X2: Float[Array, "N2 D"],
    variance: Float[Array, ""],
    bias: Float[Array, ""],
) -> Float[Array, "N1 N2"]:
    r"""Linear kernel.

    .. math::
        k(x, x') = \sigma^2\, x^\top x' + b

    Args:
        X1: ``(N1, D)`` inputs.
        X2: ``(N2, D)`` inputs.
        variance: Scalar variance multiplier on the dot product.
        bias: Scalar additive bias.

    Returns:
        ``(N1, N2)`` Gram matrix.
    """
    return variance * einsum(X1, X2, "n1 d, n2 d -> n1 n2") + bias

matern_kernel(X1, X2, variance, lengthscale, nu)

Matern kernel with closed-form nu in {1/2, 3/2, 5/2}.

.. math:: k(x, x') = \sigma^2\, f_\nu(r / \ell), \qquad r = |x - x'|

Only the three common half-integer orders are supported because those admit closed-form expressions without Bessel evaluations. nu is a static Python float (not a JAX array) so the branch specializes at trace time.

Parameters:

Name Type Description Default
X1 Float[Array, 'N1 D']

(N1, D) inputs.

required
X2 Float[Array, 'N2 D']

(N2, D) inputs.

required
variance Float[Array, '']

Scalar signal variance.

required
lengthscale Float[Array, '']

Scalar lengthscale.

required
nu float

Smoothness parameter; must be 0.5, 1.5, or 2.5.

required

Returns:

Type Description
Float[Array, 'N1 N2']

(N1, N2) Gram matrix.

Source code in src/pyrox/gp/_src/kernels.py
def matern_kernel(
    X1: Float[Array, "N1 D"],
    X2: Float[Array, "N2 D"],
    variance: Float[Array, ""],
    lengthscale: Float[Array, ""],
    nu: float,
) -> Float[Array, "N1 N2"]:
    r"""Matern kernel with closed-form ``nu in {1/2, 3/2, 5/2}``.

    .. math::
        k(x, x') = \sigma^2\, f_\nu(r / \ell),
        \qquad r = \|x - x'\|

    Only the three common half-integer orders are supported because those
    admit closed-form expressions without Bessel evaluations. ``nu`` is a
    static Python float (not a JAX array) so the branch specializes at
    trace time.

    Args:
        X1: ``(N1, D)`` inputs.
        X2: ``(N2, D)`` inputs.
        variance: Scalar signal variance.
        lengthscale: Scalar lengthscale.
        nu: Smoothness parameter; must be ``0.5``, ``1.5``, or ``2.5``.

    Returns:
        ``(N1, N2)`` Gram matrix.
    """
    sq = _pairwise_sq_dist(X1, X2)
    # Jitter inside sqrt avoids NaN gradients at r = 0 (sqrt' is undefined).
    r = jnp.sqrt(jnp.clip(sq, min=1e-30)) / lengthscale
    if nu == 0.5:
        shape = jnp.exp(-r)
    elif nu == 1.5:
        a = jnp.sqrt(3.0) * r
        shape = (1.0 + a) * jnp.exp(-a)
    elif nu == 2.5:
        a = jnp.sqrt(5.0) * r
        shape = (1.0 + a + (a * a) / 3.0) * jnp.exp(-a)
    else:
        raise ValueError(f"matern_kernel supports nu in {{0.5, 1.5, 2.5}}, got {nu!r}")
    return variance * shape

periodic_kernel(X1, X2, variance, lengthscale, period)

Periodic (MacKay) kernel.

.. math:: k(x, x') = \sigma^2 \exp!\left( -\frac{2 \sin^2(\pi |x - x'| / p)}{\ell^2} \right)

For multi-dimensional inputs the argument uses the Euclidean distance, matching the common GPML convention.

Parameters:

Name Type Description Default
X1 Float[Array, 'N1 D']

(N1, D) inputs.

required
X2 Float[Array, 'N2 D']

(N2, D) inputs.

required
variance Float[Array, '']

Scalar signal variance.

required
lengthscale Float[Array, '']

Scalar lengthscale.

required
period Float[Array, '']

Scalar period p.

required

Returns:

Type Description
Float[Array, 'N1 N2']

(N1, N2) Gram matrix.

Source code in src/pyrox/gp/_src/kernels.py
def periodic_kernel(
    X1: Float[Array, "N1 D"],
    X2: Float[Array, "N2 D"],
    variance: Float[Array, ""],
    lengthscale: Float[Array, ""],
    period: Float[Array, ""],
) -> Float[Array, "N1 N2"]:
    r"""Periodic (MacKay) kernel.

    .. math::
        k(x, x') = \sigma^2 \exp\!\left(
            -\frac{2 \sin^2(\pi \|x - x'\| / p)}{\ell^2}
        \right)

    For multi-dimensional inputs the argument uses the Euclidean distance,
    matching the common GPML convention.

    Args:
        X1: ``(N1, D)`` inputs.
        X2: ``(N2, D)`` inputs.
        variance: Scalar signal variance.
        lengthscale: Scalar lengthscale.
        period: Scalar period ``p``.

    Returns:
        ``(N1, N2)`` Gram matrix.
    """
    sq = _pairwise_sq_dist(X1, X2)
    # Jitter inside sqrt avoids NaN gradients at r = 0 (sqrt' is undefined).
    r = jnp.sqrt(jnp.clip(sq, min=1e-30))
    sinsq = jnp.sin(jnp.pi * r / period) ** 2
    return variance * jnp.exp(-2.0 * sinsq / (lengthscale * lengthscale))

polynomial_kernel(X1, X2, variance, bias, degree)

Polynomial kernel.

.. math:: k(x, x') = \sigma^2 \bigl(x^\top x' + b\bigr)^d

:func:linear_kernel is the special case degree == 1 without the outer power. degree is a static Python int so the kernel specializes at trace time.

Parameters:

Name Type Description Default
X1 Float[Array, 'N1 D']

(N1, D) inputs.

required
X2 Float[Array, 'N2 D']

(N2, D) inputs.

required
variance Float[Array, '']

Scalar multiplier.

required
bias Float[Array, '']

Scalar additive bias inside the power.

required
degree int

Positive integer polynomial degree.

required

Returns:

Type Description
Float[Array, 'N1 N2']

(N1, N2) Gram matrix.

Source code in src/pyrox/gp/_src/kernels.py
def polynomial_kernel(
    X1: Float[Array, "N1 D"],
    X2: Float[Array, "N2 D"],
    variance: Float[Array, ""],
    bias: Float[Array, ""],
    degree: int,
) -> Float[Array, "N1 N2"]:
    r"""Polynomial kernel.

    .. math::
        k(x, x') = \sigma^2 \bigl(x^\top x' + b\bigr)^d

    :func:`linear_kernel` is the special case ``degree == 1`` without the
    outer power. ``degree`` is a static Python int so the kernel specializes
    at trace time.

    Args:
        X1: ``(N1, D)`` inputs.
        X2: ``(N2, D)`` inputs.
        variance: Scalar multiplier.
        bias: Scalar additive bias inside the power.
        degree: Positive integer polynomial degree.

    Returns:
        ``(N1, N2)`` Gram matrix.
    """
    if degree < 1:
        raise ValueError(f"polynomial_kernel requires degree >= 1, got {degree!r}")
    dot = einsum(X1, X2, "n1 d, n2 d -> n1 n2")
    return variance * (dot + bias) ** degree

rational_quadratic_kernel(X1, X2, variance, lengthscale, alpha)

Rational quadratic kernel.

.. math:: k(x, x') = \sigma^2 \left( 1 + \frac{|x - x'|^2}{2\alpha \ell^2} \right)^{-\alpha}

Scale mixture of RBF kernels: the limit alpha -> infty recovers the RBF, small alpha yields heavier-tailed correlations.

Parameters:

Name Type Description Default
X1 Float[Array, 'N1 D']

(N1, D) inputs.

required
X2 Float[Array, 'N2 D']

(N2, D) inputs.

required
variance Float[Array, '']

Scalar signal variance.

required
lengthscale Float[Array, '']

Scalar lengthscale.

required
alpha Float[Array, '']

Scalar shape parameter; must be positive.

required

Returns:

Type Description
Float[Array, 'N1 N2']

(N1, N2) Gram matrix.

Source code in src/pyrox/gp/_src/kernels.py
def rational_quadratic_kernel(
    X1: Float[Array, "N1 D"],
    X2: Float[Array, "N2 D"],
    variance: Float[Array, ""],
    lengthscale: Float[Array, ""],
    alpha: Float[Array, ""],
) -> Float[Array, "N1 N2"]:
    r"""Rational quadratic kernel.

    .. math::
        k(x, x') = \sigma^2 \left(
            1 + \frac{\|x - x'\|^2}{2\alpha \ell^2}
        \right)^{-\alpha}

    Scale mixture of RBF kernels: the limit ``alpha -> infty`` recovers the
    RBF, small ``alpha`` yields heavier-tailed correlations.

    Args:
        X1: ``(N1, D)`` inputs.
        X2: ``(N2, D)`` inputs.
        variance: Scalar signal variance.
        lengthscale: Scalar lengthscale.
        alpha: Scalar shape parameter; must be positive.

    Returns:
        ``(N1, N2)`` Gram matrix.
    """
    sq = _pairwise_sq_dist(X1, X2)
    return variance * (1.0 + sq / (2.0 * alpha * lengthscale * lengthscale)) ** (-alpha)

rbf_kernel(X1, X2, variance, lengthscale)

Radial basis function (squared exponential) kernel.

.. math:: k(x, x') = \sigma^2 \exp!\left(-\frac{|x - x'|^2}{2\ell^2}\right)

Parameters:

Name Type Description Default
X1 Float[Array, 'N1 D']

(N1, D) inputs.

required
X2 Float[Array, 'N2 D']

(N2, D) inputs.

required
variance Float[Array, '']

Scalar signal variance sigma^2.

required
lengthscale Float[Array, '']

Scalar lengthscale ell.

required

Returns:

Type Description
Float[Array, 'N1 N2']

(N1, N2) kernel Gram matrix.

Source code in src/pyrox/gp/_src/kernels.py
def rbf_kernel(
    X1: Float[Array, "N1 D"],
    X2: Float[Array, "N2 D"],
    variance: Float[Array, ""],
    lengthscale: Float[Array, ""],
) -> Float[Array, "N1 N2"]:
    r"""Radial basis function (squared exponential) kernel.

    .. math::
        k(x, x') = \sigma^2 \exp\!\left(-\frac{\|x - x'\|^2}{2\ell^2}\right)

    Args:
        X1: ``(N1, D)`` inputs.
        X2: ``(N2, D)`` inputs.
        variance: Scalar signal variance ``sigma^2``.
        lengthscale: Scalar lengthscale ``ell``.

    Returns:
        ``(N1, N2)`` kernel Gram matrix.
    """
    sq = _pairwise_sq_dist(X1, X2)
    return variance * jnp.exp(-0.5 * sq / (lengthscale * lengthscale))

white_kernel(X1, X2, variance)

White-noise kernel.

.. math:: k(x, x') = \sigma^2 \,\delta(x, x')

Nonzero only where X1[i] exactly matches X2[j] across all feature dimensions. When evaluated at X1 == X2 this yields sigma^2 * I.

Parameters:

Name Type Description Default
X1 Float[Array, 'N1 D']

(N1, D) inputs.

required
X2 Float[Array, 'N2 D']

(N2, D) inputs.

required
variance Float[Array, '']

Scalar noise variance.

required

Returns:

Type Description
Float[Array, 'N1 N2']

(N1, N2) Gram matrix.

Source code in src/pyrox/gp/_src/kernels.py
def white_kernel(
    X1: Float[Array, "N1 D"],
    X2: Float[Array, "N2 D"],
    variance: Float[Array, ""],
) -> Float[Array, "N1 N2"]:
    r"""White-noise kernel.

    .. math::
        k(x, x') = \sigma^2 \,\delta(x, x')

    Nonzero only where ``X1[i]`` exactly matches ``X2[j]`` across all feature
    dimensions. When evaluated at ``X1 == X2`` this yields ``sigma^2 * I``.

    Args:
        X1: ``(N1, D)`` inputs.
        X2: ``(N2, D)`` inputs.
        variance: Scalar noise variance.

    Returns:
        ``(N1, N2)`` Gram matrix.
    """
    diff = rearrange(X1, "n1 d -> n1 1 d") - rearrange(X2, "n2 d -> 1 n2 d")
    match = jnp.all(diff == 0.0, axis=-1)
    return variance * match.astype(X1.dtype)