Efficient machinery for SSH pathwise sampling — a tour of gaussx + pyrox
Efficient machinery for SSH pathwise sampling¶
The companion note 00_ssh_pathwise_sampling.md writes the algorithm out as if you were going to implement it from scratch — dense Gram matrices, dense Cholesky, dense cross-covariance, hand-rolled RFF prior. That description makes the math obvious but it carries three concrete costs that get unpleasant the moment you push to realistic Mediterranean grids:
| Bottleneck in the naive write-up | Why it hurts |
|---|---|
| Materialise | at , — the dominant memory term |
| Cholesky of | — single-threaded LAPACK, ~40 s at |
| Per-sample matvec, repeated times | , plus the same buffer reused times |
This note walks through the gaussx and pyrox primitives that fix each of those, with pseudocode close enough to real Python that you could turn it into runnable code in an afternoon. The point is not to refactor the SSH pipeline today — it is to leave a paper trail of which existing primitive solves which step so future-me (or future-anyone) can pick the right tool without re-deriving the algorithm.
Throughout, the four “objects” from the 00 derivation keep their names: , , , and the prior path .
Machinery 1 — matrix-free kernels (kills the peak)¶
The cross-covariance is only ever multiplied against vectors in the algorithm: once for the posterior mean (), and once per sample for the correction (). There is no reason to allocate the dense block.
The primitive¶
gaussx.ImplicitCrossKernelOperator wraps any pure-function kernel k(x_i, z_j) -> scalar as a lineax.AbstractLinearOperator whose .mv(v) evaluates
via jax.lax.scan over rows of , in batches of size batch_size. Peak memory per matvec is — set batch_size to fit in cache and you are done. It carries a custom_jvp so autodiff through hyperparameters stays cheap (vmap-vectorised JVP, transposes cleanly into a VJP).
The square sibling, gaussx.ImplicitKernelOperator, does the same trick for and accepts an optional noise_var=σ² so the noisy Gram is one operator with no extra allocation.
Pseudocode¶
import gaussx as gx
import lineax as lx
def k_st(x, z, sigma2, Ls, Lt):
"""One scalar evaluation of σ²·k_s·k_t. Pure function — no batching."""
d_space = great_circle(x[:2], z[:2])
d_time = jnp.abs(x[2] - z[2])
spatial = matern32(d_space, Ls)
temporal = jnp.exp(-d_time / Lt)
return sigma2 * spatial * temporal
# Noisy Gram C = K_XX + σ²I — never materialised.
C_op = gx.ImplicitKernelOperator(
kernel_fn = lambda xi, xj: k_st(xi, xj, sigma2, Ls, Lt),
X = X_train, # (m, 3)
noise_var = sigma_obs ** 2,
tags = lx.positive_semidefinite_tag,
)
# Cross-covariance K_{X*, X} — never materialised.
Kxs_op = gx.ImplicitCrossKernelOperator(
kernel_fn = lambda xi, zj: k_st(xi, zj, sigma2, Ls, Lt),
X_data = X_grid, # (n, 3) — the prediction inputs
X_inducing = X_train, # (m, 3)
batch_size = 4096, # rows of X_grid per scan step
)
mean_field = Kxs_op.mv(alpha) # (n,) — Step 5 of the 00 algorithm
correction = Kxs_op.mv(beta) # (n,) — Step 10, called once per sampleWhat you save¶
| Object | Naive | Matrix-free |
|---|---|---|
| Peak memory for cross-cov | B = | B ≈ at batch_size=4096 |
| Peak memory for Gram | B = | B ≈ per row |
| Per-matvec flops | — same |
The flop count does not change — but the working set goes from “blows your laptop” to “comfortably fits in L3 cache”.
Machinery 2 — avoiding the Cholesky¶
You have two routes, depending on whether you want an exact solve against the dense Gram or an approximate solve against a low-rank surrogate.
Route A — preconditioned CG against the implicit Gram (exact)¶
gaussx.PreconditionedCGSolver builds a rank- pivoted partial Cholesky preconditioner via matfree.low_rank.cholesky_partial_pivot and feeds it into lineax.CG. For — the textbook target of preconditioned GP CG — convergence is essentially independent of once the preconditioner captures the leading spectrum. Each CG iteration is one -matvec ⇒ one ImplicitKernelOperator matvec ⇒ no dense matrix ever exists.
solver = gx.PreconditionedCGSolver(
preconditioner_rank = 50, # rank-50 partial Cholesky
shift = sigma_obs ** 2, # the σ²I in C — used by Woodbury inside the precond
rtol = 1e-6,
max_steps = 200, # 50–200 iters typical
lanczos_order = 30, # only used for SLQ logdet, irrelevant here
)
alpha = gx.solve(C_op, y, solver=solver) # (m,) — Step 3 of the 00 algorithm
# Per-sample correction also reuses the same solver:
beta = gx.solve(C_op, y - f_tilde_X, solver=solver)Cost: flops, memory. With n_iters ≈ 100 and that is flops vs. for dense Cholesky — about 15× faster, and you skipped the Gram allocation along the way.
Route B — RFF + Woodbury (approximate, but very cheap when )¶
The 00 note is already going to compute RFF features for the prior path. Reuse them for the Gram itself: by Bochner’s theorem where . Then , which is exactly the structure that gaussx.LowRankUpdate is built for, and gx.solve dispatches Woodbury on it automatically.
gaussx.rff_operator constructs this in one call:
import jax.random as jr
# Spectral frequencies for σ²·k_s·k_t — separable kernel ⇒ ω = (ω_s, ω_t)
omega = sample_spectral_frequencies(key_omega, n_features=2000) # (r, 3)
phase = jr.uniform(key_phase, (2000,), maxval=2 * jnp.pi)
K_lr = gx.rff_operator(X_train, omega=omega, b=phase) # K ≈ ΦΦᵀ as LowRankUpdate
C_lr = K_lr + lx.IdentityLinearOperator(in_struct, in_struct) * sigma_obs ** 2
# (or: gx.low_rank_plus_diag(...) — see __init__.py for the helper)
alpha = gx.solve(C_lr, y) # Woodbury under the hoodCost: for the inner solve, memory. With , : flops — comparable to dense Cholesky on this size, but it scales as instead of , so it pulls ahead fast as grows.
gaussx.nystrom_operator is the inducing-point cousin if uniformly subsampled altimeter tracks make a better basis than RFF — same LowRankUpdate output, same Woodbury dispatch.
Machinery 3 — the whole Matheron loop is already in pyrox¶
pyrox.gp.PathwiseSampler implements Steps 6–10 of the 00 algorithm verbatim. It combines:
pyrox._basis._rff.draw_rff_cosine_basis— draws(variance, lengthscale, ω, phase, weights)from the kernel’s spectral density, supports RBF and Matérn (any ν — including used here);evaluate_rff_cosine_paths— evaluates the prior path at any , vectorised overn_paths;- a frozen
(X1, X2) -> K(X1, X2)callable that bakes in the same hyperparameter draw used for the RFF, so the correction stays consistent.
The result is a PathwiseFunction you call at any — the per-call cost is per path.
Pseudocode — drop-in replacement for Steps 6–10¶
from pyrox.gp import GPPrior, PathwiseSampler, Matern
# Build a Matern-3/2 kernel on (lon, lat, t). Could also be a product of
# spatial-Matern × temporal-OU (Matern-1/2) — pyrox supports kernel_mul.
kernel = Matern(nu=1.5, lengthscale=Ls, variance=sigma_eta_2)
prior = GPPrior(kernel=kernel, X=X_train) # ConditionedGP carries C internally
posterior = prior.condition(y, noise_var=sigma_obs ** 2)
sampler = PathwiseSampler(posterior, n_features=2000)
paths = sampler.sample_paths(key, n_paths=100) # PathwiseFunction
samples = paths(X_grid) # (100, n) — Step 10That is the entire pathwise loop in five lines.
Caveat — swap the inner solver¶
PathwiseSampler currently uses gaussx.cholesky for the inside Matheron’s correction. To combine wins #2 and #3 you would either (a) use Route B (low-rank ) so the Cholesky is cheap, or (b) pass a different solver into the sampler. The latter is a small monkeypatch / fork — flag it if you ever need .
Machinery 4 — pointwise posterior variance without the row-loop¶
The 00 note flags pointwise variance at — one solve per test point. gaussx.love_cache (LanczOs Variance Estimates, Pleiss et al. 2018) precomputes a rank- Lanczos factorisation of once, after which every test-point variance costs instead of .
cache = gx.love_cache(C_op, lanczos_order=50) # one-time, ~50 CG-matvecs
def post_var_at(x_star):
k_star_row = jax.vmap(lambda xj: k_st(x_star, xj, sigma2, Ls, Lt))(X_train) # (m,)
return sigma_eta_2 - gx.love_variance(cache, k_star_row)
post_var_map = jax.vmap(post_var_at)(X_grid) # (n,)For , , : cost drops from flops (one solve per test point) to — about 100× cheaper. Use the empirical sample variance from when you only need a few samples; switch to LOVE when you specifically want a smooth, non-Monte-Carlo-noisy uncertainty map.
Two structural exploits the libraries don’t do for you¶
These are not in gaussx/pyrox as primitives but they compose with the operators above and are worth coding by hand for an SSH-shaped problem.
Exploit A — temporal factor of has only 3 distinct values¶
All prediction points share time , so for any observation in time group , , or the temporal weight is one of three scalars: , 1, . Build spatial-only implicit cross operators per time group and combine with scalar weights at matvec time:
def make_block(X_train_group, weight):
spatial_op = gx.ImplicitCrossKernelOperator(
kernel_fn = lambda x, z: sigma2 * matern32_great_circle(x[:2], z[:2], Ls),
X_data = X_grid_xy, # (n, 2) — drop time axis
X_inducing = X_train_group, # (m_group, 2)
batch_size = 4096,
)
return weight, spatial_op
w_minus = jnp.exp(-tau / Lt)
blocks = [
make_block(X_train_minus_xy, w_minus),
make_block(X_train_zero_xy, 1.0),
make_block(X_train_plus_xy, w_minus),
]
def Kxs_mv(beta): # beta is partitioned to match groups
out = 0.0
for (w, op), b_group in zip(blocks, beta_partition):
out = out + w * op.mv(b_group)
return outThis halves the kernel-eval count (one trig stack per spatial pair, not one per spatiotemporal pair) and lets you cache the spatial cross-covs across re-fits with different — handy when you sweep temporal length scales during hyperparameter learning.
Exploit B — the block structure of ¶
The training inputs split into 3 time groups, so is a block matrix whose block is
Six unique spatial blocks (3 diagonal + 3 off-diagonal); the temporal weights are scalars. Build the spatial blocks once with ImplicitKernelOperator / ImplicitCrossKernelOperator, wrap with a temporally-weighted block matvec, and you avoid of the redundant great-circle evaluations on every Gram matvec inside CG.
This composes with gaussx.BlockDiag for the within-group blocks, with the off-diagonal cross-blocks added back as scaled implicit operators inside a custom matvec. Worth the effort once you push to many time groups (a sliding window with days, say).
Putting it all together — the efficient SSH algorithm¶
The same ten-step algorithm from 00, rewritten with the machinery above. Storage is now throughout — never , never .
Offline (once per time window)¶
# --- Choose a structured representation of C ---
# Route A (exact, preferred when m ≲ 5e4 and you have time for ~100 CG iters):
C_op = gx.ImplicitKernelOperator(k_st_curried, X_train, noise_var=sigma_obs ** 2,
tags=lx.positive_semidefinite_tag)
solver = gx.PreconditionedCGSolver(preconditioner_rank=50, shift=sigma_obs ** 2,
rtol=1e-6, max_steps=200)
# Route B (low-rank, preferred when m ≫ 1e4 and r ≪ m):
# K_lr = gx.rff_operator(X_train, omega=omega, b=phase)
# C_lr = K_lr + identity(m) * sigma_obs ** 2 # LowRankUpdate ⇒ Woodbury
# solver = None # gx.solve auto-dispatches Woodbury
# --- Step 3: solve for the dual weights α ---
alpha = gx.solve(C_op, y, solver=solver) # (m,)
# --- Steps 4–5: matrix-free posterior mean field ---
Kxs_op = gx.ImplicitCrossKernelOperator(
kernel_fn = k_st_curried,
X_data = X_grid, # (n, 3)
X_inducing = X_train, # (m, 3)
batch_size = 4096,
)
mu_field = Kxs_op.mv(alpha) # (n,) — never materialise (n, m)Per posterior sample¶
# --- Steps 6–7: RFF prior path at training + grid (one consistent draw) ---
variance, lengthscale, omega, phase, weights = pyrox.rff.draw_rff_cosine_basis(
kernel = matern_3_2,
key = sample_key,
n_paths = 1, n_features=2000, in_features=3, dtype=jnp.float64,
)
f_tilde_X = pyrox.rff.evaluate_rff_cosine_paths(
X_train, variance=variance, lengthscale=lengthscale,
omega=omega, phase=phase, weights=weights)[0]
f_tilde_grid = pyrox.rff.evaluate_rff_cosine_paths(
X_grid, variance=variance, lengthscale=lengthscale,
omega=omega, phase=phase, weights=weights)[0]
# --- Step 8–9: innovation + correction solve (reuses solver) ---
delta = y - f_tilde_X # (m,)
beta = gx.solve(C_op, delta, solver=solver) # (m,) — same CG, same precond
# --- Step 10: matrix-free correction ---
eta_sample = f_tilde_grid + Kxs_op.mv(beta) # (n,) — one exact posterior drawOr, equivalently, the four-line version using the pre-built sampler:
posterior = GPPrior(kernel=matern_3_2_x_OU, X=X_train).condition(y, sigma_obs ** 2)
sampler = PathwiseSampler(posterior, n_features=2000)
paths = sampler.sample_paths(key, n_paths=S)
samples = paths(X_grid) # (S, n)Pointwise variance map (optional, when sample-variance is too noisy)¶
cache = gx.love_cache(C_op, lanczos_order=50)
post_var_map = jax.vmap(
lambda x_star: sigma_eta_2 - gx.love_variance(
cache,
jax.vmap(lambda xj: k_st(x_star, xj, sigma2, Ls, Lt))(X_train),
)
)(X_grid) # (n,)Cost recap¶
Same complexity table as 00, side-by-side with what the machinery delivers.
| Phase | Naive (00) | With machinery (here) |
|---|---|---|
| Build + Cholesky | memory, flops | memory, flops via PCG (Route A); or + via Woodbury (Route B) |
| Build + posterior mean | memory, flops | memory, flops |
| posterior samples | memory, flops | same flops, but cross-cov memory shrinks from to — and the solve is the cheap PCG/Woodbury solve, not a fresh Cholesky |
| Pointwise variance | flops | flops via LOVE, |
The memory bound drops from to — i.e. from 4 GB to a few hundred MB at SSH-realistic scales. The flop bound drops from to with the same constants. Both wins compound: matrix-free + PCG means you can push to (a full Mediterranean month) on a single GPU, where the naive code would OOM long before it finished its first Cholesky.
Wall-clock estimates for realistic SSH reconstructions¶
Order-of-magnitude time budgets to reconstruct daily SSH fields over the Mediterranean Sea, North Atlantic, and global ocean, for both an analysis-day window (operational, 2 time groups) and a reanalysis-day window (3 time groups, lookahead allowed). All numbers below are accurate to a factor of — they exist to flag which configurations are tractable on what hardware, not to commit to a specific runtime.
Daily observation counts (Copernicus Marine Service catalogue, 2024–2026)¶
CMEMS publishes the along-track and SWOT L3 streams as SEALEVEL_GLO_PHY_L3_*_OBSERVATIONS_008_* (NRT and reprocessed multi-mission), the SWOT KaRIn 7.5 km L3 as SEALEVEL_GLO_PHY_L3_MY_008_069, and tide-gauge SSH as INSITU_GLO_PHY_SSH_DISCRETE_NRT_013_059. Post-QC daily counts:
| Stream | Per-satellite raw | After QC | Active in 2024–2026 |
|---|---|---|---|
| Nadir altimeters (1 Hz) | /day/sat | /day/sat | S3A, S3B, S6-MF, SARAL, CryoSat-2, HY-2B/C — 6 sats ⇒ /day globally |
| SWOT KaRIn (2 km native, 7.5 km L3) | /day | /day | 1 mission, operational since 2023 |
| In-situ SSH (tide gauges, GLOSS) | — | /day | Negligible vs altimetry; anchors absolute datum |
| Combined total (post-QC) | /day globally | SWOT dominates |
Region fractions (area-weighted):
| Region | Area | Fraction | Daily obs (with SWOT) | Daily obs (nadir-only, archived years) |
|---|---|---|---|---|
| Mediterranean Sea | ||||
| North Atlantic | ||||
| Global ocean |
Per-day problem sizes¶
With –3 days and obs grouped by time-band:
| Region | Grid ( res) | analysis (2 groups, with SWOT) | reanalysis (3 groups, with SWOT) | reanalysis (nadir-only) |
|---|---|---|---|---|
| Mediterranean | 105 | |||
| North Atlantic | 105 | |||
| Global ocean | 106 |
Per-day compute, two regimes¶
Two implementation paths from above, with RFF features and posterior samples per day. The “fully-RFF” column uses RFF for both the prior path and the cross-covariance — the per-sample correction matvec then collapses from to , removing from the inner loop entirely.
| Region (reanalysis day, with SWOT) | Exact GP via PCG (Route A) — flops/day | Fully-RFF Woodbury (Route B+) — flops/day |
|---|---|---|
| Mediterranean (, ) | ||
| North Atlantic (, ) | ||
| Global ocean (, ) | — infeasible |
Effective sustained throughput on common targets: a modern Xeon (32 cores, MKL) hits f64 on these matvec/scan kernels; an A100 GPU hits f64 (well below peak because the workload is bandwidth-bound). Per-day wall-clock follows directly:
| Region | Route A — A100 / day | Route A — CPU / day | Fully-RFF — A100 / day | Fully-RFF — CPU / day |
|---|---|---|---|---|
| Mediterranean | ||||
| North Atlantic | ||||
| Global | weeks |
Analysis day (2 time groups instead of 3) reduces by , so per-day cost drops by for Route A (the term) and for Route B+. Numbers below take this into account.
Full-window wall-clock¶
Multiply per-day by 30 / 90 / 180. Reanalysis with the full SWOT-era constellation, on a single A100:
| Region | 1 month (30 d) | 3 months (90 d) | 6 months (180 d) |
|---|---|---|---|
| Mediterranean | Route A: · RFF+: | Route A: · RFF+: | Route A: · RFF+: |
| North Atlantic | Route A: · RFF+: | Route A: · RFF+: | Route A: · RFF+: |
| Global | Route A: infeasible · RFF+: | RFF+: | RFF+: |
Same totals on a 32-core CPU (Route A column for global is dropped — even with RFF+ it would take days at exact-GP scales):
| Region | 1 month | 3 months | 6 months |
|---|---|---|---|
| Mediterranean | Route A: · RFF+: | Route A: · RFF+: | Route A: · RFF+: |
| North Atlantic | Route A: · RFF+: | Route A: · RFF+: | Route A: · RFF+: |
| Global | RFF+: | RFF+: | RFF+: |
Analysis day (operational, single window per run) is roughly half these per-day numbers — Mediterranean analysis is sub-second on any hardware; global analysis with RFF+ is a few seconds per A100 day.
Where things break¶
- Route A on global SWOT-era data is hopeless on a single node. pushes the per-CG-iter cost ( via implicit-kernel scan) into the 1013-flop range, with iterations per solve and solves per day. You end up at flops/day, which is days-per-day on an A100. Either drop SWOT (reverts to the nadir column where Route A is viable for NA but still tight for global), reduce resolution, switch to inducing-point sparse GPs (
gaussx.nystrom_operator+LowRankUpdate), or move to a state-space SPDE formulation (which gets you sparse precision matrices and smoothing, at the cost of requiring an isotropic Matérn kernel and more setup work). - RFF+ is the right default for North Atlantic and global. The accuracy loss vs exact GP is in the kernel approximation; for SSH-mesoscale fields and this is well below the altimeter noise floor.
- Hyperparameter learning is not in these numbers. Estimating via marginal-likelihood gradient descent multiplies the cost by –100 optimiser steps. Do this once on a representative window, then freeze the hyperparameters for the full reanalysis pass.
- I/O is not in these numbers. At obs/day SWOT-era, the download from CMEMS at typical link speeds (~) is comparable to or larger than the compute on Mediterranean/NA. Pre-stage the data; do not re-fetch per day.
- Posterior variance maps via LOVE add – to the per-day budget — cheap when needed, skip when the empirical sample variance from suffices.