Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Latent GP classification

Latent GP Classification — the Three Patterns

Open In Colab

For non-conjugate likelihoods (Bernoulli, Poisson, StudentT, …) the GP latent function cannot be marginalized analytically — we need gp_sample inside the NumPyro model. This notebook walks the three pyrox patterns through a 2D binary-classification problem.

What you’ll learn:

  1. Use gp_sample to register the latent GP function as a sample site under an arbitrary observation likelihood.
  2. Compose it with a Bernoulli(logits=f) observation model.
  3. Fit kernel hyperparameters with SVI across all three patterns.

Background

Joint model

For a non-conjugate likelihood p(yf)p(y \mid f), the hierarchical model is

θp(θ),fX,θN ⁣(0,KXX(θ)),yifiBernoulli ⁣(σ(fi)),\theta \sim p(\theta), \quad f \mid X, \theta \sim \mathcal{N}\!\bigl(0, K_{XX}(\theta)\bigr), \quad y_i \mid f_i \sim \mathrm{Bernoulli}\!\bigl(\sigma(f_i)\bigr),

where σ(z)=1/(1+ez)\sigma(z) = 1 / (1 + e^{-z}) is the logistic link. The joint factorizes as

p(θ,f,yX)=p(θ)p(fX,θ)i=1Np(yifi).p(\theta, f, y \mid X) = p(\theta)\, p(f \mid X, \theta)\, \prod_{i=1}^{N} p(y_i \mid f_i).

Because the likelihood is not Gaussian, the marginal p(yX,θ)=p(yf)p(fX,θ)dfp(y \mid X, \theta) = \int p(y \mid f)\, p(f \mid X, \theta)\, df has no closed form — we can’t collapse ff out the way we did for regression.

Variational inference + the whitening trick

We approximate p(θ,fX,y)p(\theta, f \mid X, y) by a factorized variational distribution and maximize the evidence lower bound

L(ϕ)=Eqϕ ⁣[logp(yf)]KL ⁣(qϕ(θ,f)p(θ,fX)).\mathcal{L}(\phi) = \mathbb{E}_{q_\phi}\!\bigl[\log p(y \mid f)\bigr] - \mathrm{KL}\!\bigl(q_\phi(\theta, f) \,\|\, p(\theta, f \mid X)\bigr).

A naive parameterization places qϕ(f)q_\phi(f) directly over the latent values at the training inputs. Under NumPyro’s AutoNormal this becomes a mean-field Gaussian — independent across the NN training latents. That is a terrible approximation here, because the prior p(fX,θ)=N(0,KXX)p(f \mid X, \theta) = \mathcal{N}(0, K_{XX}) has strong off-diagonal correlations and the optimum of the ELBO ends up dominated by that KL, not by the data fit. In practice the posterior over ff barely moves and the predictive decision boundary is essentially noise.

The standard fix is the whitening reparameterization (Murray & Adams, 2010; Hensman et al., 2015). Cholesky-factor the prior covariance once,

KXX(θ)+ϵI=LL,uN(0,IN),f=Lu,K_{XX}(\theta) + \epsilon I = L L^\top, \qquad u \sim \mathcal{N}(0, I_N), \qquad f = L u,

so the latent enters the likelihood through a deterministic transformation of an i.i.d. unit-Gaussian site. Now mean-field qϕ(u)q_\phi(u) is well-conditioned — there are no a-priori correlations between the uiu_i to begin with — and the GP correlations come back into ff for free through LL. Empirically this is the difference between a chaotic boundary and one that tracks the data.

gp_sample modes — when to reach for which

pyrox exposes gp_sample("f", prior, *, whitened, guide) as the NumPyro-aware registration for the latent function. The three modes are mutually exclusive:

  • default (whitened=False, guide=None) — registers a single N(μ(X),KXX+ϵI)\mathcal{N}(\mu(X), K_{XX} + \epsilon I) sample site via gaussx.MultivariateNormal. The right call for MCMC (NUTS samples directly in the original ff-space and handles the prior covariance correctly) and for conjugate workflows where you marginalize ff out via gp_factor.
  • whitened (whitened=True) — registers an NN-dimensional unit-Gaussian site uN(0,I)u \sim \mathcal{N}(0, I) and returns the deterministic f=μ(X)+Luf = \mu(X) + L u shown above. The right call for SVI on non-conjugate likelihoods with NumPyro auto-guides such as AutoNormal. This is what every pattern in this notebook uses.
  • guide (guide=...) — delegates to a concrete sparse variational guide (FullRankGuide, MeanFieldGuide, or WhitenedGuide over inducing values), which is the right call for the sparse SVGP workflow when NN is too large for an NN-dimensional latent.

Posterior predictive

For a test input xx_* the posterior predictive is

p(y=1x,D)=σ(f)p(fx,D)df,p(y_* = 1 \mid x_*, \mathcal{D}) = \int \sigma(f_*)\, p(f_* \mid x_*, \mathcal{D})\, df_*,

which we estimate by Monte Carlo + the logistic-Gaussian approximation further below.

The three patterns

PatternKernel hyperparameters live inWhen to reach for it
APure eqx.Module + numpyro.sample + eqx.tree_atLightweight; no base class required.
BCustom PyroxModule kernel that calls pyrox_sample in __call__Self-contained probabilistic kernel.
CParameterized kernel (shipped RBF, Matern, …) with set_priorFull registry, constraints, autoguides, modes.

Only the kernel construction differs across patterns — the whitening boilerplate and the Bernoulli likelihood are shared via a small helper.

Setup

import subprocess
import sys


try:
    import google.colab  # noqa: F401

    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    subprocess.run(
        [
            sys.executable,
            "-m",
            "pip",
            "install",
            "-q",
            "pyrox[colab] @ git+https://github.com/jejjohnson/pyrox@main",
        ],
        check=True,
    )
import warnings


warnings.filterwarnings("ignore", message=r".*IProgress.*")

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
from jaxtyping import Array, Float
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from numpyro.optim import Adam

from pyrox._core import PyroxModule, pyrox_method
from pyrox.gp import RBF, GPPrior, Kernel, gp_sample
from pyrox.gp._src.kernels import rbf_kernel


jax.config.update("jax_enable_x64", True)

Reproducibility readout.

import importlib.util


try:
    from IPython import get_ipython

    ipython = get_ipython()
except ImportError:
    ipython = None

if ipython is not None and importlib.util.find_spec("watermark") is not None:
    ipython.run_line_magic("load_ext", "watermark")
    ipython.run_line_magic(
        "watermark",
        "-v -m -p jax,equinox,numpyro,gaussx,pyrox,matplotlib",
    )
else:
    print("watermark extension not installed; skipping reproducibility readout.")
Python implementation: CPython
Python version       : 3.12.13
IPython version      : 9.12.0

jax       : 0.8.3
equinox   : 0.13.7
numpyro   : 0.20.1
gaussx    : 0.0.10
pyrox     : 0.0.6
matplotlib: 3.10.8

Compiler    : GCC 14.3.0
OS          : Linux
Release     : 6.8.0-1044-azure
Machine     : x86_64
Processor   : x86_64
CPU cores   : 16
Architecture: 64bit

Toy 2D binary dataset

Two interleaved half-moons — the boundary between them is the function the latent GP has to learn.

key = jr.PRNGKey(0)


def make_half_moons(key, n_per_class=30):
    k1, k2, k3, k4 = jr.split(key, 4)
    theta0 = jnp.pi * jr.uniform(k1, (n_per_class,))
    x0 = jnp.stack([jnp.cos(theta0), jnp.sin(theta0)], axis=-1)
    x0 = x0 + 0.1 * jr.normal(k2, x0.shape)
    theta1 = jnp.pi + jnp.pi * jr.uniform(k3, (n_per_class,))
    x1 = jnp.stack([1.0 + jnp.cos(theta1), 0.5 + jnp.sin(theta1)], axis=-1)
    x1 = x1 + 0.1 * jr.normal(k4, x1.shape)
    X = jnp.concatenate([x0, x1], axis=0)
    y = jnp.concatenate([jnp.zeros(n_per_class), jnp.ones(n_per_class)], axis=0).astype(
        jnp.int32
    )
    return X, y


X_train, y_train = make_half_moons(key, n_per_class=20)

grid_lo, grid_hi = -1.6, 2.6
grid_steps = 40
xx, yy = jnp.meshgrid(
    jnp.linspace(grid_lo, grid_hi, grid_steps),
    jnp.linspace(grid_lo, grid_hi, grid_steps),
)
X_grid = jnp.stack([xx.ravel(), yy.ravel()], axis=-1)

print(f"Training points: {X_train.shape[0]}")
print(f"Grid points:     {X_grid.shape[0]}")


def scatter_data(ax):
    mask0 = y_train == 0
    ax.scatter(
        X_train[mask0, 0],
        X_train[mask0, 1],
        s=40,
        c="C0",
        edgecolors="k",
        linewidths=0.5,
        label="class 0",
        zorder=5,
    )
    ax.scatter(
        X_train[~mask0, 0],
        X_train[~mask0, 1],
        s=40,
        c="C1",
        edgecolors="k",
        linewidths=0.5,
        label="class 1",
        zorder=5,
    )
    ax.set_xlim(grid_lo, grid_hi)
    ax.set_ylim(grid_lo, grid_hi)
    ax.set_xlabel(r"$x_1$")
    ax.set_ylabel(r"$x_2$")


fig, ax = plt.subplots(figsize=(6, 5))
scatter_data(ax)
ax.set_title("Half-moons binary classification")
ax.legend()
plt.show()
Training points: 40
Grid points:     1600
<Figure size 600x500 with 1 Axes>

The common model skeleton

All three patterns share the same whitened latent + Bernoulli likelihood — the only thing that changes is how kernel is built (which is the whole point of comparing the patterns):

prior = GPPrior(kernel, X, jitter=1e-4)   # kernel built per-pattern
f     = gp_sample("f", prior, whitened=True)
        # samples u ~ N(0, I_N) and returns the deterministic
        # f = mu(X) + L u with L = chol(K + jitter*I)
obs   ~ Bernoulli(logits=f)

gp_sample(..., whitened=True) is the model-layer entry point that registers the unit-Gaussian site as "f_u" and the unwhitened function value as the deterministic "f". Each pattern’s model function then reads as “build the kernel, then run the standard whitened-classification pipeline.”

Pattern A — pure Equinox + eqx.tree_at

class RBFLite(Kernel):
    """Minimal Equinox-native RBF kernel."""

    variance: Float[Array, ""]
    lengthscale: Float[Array, ""]

    def __call__(
        self, X1: Float[Array, "N1 D"], X2: Float[Array, "N2 D"]
    ) -> Float[Array, "N1 N2"]:
        return rbf_kernel(X1, X2, self.variance, self.lengthscale)


def model_pattern_a(X, y):
    variance = numpyro.sample("variance", dist.LogNormal(0.0, 1.0))
    lengthscale = numpyro.sample("lengthscale", dist.LogNormal(0.0, 1.0))
    kernel = RBFLite(variance=jnp.array(1.0), lengthscale=jnp.array(1.0))
    kernel = eqx.tree_at(
        lambda k: (k.variance, k.lengthscale),
        kernel,
        (variance, lengthscale),
    )
    prior = GPPrior(kernel=kernel, X=X, jitter=1e-4)
    f = gp_sample("f", prior, whitened=True)
    numpyro.sample("obs", dist.Bernoulli(logits=f), obs=y)

Pattern B — PyroxModule kernel with pyrox_sample

class RBFPyrox(Kernel, PyroxModule):
    """PyroxModule kernel with inline prior registration."""

    pyrox_name: str = "RBFPyrox"

    @pyrox_method
    def __call__(
        self, X1: Float[Array, "N1 D"], X2: Float[Array, "N2 D"]
    ) -> Float[Array, "N1 N2"]:
        variance = self.pyrox_sample("variance", dist.LogNormal(0.0, 1.0))
        lengthscale = self.pyrox_sample("lengthscale", dist.LogNormal(0.0, 1.0))
        return rbf_kernel(X1, X2, variance, lengthscale)


def model_pattern_b(X, y):
    kernel = RBFPyrox()
    prior = GPPrior(kernel=kernel, X=X, jitter=1e-4)
    f = gp_sample("f", prior, whitened=True)
    numpyro.sample("obs", dist.Bernoulli(logits=f), obs=y)

Pattern C — Parameterized kernel with set_prior

def model_pattern_c(X, y):
    kernel = RBF()
    kernel.set_prior("variance", dist.LogNormal(0.0, 1.0))
    kernel.set_prior("lengthscale", dist.LogNormal(0.0, 1.0))
    prior = GPPrior(kernel=kernel, X=X, jitter=1e-4)
    f = gp_sample("f", prior, whitened=True)
    numpyro.sample("obs", dist.Bernoulli(logits=f), obs=y)

Fit all three with the same SVI loop

def fit(model_fn, seed, n_steps=400):
    guide = AutoNormal(model_fn)
    svi = SVI(model_fn, guide, Adam(2e-2), Trace_ELBO())
    state = svi.init(jr.PRNGKey(seed), X_train, y_train)
    losses = []
    for _ in range(n_steps):
        state, loss = svi.update(state, X_train, y_train)
        losses.append(float(loss))
    return state, svi, guide, losses


state_a, svi_a, guide_a, losses_a = fit(model_pattern_a, 1)
state_b, svi_b, guide_b, losses_b = fit(model_pattern_b, 2)
state_c, svi_c, guide_c, losses_c = fit(model_pattern_c, 3)


def posterior_hyperparams(svi, state, variance_site, lengthscale_site):
    params = svi.get_params(state)
    variance = float(jnp.exp(params[f"{variance_site}_auto_loc"]))
    lengthscale = float(jnp.exp(params[f"{lengthscale_site}_auto_loc"]))
    return variance, lengthscale


v_a, ls_a = posterior_hyperparams(svi_a, state_a, "variance", "lengthscale")
v_b, ls_b = posterior_hyperparams(
    svi_b, state_b, "RBFPyrox.variance", "RBFPyrox.lengthscale"
)
v_c, ls_c = posterior_hyperparams(svi_c, state_c, "RBF.variance", "RBF.lengthscale")

print(f"{'pattern':<10} {'variance':>10} {'lengthscale':>12}")
print("-" * 36)
print(f"{'A':<10} {v_a:>10.3f} {ls_a:>12.3f}")
print(f"{'B':<10} {v_b:>10.3f} {ls_b:>12.3f}")
print(f"{'C':<10} {v_c:>10.3f} {ls_c:>12.3f}")
pattern      variance  lengthscale
------------------------------------
A               4.707        0.704
B               4.070        0.699
C               4.326        0.767
fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(losses_a, "C0-", label="Pattern A", alpha=0.85)
ax.plot(losses_b, "C1-", label="Pattern B", alpha=0.85)
ax.plot(losses_c, "C2-", label="Pattern C", alpha=0.85)
ax.set_xlabel("SVI step")
ax.set_ylabel(r"$-$ELBO")
ax.set_title("SVI convergence — three patterns")
ax.legend()
plt.show()
<Figure size 1200x500 with 1 Axes>

Decision boundaries — MC-averaged posterior predictive

We want the Bayesian posterior predictive

p(y=1x,D)=σ(f)p(fx,f,θ)q(θ,u)dfdudθ,p(y_* = 1 \mid x_*, \mathcal{D}) = \iint \sigma(f_*)\, p(f_* \mid x_*, f, \theta)\, q(\theta, u)\, df_*\, du\, d\theta,

where q(θ,u)q(\theta, u) is the SVI guide posterior over the kernel hyperparameters and the whitened latent, and p(fx,f,θ)p(f_* \mid x_*, f, \theta) is the standard noise-free GP conditional

p(fx,f,θ)=N ⁣(μ,σ2),μ=k(KXX+ϵI)1f,σ2=kk(KXX+ϵI)1k,p(f_* \mid x_*, f, \theta) = \mathcal{N}\!\bigl(\mu_*, \sigma_*^2\bigr), \qquad \mu_* = k_*^\top\bigl(K_{XX} + \epsilon I\bigr)^{-1} f, \qquad \sigma_*^2 = k_{**} - k_*^\top\bigl(K_{XX} + \epsilon I\bigr)^{-1} k_*,

with k=K(x,X)k_* = K(x_*, X) and k=K(x,x)k_{**} = K(x_*, x_*). ε is the same jitter used inside latent_f.

We estimate the outer integral by Monte Carlo and the inner integral by the logistic-Gaussian approximation

σ(f)N(fμ,σ2)df    σ(κ(σ2)μ),κ(v)=(1+πv/8)1/2,\int \sigma(f_*)\, \mathcal{N}(f_* \mid \mu_*, \sigma_*^2)\, df_* \;\approx\; \sigma\bigl(\kappa(\sigma_*^2)\, \mu_*\bigr), \qquad \kappa(v) = \bigl(1 + \pi v / 8\bigr)^{-1/2},

which is exact for the probit link and within a percent for the logistic link that we actually use (MacKay, 1992).

Putting it together — for each MC sample ss:

  1. Draw (θs,us)qϕ(\theta_s, u_s) \sim q_\phi from the SVI guide.
  2. Reconstruct the training-time latent fs=L(θs)usf_s = L(\theta_s)\, u_s with the matching Cholesky.
  3. Build the GP prior at θs\theta_s and condition on fsf_s to get (μ,s,σ,s2)(\mu_{*, s}, \sigma_{*, s}^2) at the grid points.
  4. Apply the logistic-Gaussian approximation to get per-sample class probabilities.

Average:

p(y=1x,D)    1Ss=1Sσ ⁣(κ(σ,s2)μ,s).p(y_* = 1 \mid x_*, \mathcal{D}) \;\approx\; \tfrac{1}{S}\sum_{s=1}^{S} \sigma\!\bigl(\kappa(\sigma_{*, s}^2)\, \mu_{*, s}\bigr).

This uses the actual fitted Bernoulli model — no Gaussian-likelihood stand-in anywhere.

N_MC = 16


def predict_prob(guide, svi_obj, state, variance_site, lengthscale_site):
    """MC-averaged posterior predictive for the Bernoulli GP.

    Each guide sample yields ``(theta_s, u_s)`` over the kernel
    hyperparameters and the whitened latent. We reconstruct the
    training-time latent ``f_s = L(theta_s) @ u_s``, condition the GP
    on ``f_s`` as noise-free observations, and apply the
    logistic-Gaussian approximation to get class probabilities at the
    grid points before averaging across samples.
    """
    params = svi_obj.get_params(state)
    posterior = guide.sample_posterior(jr.PRNGKey(42), params, sample_shape=(N_MC,))
    variances = posterior[variance_site]
    lengthscales = posterior[lengthscale_site]
    u_samples = posterior["f_u"]  # (N_MC, N_train)

    probs = []
    for v, ls, u in zip(variances, lengthscales, u_samples):
        kernel = RBF(init_variance=float(v), init_lengthscale=float(ls))
        prior = GPPrior(kernel=kernel, X=X_train, jitter=1e-4)
        L = jnp.linalg.cholesky(prior._prior_operator().as_matrix())
        f_train = L @ u
        # Condition on the reconstructed latent as noise-free observations.
        cond = prior.condition(f_train, jnp.array(1e-4))
        mean, var = cond.predict(X_grid)
        kappa = 1.0 / jnp.sqrt(1.0 + jnp.pi * jnp.clip(var, min=0.0) / 8.0)
        probs.append(jax.nn.sigmoid(kappa * mean))
    return jnp.mean(jnp.stack(probs), axis=0).reshape(grid_steps, grid_steps)


prob_a = predict_prob(guide_a, svi_a, state_a, "variance", "lengthscale")
prob_b = predict_prob(
    guide_b, svi_b, state_b, "RBFPyrox.variance", "RBFPyrox.lengthscale"
)
prob_c = predict_prob(guide_c, svi_c, state_c, "RBF.variance", "RBF.lengthscale")

fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True)
for ax, prob, title in zip(
    axes, [prob_a, prob_b, prob_c], ["Pattern A", "Pattern B", "Pattern C"]
):
    cf = ax.contourf(
        xx, yy, prob, levels=10, cmap="RdBu_r", alpha=0.7, vmin=0.0, vmax=1.0
    )
    ax.contour(xx, yy, prob, levels=[0.5], colors="k", linewidths=1.5)
    scatter_data(ax)
    ax.set_title(title)
    ax.legend(fontsize=9)
fig.colorbar(cf, ax=axes, shrink=0.9, label=r"$p(y=1 \mid x)$")
plt.show()
<Figure size 1800x500 with 4 Axes>

When to use which

The non-conjugate story is exactly the conjugate story with gp_factor swapped for gp_sample + an explicit likelihood. The pattern choice is still purely ergonomic:

  • Pattern A — lightest; sample hyperparameters at the model level and splice into a pure Equinox kernel. Good for one-off Bayesian extensions.
  • Pattern B — custom PyroxModule kernel that owns its priors. Good when the kernel is a reusable probabilistic building block.
  • Pattern C — full registry via Parameterized. Use the shipped RBF, Matern, etc., and attach priors / autoguides / modes declaratively.

Switch the likelihood, change nothing else — Bernoulli(logits=f) becomes Poisson(rate=jnp.exp(f)) for counts, StudentT(...) for robust regression, and so on.