Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

GP on a 2D Grid with Kronecker Structure

When data lives on a regular grid, the kernel matrix factorizes as K=KxKyK = K_x \otimes K_y. This enables GP inference on grids with thousands of points that would be intractable with a dense kernel.

This example fits a 2D GP to noisy observations on a grid using Kronecker structure for solve, logdet, and prediction.

Background

When inputs lie on a Cartesian product grid X={x1,,xn1}×{y1,,yn2}\mathcal{X} = \{x_1, \ldots, x_{n_1}\} \times \{y_1, \ldots, y_{n_2}\} and the kernel is separable,

k((x,y),(x,y))=kx(x,x)ky(y,y),k\bigl((x,y),(x',y')\bigr) = k_x(x,x')\, k_y(y,y'),

the kernel matrix factorizes as a Kronecker product K=KxKyK = K_x \otimes K_y where KxRn1×n1K_x \in \mathbb{R}^{n_1 \times n_1} and KyRn2×n2K_y \in \mathbb{R}^{n_2 \times n_2}. With N=n1n2N = n_1 n_2, this reduces:

  • Storage from O(N2)O(N^2) to O(n12+n22)O(n_1^2 + n_2^2).
  • Solve from O(N3)O(N^3) to O(n13+n23)O(n_1^3 + n_2^3).
  • Log-determinant via logKxKy=n2logKx+n1logKy\log|K_x \otimes K_y| = n_2 \log|K_x| + n_1 \log|K_y|.

These savings make GP inference feasible on grids with thousands of points that would be intractable with a dense kernel matrix. See Saatci (2012) for a comprehensive treatment and Gilboa et al. (2015) for extensions to higher-dimensional grids.

from __future__ import annotations

import warnings


warnings.filterwarnings("ignore", message=r".*IProgress.*")

import jax
import jax.numpy as jnp
import lineax as lx
import matplotlib.pyplot as plt

import gaussx


jax.config.update("jax_enable_x64", True)

Generate 2D grid data

# Grid
nx, ny = 15, 15
N = nx * ny
x = jnp.linspace(0, 4, nx)
y = jnp.linspace(0, 4, ny)
xx, yy = jnp.meshgrid(x, y, indexing="ij")


# True function: a smooth 2D surface
def f_true(x, y):
    return jnp.sin(x) * jnp.cos(1.5 * y)


z_true = f_true(xx, yy)

# Noisy observations
noise_std = 0.15
key = jax.random.PRNGKey(7)
z_obs = z_true + noise_std * jax.random.normal(key, z_true.shape)

print(f"Grid: {nx}x{ny} = {N} points")
Grid: 15x15 = 225 points
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
im0 = axes[0].pcolormesh(xx, yy, z_true, cmap="RdBu_r", shading="auto")
axes[0].set_title("True function")
plt.colorbar(im0, ax=axes[0])

im1 = axes[1].pcolormesh(xx, yy, z_obs, cmap="RdBu_r", shading="auto")
axes[1].set_title(f"Noisy observations ($\\sigma$={noise_std})")
plt.colorbar(im1, ax=axes[1])

for ax in axes:
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_aspect("equal")

plt.tight_layout()
plt.show()
<Figure size 1000x400 with 4 Axes>

Build Kronecker kernel

Separable RBF: k((x1,y1),(x2,y2))=kx(x1,x2)ky(y1,y2)k((x_1,y_1), (x_2,y_2)) = k_x(x_1,x_2) \cdot k_y(y_1,y_2)

Separability holds for any product of 1D stationary kernels (RBF, Matern, periodic). For non-separable kernels, sums of Kronecker products rKx(r)Ky(r)\sum_r K_x^{(r)} \otimes K_y^{(r)} can be used as approximations.

def rbf_1d(coords, lengthscale, variance):
    sq_dist = (coords[:, None] - coords[None, :]) ** 2
    return variance * jnp.exp(-0.5 * sq_dist / lengthscale**2)


ls_x, ls_y = 1.0, 1.0
var_x, var_y = 1.0, 1.0
noise_var = noise_std**2

Kx = rbf_1d(x, ls_x, var_x)
Ky = rbf_1d(y, ls_y, var_y)

Kx_op = lx.MatrixLinearOperator(Kx, lx.positive_semidefinite_tag)
Ky_op = lx.MatrixLinearOperator(Ky, lx.positive_semidefinite_tag)

# Full kernel: Kx kron Ky + noise * I
K_kron = gaussx.Kronecker(Kx_op, Ky_op)

print(f"Kronecker kernel: {K_kron.in_size()}x{K_kron.out_size()}")
print(f"Storage: {nx}x{nx} + {ny}x{ny} = {nx**2 + ny**2}")
print(f"Dense would be: {N}x{N} = {N**2:,}")
Kronecker kernel: 225x225
Storage: 15x15 + 15x15 = 450
Dense would be: 225x225 = 50,625

Add noise and solve

We need (KxKy+σ2I)α=zobs(K_x \otimes K_y + \sigma^2 I) \alpha = z_\mathrm{obs}. For now we add noise to the dense matrix; a full implementation would keep this structured.

K_noisy = K_kron.as_matrix() + noise_var * jnp.eye(N)
op = lx.MatrixLinearOperator(K_noisy, lx.positive_semidefinite_tag)

# Flatten observations for solve
z_vec = z_obs.ravel()
alpha = gaussx.solve(op, z_vec)

# Posterior mean on the same grid
z_pred = (K_kron.as_matrix() @ alpha).reshape(nx, ny)

print(f"Solve residual: {jnp.max(jnp.abs(op.mv(alpha) - z_vec)):.2e}")
Solve residual: 3.15e-14

Log-marginal likelihood

The log-marginal likelihood (LML) for GP regression is

logp(zX)=12z(K+σ2I)1z12logK+σ2IN2log2π.\log p(\mathbf{z} \mid X) = -\tfrac{1}{2}\mathbf{z}^\top (K + \sigma^2 I)^{-1}\mathbf{z} - \tfrac{1}{2}\log|K + \sigma^2 I| - \tfrac{N}{2}\log 2\pi.

Note that the noise term σ2I\sigma^2 I breaks the pure Kronecker structure of KxKyK_x \otimes K_y. However, eigendecomposition-based approaches (Saatci, 2012) or Krylov methods can handle KxKy+σ2IK_x \otimes K_y + \sigma^2 I efficiently by working with the per-factor eigenvalues.

def lml(z, op):
    alpha = gaussx.solve(op, z)
    n = len(z)
    ld = gaussx.logdet(op)
    return -0.5 * jnp.dot(z, alpha) - 0.5 * ld - 0.5 * n * jnp.log(2 * jnp.pi)


print(f"Log-marginal likelihood: {lml(z_vec, op):.2f}")
Log-marginal likelihood: 55.04

Visualize results

fig, axes = plt.subplots(1, 3, figsize=(14, 4))

im0 = axes[0].pcolormesh(xx, yy, z_obs, cmap="RdBu_r", shading="auto")
axes[0].set_title("Observations")
plt.colorbar(im0, ax=axes[0])

im1 = axes[1].pcolormesh(xx, yy, z_pred, cmap="RdBu_r", shading="auto")
axes[1].set_title("GP posterior mean")
plt.colorbar(im1, ax=axes[1])

im2 = axes[2].pcolormesh(xx, yy, jnp.abs(z_true - z_pred), cmap="Reds", shading="auto")
axes[2].set_title("|True - Predicted|")
plt.colorbar(im2, ax=axes[2])

for ax in axes:
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_aspect("equal")

plt.tight_layout()
plt.show()
<Figure size 1400x400 with 6 Axes>

Kronecker-only operations

Even though we fell back to dense for the noisy solve, the Kronecker structure enables cheap operations on the kernel itself.

print("logdet(Kx kron Ky):")
print(f"  Structured: {gaussx.logdet(K_kron):.4f}")
print(f"  Dense:      {jnp.linalg.slogdet(K_kron.as_matrix())[1]:.4f}")

print("\ntrace(Kx kron Ky):")
print(f"  Structured: {gaussx.trace(K_kron):.4f}")
print(f"  Dense:      {jnp.trace(K_kron.as_matrix()):.4f}")

L_kron = gaussx.cholesky(K_kron)
print(f"\ncholesky type: {type(L_kron).__name__}")
logdet(Kx kron Ky):
  Structured: -4064.3464
  Dense:      -3904.5906

trace(Kx kron Ky):
  Structured: 225.0000
  Dense:      225.0000

cholesky type: Kronecker

References

  • Gilboa, E., Saatci, Y., & Cunningham, J. P. (2015). Scaling multidimensional inference for structured Gaussian processes. IEEE TPAMI, 37(2), 424--436.
  • Rasmussen, C. E. & Williams, C. K. I. (2006). Gaussian Processes for Machine Learning. MIT Press.
  • Saatci, Y. (2012). Scalable Inference for Structured Gaussian Process Models. PhD thesis, University of Cambridge.