Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Kernel Approximations: Nystrom & Random Fourier Features

Full kernel matrices cost O(N2)O(N^2) memory and O(N3)O(N^3) time to factorize. For large-scale GP inference we need structured approximations that preserve the low-rank nature of the kernel.

This notebook showcases six gaussx utilities:

FunctionReturnsPurpose
nystrom_operatorLowRankUpdateNystrom low-rank approximation
rff_operatorLowRankUpdateRandom Fourier features approximation
center_kernelMatrixLinearOperatorCenter a kernel matrix (HKHHKH)
centering_operatorLowRankUpdateCentering matrix H=I11T/nH = I - 11^T/n
hsicscalarHilbert-Schmidt Independence Criterion
mmd_squaredscalarMaximum Mean Discrepancy (squared)

Background

Nystrom approximation. Given NN data points and MNM \ll N inducing (landmark) points ZZ, the Nystrom method approximates the full kernel matrix as

KffQff=KfMKMM1KMfK_{ff} \approx Q_{ff} = K_{fM} K_{MM}^{-1} K_{Mf}

This is a rank-MM positive semi-definite matrix. Internally, gaussx stores it as a LowRankUpdate with zero base and factor U=KfMLMMU = K_{fM} L_{MM}^{-\top} where LMML_{MM} is the Cholesky factor of KMMK_{MM}, so that Qff=UUQ_{ff} = U U^\top. Adding noise gives Qff+σ2IQ_{ff} + \sigma^2 I, a LowRankUpdate that can be solved in O(NM2)O(NM^2) via the Woodbury identity.

Random Fourier Features (RFF). Rahimi & Recht (2007) showed that for shift-invariant kernels the kernel function can be approximated via a finite-dimensional random feature map:

k(x,x)ϕ(x)ϕ(x),ϕj(x)=2Dcos(ωjx+bj)k(x, x') \approx \phi(x)^\top \phi(x'), \quad \phi_j(x) = \sqrt{\frac{2}{D}} \cos(\omega_j^\top x + b_j)

where ωj\omega_j is drawn from the spectral density of the kernel and bjUniform(0,2π)b_j \sim \mathrm{Uniform}(0, 2\pi). For the RBF kernel with lengthscale \ell, the spectral density is ωN(0,2I)\omega \sim \mathcal{N}(0, \ell^{-2} I). The feature matrix ΦRN×D\Phi \in \mathbb{R}^{N \times D} gives a rank-DD approximation KΦΦK \approx \Phi \Phi^\top, again stored as a LowRankUpdate.

from __future__ import annotations

import warnings


warnings.filterwarnings("ignore", message=r".*IProgress.*")

import jax
import jax.numpy as jnp
import lineax as lx
import matplotlib.pyplot as plt

import gaussx


jax.config.update("jax_enable_x64", True)

Setup: 1D GP regression data

We generate 500 noisy observations from a smooth function and define an RBF kernel helper.

key = jax.random.PRNGKey(42)
k1, k2 = jax.random.split(key)

n_train = 500
noise_std = 0.2

f_true = lambda x: jnp.sin(3 * x) * jnp.exp(-0.5 * x**2)

x_train = jnp.sort(jax.random.uniform(k1, (n_train,), minval=-4.0, maxval=4.0))
y_train = f_true(x_train) + noise_std * jax.random.normal(k2, (n_train,))

x_plot = jnp.linspace(-4.5, 4.5, 300)
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(x_plot, f_true(x_plot), "k--", lw=1.5, label="True function", zorder=4)
ax.scatter(
    x_train,
    y_train,
    s=30,
    c="C0",
    edgecolors="k",
    linewidths=0.5,
    alpha=0.5,
    label="Training data (N=500)",
    zorder=5,
)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.legend(fontsize=9)
ax.set_title("1D Regression Data")
ax.grid(True, which="major", alpha=0.3)
ax.grid(True, which="minor", alpha=0.1)
ax.minorticks_on()
plt.tight_layout()
plt.show()
<Figure size 1000x400 with 1 Axes>

RBF kernel helper

def rbf_kernel(x1, x2, lengthscale, variance):
    """Squared-exponential kernel matrix."""
    sq_dist = (x1[:, None] - x2[None, :]) ** 2
    return variance * jnp.exp(-0.5 * sq_dist / lengthscale**2)


lengthscale = 0.8
variance = 1.0
noise_var = noise_std**2

# Full kernel matrix for comparison
K_full = rbf_kernel(x_train, x_train, lengthscale, variance)
K_full_op = lx.MatrixLinearOperator(K_full)

3. Nystrom Approximation

We select MM inducing points on an evenly-spaced grid and build the Nystrom low-rank approximation.

m_inducing = 30
z_inducing = jnp.linspace(-4.0, 4.0, m_inducing)

# Cross-covariance K_XZ (N, M) and inducing kernel K_ZZ (M, M)
K_XZ = rbf_kernel(x_train, z_inducing, lengthscale, variance)
K_ZZ = rbf_kernel(z_inducing, z_inducing, lengthscale, variance)
K_ZZ_op = lx.MatrixLinearOperator(K_ZZ)

# Build the Nystrom operator -> LowRankUpdate
K_nys = gaussx.nystrom_operator(K_XZ, K_ZZ_op)
print(f"Nystrom operator type: {type(K_nys).__name__}")
Nystrom operator type: LowRankUpdate

Approximation quality

Compare the Nystrom approximation against the full kernel via the Frobenius norm error and a diagonal comparison.

K_nys_dense = K_nys.as_matrix()
frob_error = jnp.linalg.norm(K_full - K_nys_dense, "fro")
frob_relative = frob_error / jnp.linalg.norm(K_full, "fro")

print(f"Frobenius norm error:    {frob_error:.4f}")
print(f"Relative Frobenius err:  {frob_relative:.4f}")
Frobenius norm error:    0.0000
Relative Frobenius err:  0.0000
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(
    x_train,
    jnp.diag(K_full),
    "k--",
    lw=1.5,
    label="Full kernel diagonal",
    zorder=4,
)
ax.plot(
    x_train,
    jnp.diag(K_nys_dense),
    "--",
    color="C1",
    lw=2,
    label=f"Nystrom diagonal (M={m_inducing})",
    zorder=3,
)
ax.set_xlabel("x")
ax.set_ylabel("k(x, x)")
ax.legend(fontsize=9)
ax.set_title("Diagonal Comparison: Full vs Nystrom")
ax.grid(True, which="major", alpha=0.3)
ax.grid(True, which="minor", alpha=0.1)
ax.minorticks_on()
plt.tight_layout()
plt.show()
<Figure size 1000x400 with 1 Axes>

GP regression with the Nystrom approximation

The system to solve is (Knys+σ2I)α=y(K_{\mathrm{nys}} + \sigma^2 I) \alpha = y. Because KnysK_{\mathrm{nys}} is a LowRankUpdate, gaussx exploits the Woodbury identity automatically, giving O(NM2)O(NM^2) cost instead of O(N3)O(N^3).

# Build K_nys + sigma^2 I as a LowRankUpdate with diagonal base
K_noisy = gaussx.LowRankUpdate(
    base=lx.DiagonalLinearOperator(jnp.full(n_train, noise_var)),
    U=K_nys.U,
    d=K_nys.d,
    V=K_nys.V,
)

alpha_nys = gaussx.solve(K_noisy, y_train)

# Predictions at test points
K_star_Z = rbf_kernel(x_plot, z_inducing, lengthscale, variance)
K_star_train = rbf_kernel(x_plot, x_train, lengthscale, variance)
mu_nys = K_star_train @ alpha_nys

# Full GP for reference
K_noisy_full = lx.MatrixLinearOperator(K_full + noise_var * jnp.eye(n_train))
alpha_full = gaussx.solve(K_noisy_full, y_train)
mu_full = K_star_train @ alpha_full
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(x_plot, f_true(x_plot), "k--", lw=1.5, label="True function", zorder=4)
ax.plot(x_plot, mu_full, "C0-", lw=2, label="Full GP", zorder=3)
ax.plot(x_plot, mu_nys, "C1--", lw=2, label=f"Nystrom GP (M={m_inducing})", zorder=3)
ax.scatter(
    x_train,
    y_train,
    s=30,
    c="grey",
    edgecolors="k",
    linewidths=0.5,
    alpha=0.3,
    zorder=5,
)
ax.scatter(
    z_inducing,
    jnp.zeros_like(z_inducing),
    marker="|",
    s=100,
    c="C1",
    zorder=5,
    label="Inducing points",
)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.legend(fontsize=9)
ax.set_title("GP Predictions: Full vs Nystrom")
ax.grid(True, which="major", alpha=0.3)
ax.grid(True, which="minor", alpha=0.1)
ax.minorticks_on()
plt.tight_layout()
plt.show()
<Figure size 1000x400 with 1 Axes>

4. Random Fourier Features

For the RBF kernel with lengthscale \ell, we sample ωN(0,2I)\omega \sim \mathcal{N}(0, \ell^{-2} I) and bUniform(0,2π)b \sim \mathrm{Uniform}(0, 2\pi).

k_rff = jax.random.PRNGKey(99)
D_rff = 500

k_omega, k_b = jax.random.split(k_rff)
omega = jax.random.normal(k_omega, (D_rff, 1)) / lengthscale  # (D_rff, D=1)
b = jax.random.uniform(k_b, (D_rff,), minval=0.0, maxval=2.0 * jnp.pi)

# X must be (N, D) -- reshape 1D inputs
X_2d = x_train[:, None]  # (N, 1)

K_rff = gaussx.rff_operator(X_2d, omega, b)
print(f"RFF operator type: {type(K_rff).__name__}")
RFF operator type: LowRankUpdate
K_rff_dense = K_rff.as_matrix()
frob_rff = jnp.linalg.norm(K_full - K_rff_dense, "fro")
frob_rff_rel = frob_rff / jnp.linalg.norm(K_full, "fro")

print(f"RFF Frobenius error (D={D_rff}): {frob_rff:.4f}")
print(f"RFF relative error:              {frob_rff_rel:.4f}")
RFF Frobenius error (D=500): 24.3464
RFF relative error:              0.1189

Convergence with increasing feature dimension

As DrffD_{\mathrm{rff}} grows the approximation improves at O(1/D)O(1 / \sqrt{D}) rate.

D_values = [100, 500, 1000]
errors = []

for D in D_values:
    k_o, k_bb = jax.random.split(jax.random.PRNGKey(D))
    om = jax.random.normal(k_o, (D, 1)) / lengthscale
    bb = jax.random.uniform(k_bb, (D,), minval=0.0, maxval=2.0 * jnp.pi)
    K_approx = gaussx.rff_operator(X_2d, om, bb).as_matrix()
    err = jnp.linalg.norm(K_full - K_approx, "fro") / jnp.linalg.norm(K_full, "fro")
    errors.append(float(err))
    print(f"  D_rff = {D:5d}  |  relative error = {err:.4f}")
  D_rff =   100  |  relative error = 0.3016
  D_rff =   500  |  relative error = 0.0947
  D_rff =  1000  |  relative error = 0.0488
fig, ax = plt.subplots(figsize=(10, 4))
ax.loglog(D_values, errors, "o-", color="C2", lw=2)
ax.set_xlabel("Number of random features (D)")
ax.set_ylabel("Relative Frobenius error")
ax.set_title("RFF Convergence")
ax.grid(True, which="major", alpha=0.3)
ax.grid(True, which="minor", alpha=0.1)
ax.minorticks_on()
plt.tight_layout()
plt.show()
<Figure size 1000x400 with 1 Axes>

5. Kernel Centering

Centering a kernel matrix is required for kernel PCA and many kernel methods. The centered kernel is K~=HKH\tilde{K} = H K H where H=I1n11H = I - \frac{1}{n}\mathbf{1}\mathbf{1}^\top is the centering matrix.

n = n_train

# Build the centering operator H = I - (1/n) 11^T
H_op = gaussx.centering_operator(n)
print(f"Centering operator type: {type(H_op).__name__}")

H_dense = H_op.as_matrix()
print(f"H shape: {H_dense.shape}")

# Verify it is I - (1/n) 11^T
H_manual = jnp.eye(n) - jnp.ones((n, n)) / n
print(f"H matches manual: {jnp.allclose(H_dense, H_manual, atol=1e-12)}")
Centering operator type: LowRankUpdate
H shape: (500, 500)
H matches manual: True

Center the kernel matrix

K_centered_op = gaussx.center_kernel(K_full_op)
K_centered = K_centered_op.as_matrix()

# Manual centering: HKH
K_centered_manual = H_manual @ K_full @ H_manual

print(f"Centered kernel type: {type(K_centered_op).__name__}")
print(f"Matches manual HKH:  {jnp.allclose(K_centered, K_centered_manual, atol=1e-10)}")
print(f"Centered row sums ~ 0: {jnp.abs(K_centered.sum(axis=1)).max():.2e}")
Centered kernel type: MatrixLinearOperator
Matches manual HKH:  True
Centered row sums ~ 0: 8.26e-14

6. HSIC: Independence Testing

The Hilbert-Schmidt Independence Criterion measures dependence between two random variables via their kernel embeddings:

HSIC(F,G)=1n2tr(KfHKgH)\mathrm{HSIC}(F, G) = \frac{1}{n^2} \operatorname{tr}(K_f H K_g H)

HSIC =0= 0 if and only if FF and GG are independent (in the population, with characteristic kernels). We compare a correlated pair (x,y=x+ϵ)(x, y = x + \epsilon) with an independent pair (x,z)(x, z).

key_hsic = jax.random.PRNGKey(7)
k1, k2, k3 = jax.random.split(key_hsic, 3)

n_hsic = 200
x_hsic = jax.random.normal(k1, (n_hsic,))

# Correlated pair: y depends on x
y_corr = x_hsic + 0.3 * jax.random.normal(k2, (n_hsic,))

# Independent pair: z is independent of x
z_indep = jax.random.normal(k3, (n_hsic,))

# Build kernel matrices (RBF, lengthscale=1)
K_x = rbf_kernel(x_hsic, x_hsic, 1.0, 1.0)
K_y_corr = rbf_kernel(y_corr, y_corr, 1.0, 1.0)
K_z_indep = rbf_kernel(z_indep, z_indep, 1.0, 1.0)

K_x_op = lx.MatrixLinearOperator(K_x)
K_y_op = lx.MatrixLinearOperator(K_y_corr)
K_z_op = lx.MatrixLinearOperator(K_z_indep)
hsic_corr = gaussx.hsic(K_x_op, K_y_op)
hsic_indep = gaussx.hsic(K_x_op, K_z_op)

print(f"HSIC (x, y_correlated):  {hsic_corr:.6f}   (should be > 0)")
print(f"HSIC (x, z_independent): {hsic_indep:.6f}   (should be ~ 0)")
HSIC (x, y_correlated):  0.055927   (should be > 0)
HSIC (x, z_independent): 0.002026   (should be ~ 0)

7. MMD: Two-Sample Testing

The Maximum Mean Discrepancy measures the distance between two distributions in a reproducing kernel Hilbert space:

MMD2(P,Q)=E[k(x,x)]+E[k(y,y)]2E[k(x,y)]\mathrm{MMD}^2(P, Q) = \mathbb{E}[k(x,x')] + \mathbb{E}[k(y,y')] - 2\,\mathbb{E}[k(x,y)]

We compare P=N(0,1)P = \mathcal{N}(0, 1) vs Q=N(1,1)Q = \mathcal{N}(1, 1) (different) and PP vs PP (same).

key_mmd = jax.random.PRNGKey(13)
k1, k2, k3 = jax.random.split(key_mmd, 3)

n_mmd = 200
x_p = jax.random.normal(k1, (n_mmd,))  # P = N(0, 1)
y_q = jax.random.normal(k2, (n_mmd,)) + 1.0  # Q = N(1, 1)
x_p2 = jax.random.normal(k3, (n_mmd,))  # Another sample from P

# Kernel matrices (raw arrays, not operators)
ls_mmd = 1.0
K_xx = rbf_kernel(x_p, x_p, ls_mmd, 1.0)
K_yy = rbf_kernel(y_q, y_q, ls_mmd, 1.0)
K_xy = rbf_kernel(x_p, y_q, ls_mmd, 1.0)

K_xx2 = rbf_kernel(x_p, x_p, ls_mmd, 1.0)
K_pp2 = rbf_kernel(x_p2, x_p2, ls_mmd, 1.0)
K_xp2 = rbf_kernel(x_p, x_p2, ls_mmd, 1.0)
mmd_diff = gaussx.mmd_squared(K_xx, K_yy, K_xy)
mmd_same = gaussx.mmd_squared(K_xx2, K_pp2, K_xp2)

print(f"MMD^2(P, Q) [different]: {mmd_diff:.6f}   (should be > 0)")
print(f"MMD^2(P, P) [same]:      {mmd_same:.6f}   (should be ~ 0)")
MMD^2(P, Q) [different]: 0.148538   (should be > 0)
MMD^2(P, P) [same]:      0.004874   (should be ~ 0)

8. Plots: Approximation Quality

A side-by-side comparison of the kernel matrix approximations.

fig, axes = plt.subplots(1, 3, figsize=(12, 3.5))

vmin, vmax = K_full.min(), K_full.max()
kw = dict(vmin=vmin, vmax=vmax, cmap="viridis", aspect="auto")

axes[0].imshow(K_full, **kw)
axes[0].set_title("Full Kernel")

axes[1].imshow(K_nys_dense, **kw)
axes[1].set_title(f"Nystrom (M={m_inducing})")

axes[2].imshow(K_rff_dense, **kw)
axes[2].set_title(f"RFF (D={D_rff})")

for ax in axes:
    ax.set_xlabel("Index")
    ax.set_ylabel("Index")

plt.tight_layout()
plt.show()
<Figure size 1200x350 with 3 Axes>

9. Summary

FunctionOutputCost
nystrom_operator(K_XZ, K_ZZ_op)LowRankUpdateO(NM2)O(NM^2)
rff_operator(X, omega, b)LowRankUpdateO(ND)O(ND)
center_kernel(K)MatrixLinearOperatorO(N2)O(N^2)
centering_operator(n)LowRankUpdateO(1)O(1)
hsic(K_f, K_q)scalarO(N2)O(N^2)
mmd_squared(K_xx, K_yy, K_xy)scalarO(N2)O(N^2)

Key takeaways:

  • Both Nystrom and RFF produce LowRankUpdate operators, so downstream solves and log-determinants exploit the Woodbury identity automatically.
  • center_kernel and centering_operator handle the centering boilerplate needed by kernel PCA and HSIC.
  • hsic and mmd_squared are one-line kernel statistics for independence and two-sample testing.