Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Sparse Variational GP with Inducing Points

Sparse GPs approximate the full kernel with MNM \ll N inducing points:

KKNMKMM1KMNK \approx K_{NM} K_{MM}^{-1} K_{MN}

The resulting covariance is σ2I+UU\sigma^2 I + U U^\top — a LowRankUpdate in gaussx. We optimize the ELBO with respect to inducing point locations and kernel hyperparameters via jax.grad.

Background

Sparse GP methods address the O(N3)O(N^3) cost of exact GP inference by introducing MNM \ll N inducing points Z={zm}m=1MZ = \{z_m\}_{m=1}^M. The key idea is to approximate the full GP posterior via a variational distribution that conditions on the inducing variables u=f(Z)u = f(Z). The Nystrom approximation gives:

KffQff=KfMKMM1KMfK_{ff} \approx Q_{ff} = K_{fM} K_{MM}^{-1} K_{Mf}

The variational free energy (VFE) framework (Titsias, 2009) optimizes a lower bound on the log-marginal likelihood:

LVFE=logN(y0,Qff+σ2I)12σ2tr(KffQff)\mathcal{L}_{\text{VFE}} = \log \mathcal{N}(y \mid 0, Q_{ff} + \sigma^2 I) - \frac{1}{2\sigma^2}\operatorname{tr}(K_{ff} - Q_{ff})

The trace correction penalizes the approximation error, and the bound is tight when the inducing points explain all the data.

from __future__ import annotations

import warnings


warnings.filterwarnings("ignore", message=r".*IProgress.*")

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

import gaussx


jax.config.update("jax_enable_x64", True)

Generate data

key = jax.random.PRNGKey(123)
k1, k2 = jax.random.split(key)

n_data = 300
noise_std = 0.2

f_true = lambda x: jnp.sin(2 * x) + 0.3 * jnp.cos(5 * x)

x_data = jnp.sort(jax.random.uniform(k1, (n_data,), minval=-4.0, maxval=4.0))
y_data = f_true(x_data) + noise_std * jax.random.normal(k2, (n_data,))

x_plot = jnp.linspace(-4.5, 4.5, 500)

Build Nystrom approximation

The Nystrom factor U=KNMLMMU = K_{NM} L_{MM}^{-\top} gives QNN=UUQ_{NN} = UU^\top, so the noisy covariance σ2I+UU\sigma^2 I + UU^\top is a rank-MM update of a diagonal -- exactly a LowRankUpdate. This is why gaussx can apply the Woodbury identity for O(NM2+M3)O(NM^2 + M^3) solve and the matrix determinant lemma for O(NM2+M3)O(NM^2 + M^3) logdet.

n_inducing = 12


def rbf(x1, x2, ls, var):
    return var * jnp.exp(-0.5 * (x1[:, None] - x2[None, :]) ** 2 / ls**2)


def build_sparse_gp(x_data, x_inducing, ls, var, noise):
    """Build the sparse GP covariance as a LowRankUpdate."""
    K_nm = rbf(x_data, x_inducing, ls, var)
    K_mm = rbf(x_inducing, x_inducing, ls, var)

    # Nystrom: U = K_nm @ chol(K_mm)^{-T}
    jitter = 1e-5 * jnp.eye(len(x_inducing))
    L_mm = jnp.linalg.cholesky(K_mm + jitter)
    U = jax.scipy.linalg.solve_triangular(L_mm, K_nm.T, lower=True).T

    # Sigma = noise * I + U U^T
    return gaussx.low_rank_plus_diag(noise * jnp.ones(len(x_data)), U)


# Initial setup
x_inducing = jnp.linspace(-3.5, 3.5, n_inducing)
ls, var, noise = 1.0, 1.0, noise_std**2

sigma = build_sparse_gp(x_data, x_inducing, ls, var, noise)
print(f"Data: {n_data}, Inducing: {n_inducing}")
print(f"Operator rank: {sigma.rank}")
Data: 300, Inducing: 12
Operator rank: 12

ELBO (variational lower bound)

L=logN(y0,σ2I+UU)12σ2tr(KNNQNN)\mathcal{L} = \log \mathcal{N}(y \mid 0, \sigma^2 I + U U^\top) - \frac{1}{2\sigma^2} \mathrm{tr}(K_{NN} - Q_{NN})

The first term uses gaussx solve + logdet. The trace correction penalizes information lost by the approximation.

def elbo(log_params, x_data, y_data, x_inducing):
    ls = jnp.exp(log_params[0])
    var = jnp.exp(log_params[1])
    noise = jnp.exp(log_params[2])

    sigma_op = build_sparse_gp(x_data, x_inducing, ls, var, noise)

    # Log-likelihood term
    alpha = gaussx.solve(sigma_op, y_data)
    data_fit = -0.5 * jnp.dot(y_data, alpha)
    complexity = -0.5 * gaussx.logdet(sigma_op)
    const = -0.5 * len(y_data) * jnp.log(2 * jnp.pi)

    # Trace correction: tr(K_nn - Q_nn) / (2 * noise)
    # K_nn diagonal = var (for RBF with same point)
    # Q_nn diagonal = diag(U U^T)
    Q_diag = jnp.sum(sigma_op.U**2, axis=1)
    trace_correction = -0.5 * jnp.sum(var - Q_diag) / noise

    return data_fit + complexity + const + trace_correction


log_params = jnp.log(jnp.array([ls, var, noise]))
print(f"Initial ELBO: {elbo(log_params, x_data, y_data, x_inducing):.2f}")
Initial ELBO: -165.14

Optimize hyperparameters

grad_fn = jax.jit(jax.grad(lambda p: -elbo(p, x_data, y_data, x_inducing)))

lr = 0.01
for _i in range(100):
    g = grad_fn(log_params)
    # Clip gradients for stability
    g = jnp.clip(g, -10.0, 10.0)
    log_params = log_params - lr * g

ls_opt, var_opt, noise_opt = jnp.exp(log_params)
print(f"Optimized: ls={ls_opt:.3f}, var={var_opt:.3f}, noise={noise_opt:.4f}")
print(f"Final ELBO: {elbo(log_params, x_data, y_data, x_inducing):.2f}")
Optimized: ls=1.020, var=1.504, noise=0.0922
Final ELBO: -96.68

Predict

sigma_opt = build_sparse_gp(x_data, x_inducing, ls_opt, var_opt, noise_opt)
alpha_opt = gaussx.solve(sigma_opt, y_data)

# Prediction via inducing points
K_star_m = rbf(x_plot, x_inducing, ls_opt, var_opt)
K_mm = rbf(x_inducing, x_inducing, ls_opt, var_opt)
K_nm = rbf(x_data, x_inducing, ls_opt, var_opt)

weights = jnp.linalg.solve(
    K_mm + 1e-5 * jnp.eye(n_inducing),
    K_nm.T @ alpha_opt,
)
y_pred = K_star_m @ weights

Visualize

fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(x_plot, f_true(x_plot), "k--", lw=1.5, label="True function", zorder=4)
ax.scatter(
    x_data,
    y_data,
    s=30,
    c="C0",
    alpha=0.3,
    edgecolors="k",
    linewidths=0.5,
    label="Data",
    zorder=5,
)
ax.plot(x_plot, y_pred, "C1-", lw=2, label="Sparse GP mean", zorder=3)
ax.scatter(
    x_inducing,
    jnp.full_like(x_inducing, -1.8),
    marker="^",
    s=50,
    c="C2",
    zorder=5,
    label=f"Inducing pts ({n_inducing})",
)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_title(
    f"Sparse Variational GP: {n_data} data, {n_inducing} inducing "
    f"(ls={ls_opt:.2f}, var={var_opt:.2f})"
)
ax.legend(fontsize=9)
ax.grid(True, which="major", alpha=0.3)
ax.grid(True, which="minor", alpha=0.1)
ax.minorticks_on()
plt.tight_layout()
plt.show()
<Figure size 1000x400 with 1 Axes>

Summary

Componentgaussx role
low_rank_plus_diagBuild Nystrom covariance
gaussx.solveWoodbury solve for weights
gaussx.logdetMatrix determinant lemma for ELBO
jax.gradDifferentiate through everything

References

  • Quinonero-Candela, J. & Rasmussen, C. E. (2005). A unifying view of sparse approximate Gaussian process regression. JMLR, 6, 1939--1959.
  • Titsias, M. (2009). Variational learning of inducing variables in sparse Gaussian processes. Proc. AISTATS.
  • Hensman, J., Fusi, N., & Lawrence, N. D. (2013). Gaussian processes for big data. Proc. UAI.