Deep Random Feature Expansions — Stacked Spectral GPs (Cutajar et al. 2017)¶
This notebook is the deep companion to the RFF → SSGP → VSSGP notebook. The single-layer methods covered there can only represent functions in the RKHS of one stationary kernel. Stacking gives non-stationary, hierarchical representations — long-range trends and short-range texture in the same model. Cutajar, Bonilla, Michiardi & Filippone (ICML 2017) made this practical by stacking RFF layers and fitting them with the doubly-stochastic reparameterised ELBO.
Three deep methods, each the depth-$L$ analogue of one single-layer method from the companion notebook:
| Single layer | This notebook | $\Omega_l$ | $W_l$ | Objective |
|---|---|---|---|---|
| Learned RFF MAP | Deep RFF MAP | trained (point) | trained (point) | regularised MSE |
| SSGP | Deep SSGP | trained (point) per layer | trained (point) for $l < L$, marginalised at $l = L$ | log marginal likelihood |
| VSSGP | Deep VSSGP | variational $q(\Omega_l)$ | variational $q(W_l)$ | tempered ELBO |
The architecture is the same in all three:
$$F_0 = X, \qquad F_{l+1} = \Phi_l(F_l;\, \Omega_l, \ell_l)\, W_l, \qquad l = 0, \dots, L-1,$$
with $\Phi_l(F; \Omega_l, \ell_l) = \sqrt{1/M}\,[\cos(F\Omega_l/\ell_l), \sin(F\Omega_l/\ell_l)] \in \mathbb{R}^{N \times 2M}$.
Forward-pointers: the single-layer versions live in the SSGP notebook; the fixed-Ω + ensemble alternatives are in the RFF as Neural Networks notebook.
When might depth help, and at what cost?¶
A single-layer SSGP/VSSGP is a stationary model: the kernel is the same function of $x - x'$ everywhere in input space. The linear head $\beta$ is a constant set of weights on the features $\Phi(x)$ — it cannot say "here, weight the high-frequency features more; over there, weight the low-frequency ones." For some non-stationary targets, this is a real limitation.
Stacking $L$ RFF blocks builds an amortised kind of non-stationarity: the input to layer $l+1$ is itself a learned nonlinear transform of $x$, so the effective kernel in the deep model is position-dependent. The canonical demonstration target — used here and in Damianou & Lawrence (AISTATS 2013) — is the compositional warp
$$f(x) = \sin\!\bigl(3\pi\,\sin(\pi x)\bigr).$$
Its instantaneous frequency $|f'(x)| = 3\pi^2\,|\cos(\pi x)|$ varies from $\sim 30$ at $x = 0$ to $0$ at $x = \pm \tfrac{1}{2}$ and back to $\sim 30$ at $x = \pm 1$. A single-layer model with one global lengthscale must compromise; a depth-$2$ model can in principle learn $\sin(\pi x)$ in the first layer and $\sin(3\pi z)$ in the second.
The catch. The per-layer Gaussian conjugacy of single-layer SSGP is broken by stacking — the input to layer $l$ is itself a stochastic function of $\Omega_{0:l-1}, W_{0:l-1}$, so the marginal likelihood at the top is no longer Gaussian. We need either (a) a deterministic deep features stack with a closed-form SSGP only on the readout (deep SSGP), or (b) a fully variational doubly-stochastic ELBO over all $(\Omega_l, W_l)$ (deep VSSGP). The second is harder to optimise — see §3.
Honest expectation. Depth-2 with modest hidden width on a 1D toy is not always a clear win over a well-tuned single-layer SSGP (covered in the SSGP notebook) — for moderate problems the single-layer methods are usually enough. Treat this notebook as a recipe for building the deep stack; whether you actually need it for your problem should be settled by benchmarking against the single-layer baseline.
Setup¶
import subprocess
import sys
try:
import google.colab # noqa: F401
IN_COLAB = True
except ImportError:
IN_COLAB = False
if IN_COLAB:
subprocess.run(
[
sys.executable,
"-m",
"pip",
"install",
"-q",
"pyrox[colab] @ git+https://github.com/jejjohnson/pyrox@main",
],
check=True,
)
import warnings
warnings.filterwarnings("ignore", message=r".*IProgress.*")
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import optax
jax.config.update("jax_enable_x64", True)
import importlib.util
try:
from IPython import get_ipython
ipython = get_ipython()
except ImportError:
ipython = None
if ipython is not None and importlib.util.find_spec("watermark") is not None:
ipython.run_line_magic("load_ext", "watermark")
ipython.run_line_magic(
"watermark",
"-v -m -p jax,equinox,numpyro,pyrox,matplotlib",
)
else:
print("watermark extension not installed; skipping reproducibility readout.")
Python implementation: CPython Python version : 3.13.5 IPython version : 9.10.0 jax : 0.9.2 equinox : 0.13.6 numpyro : 0.20.1 pyrox : 0.0.6 matplotlib: 3.10.8 Compiler : GCC 11.2.0 OS : Linux Release : 6.8.0-1044-azure Machine : x86_64 Processor : x86_64 CPU cores : 16 Architecture: 64bit
Dataset — Damianou-Lawrence compositional warp¶
We use the canonical 1D deep-GP demo target
$$f^\star(x) = \sin\!\bigl(3\pi\,\sin(\pi x)\bigr), \qquad x \in [-1, 1],$$
observed with Gaussian noise $\sigma = 0.05$. The function has position-dependent effective frequency: oscillates rapidly near $x = 0$ and $x = \pm 1$ (where the warp $\sin(\pi x)$ has maximum slope) and is locally flat at $x = \pm \tfrac{1}{2}$ (where the warp's slope vanishes).
N_OBS, NOISE_STD = 200, 0.05
key = jr.PRNGKey(0)
x_full = jnp.linspace(-1.0, 1.0, N_OBS).reshape(-1, 1)
def truth(xs: jax.Array) -> jax.Array:
x = xs[:, 0]
return jnp.sin(3.0 * jnp.pi * jnp.sin(jnp.pi * x))
# Held-out gap to test predictive uncertainty in the gap region.
mask = (x_full[:, 0] < -0.2) | (x_full[:, 0] > 0.4)
x_obs = x_full[mask]
y_obs = truth(x_obs) + NOISE_STD * jr.normal(jr.PRNGKey(1), x_obs.shape[0:1])
x_test = jnp.linspace(-1.2, 1.2, 400).reshape(-1, 1)
y_truth = truth(x_test)
print(f"observations: N = {x_obs.shape[0]} (gap removed from x ∈ (-0.2, 0.4))")
observations: N = 140 (gap removed from x ∈ (-0.2, 0.4))
fig, ax = plt.subplots(figsize=(11, 4))
ax.scatter(
x_obs[:, 0],
y_obs,
s=10,
color="C1",
edgecolors="k",
linewidths=0.5,
label="observations",
zorder=5,
)
ax.plot(x_test[:, 0], y_truth, "k--", linewidth=1.5, label=r"truth $f^\star(x)$")
ax.axvspan(-0.2, 0.4, color="0.85", alpha=0.4, label="held-out gap")
ax.set_xlabel("$x$")
ax.set_ylabel("$f(x)$")
ax.set_title(r"Compositional warp — $f(x) = \sin(3\pi\,\sin(\pi x))$")
ax.legend(loc="upper right", fontsize=9)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
All three deep methods use the same architecture:
- $D_{\mathrm{in}} = 1$, $D_{\mathrm{out}} = 1$
- hidden dimension $D_h = 16$
- $M = 32$ Fourier features per layer
- depth $L = 2$ (stacked SSGP blocks; one is the readout)
- prior lengthscale $\ell = 0.3$
This keeps the per-method parameter count comparable so depth is what differs, not capacity.
DEPTH = 2
HIDDEN = 16
M_FEAT = 32
LENGTHSCALE_INIT = 0.3
1. Deep RFF (MAP) — the deterministic baseline¶
Math. Same architecture as the deep VSSGP, but every parameter is a PyTree leaf trained by gradient descent on the regularised MSE
$$\mathcal{L}(\Theta) = \tfrac{1}{2\sigma_n^2}\,\bigl\|y - F_L(X; \Theta)\bigr\|^2 + \tfrac{M}{2\sigma_W^2}\sum_{l=0}^{L-1}\|W_l\|_F^2 + \tfrac{1}{2}\sum_{l=0}^{L-1}\|\Omega_l\|_F^2 / \sigma_\Omega^2,$$
where $\Theta = \{\Omega_l, W_l, \ell_l\}_{l=0}^{L-1}$. The frequency L2 acts as a soft prior pulling each $\Omega_l$ toward the lengthscale-1 RBF spectral density.
What MAP cannot give you. A point estimate of the deep network — no per-prediction uncertainty. The fit through the held-out gap is whatever the optimiser settled on, with no signal that the model is uncertain there.
Wide initialisation matters per layer. Same saddle issue as the single-layer learned RFF: $\cos(F\Omega/\ell)$ has zero gradient at $\Omega = 0$. We initialise each $\Omega_l \sim \mathcal{N}(0, \sigma^2)$ with $\sigma = 5$ to escape the saddle.
class DeepLearnedRFF(eqx.Module):
"""Deep RFF stack with all parameters trainable as PyTree leaves."""
omegas: list # length L: each (D_l, M)
log_ells: list # length L
weights: list # length L: each (2M, D_{l+1})
biases: list # length L: each (D_{l+1},)
@classmethod
def init(
cls,
key,
in_features,
hidden_features,
out_features,
*,
depth,
n_features,
lengthscale,
w_init_scale=2.0,
beta_init_scale=0.01,
):
keys = jr.split(key, 2 * depth)
omegas, log_ells, weights, biases = [], [], [], []
for layer_idx in range(depth):
in_dim = in_features if layer_idx == 0 else hidden_features
out_dim = out_features if layer_idx == depth - 1 else hidden_features
omegas.append(
w_init_scale * jr.normal(keys[2 * layer_idx], (in_dim, n_features))
)
log_ells.append(jnp.log(jnp.array(lengthscale)))
weights.append(
beta_init_scale
* jr.normal(keys[2 * layer_idx + 1], (2 * n_features, out_dim))
)
biases.append(jnp.zeros((out_dim,)))
return cls(omegas=omegas, log_ells=log_ells, weights=weights, biases=biases)
def __call__(self, x):
z = x
n_features = self.weights[0].shape[0] // 2
scale = jnp.sqrt(1.0 / n_features)
for omega, log_ell, W, b in zip(
self.omegas,
self.log_ells,
self.weights,
self.biases,
strict=True,
):
ell = jnp.exp(log_ell)
zw = z @ omega / ell
phi = scale * jnp.concatenate([jnp.cos(zw), jnp.sin(zw)], axis=-1)
z = phi @ W + b
return z
def fit_deep_map(model, x_obs, y_obs, *, n_steps=4000, lr=5e-3, beta_l2=1e-3):
opt = optax.adam(lr)
state = opt.init(eqx.filter(model, eqx.is_inexact_array))
def loss_fn(m):
pred = m(x_obs)[:, 0]
mse = jnp.mean((pred - y_obs) ** 2)
l2 = sum(jnp.sum(W**2) for W in m.weights)
return mse + beta_l2 * l2
@eqx.filter_jit
def step(m, s):
loss, grads = eqx.filter_value_and_grad(loss_fn)(m)
upd, s = opt.update(grads, s, eqx.filter(m, eqx.is_inexact_array))
return eqx.apply_updates(m, upd), s, loss
losses = []
for _ in range(n_steps):
model, state, loss = step(model, state)
losses.append(float(loss))
return model, losses
deep_map = DeepLearnedRFF.init(
jr.PRNGKey(11),
in_features=1,
hidden_features=HIDDEN,
out_features=1,
depth=DEPTH,
n_features=M_FEAT,
lengthscale=LENGTHSCALE_INIT,
)
deep_map, map_losses = fit_deep_map(deep_map, x_obs, y_obs)
y_pred_map = deep_map(x_test)[:, 0]
mse_map = float(jnp.mean((y_pred_map - y_truth) ** 2))
print(f"deep RFF MAP: final loss = {map_losses[-1]:.4f}, MSE = {mse_map:.4f}")
deep RFF MAP: final loss = 0.0092, MSE = 0.1837
fig, ax = plt.subplots(figsize=(11, 4.5))
ax.scatter(
x_obs[:, 0],
y_obs,
s=10,
color="C1",
edgecolors="k",
linewidths=0.5,
label="observations",
zorder=5,
)
ax.plot(x_test[:, 0], y_truth, "k--", linewidth=1.5, label="truth")
ax.plot(
x_test[:, 0],
y_pred_map,
"C3",
linewidth=1.8,
label=f"deep RFF MAP (MSE={mse_map:.4f})",
)
ax.axvspan(-0.2, 0.4, color="0.85", alpha=0.4, label="held-out gap")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_title(f"Deep RFF MAP, depth $L = {DEPTH}$ — deterministic fit")
ax.grid(True, alpha=0.3)
ax.legend(loc="upper right", fontsize=9)
plt.tight_layout()
plt.show()
Clean fit at the data, but the curve through the held-out gap is just whatever the optimiser found — no uncertainty. The depth here gives the model enough capacity to track the warp's varying instantaneous frequency without adjusting features per-region; we still have no posterior to read from.
2. Deep SSGP — analytic readout, deterministic deep features¶
Math. Layer-wise marginalisation of $W_l$ in the deep stack is not closed-form because $F_l$ depends stochastically on $W_{0:l-1}$. The clean compromise: treat the first $L - 1$ layers as deterministic feature extractors with point-estimate $(\Omega_l, W_l)$, and apply the closed-form SSGP marginal likelihood only at the readout.
Concretely, for a depth-$L$ network:
$$F_{L-1} = \Phi_{L-2}(\Phi_{L-3}(\cdots \Phi_0(X)\cdots))\quad\text{(deterministic)}, \qquad y = \Phi_{L-1}(F_{L-1}; \Omega_{L-1}, \ell_{L-1})\, \beta + \varepsilon,$$
where $\beta \sim \mathcal{N}(0, \sigma_\beta^2 I)$ is marginalised analytically as in single-layer SSGP. The marginal likelihood becomes
$$\log p(y \mid X, \Theta_{\mathrm{deep}}, \Omega_{L-1}, \ell_{L-1}, \sigma_\beta, \sigma_n) = \log \mathcal{N}\!\bigl(y \,\big|\, 0,\; K_{L-1} + \sigma_n^2 I\bigr), \qquad K_{L-1} = \sigma_\beta^2 \Phi_{L-1} \Phi_{L-1}^\top.$$
All hyperparameters $\Theta = \{\Omega_l, W_l, \ell_l\}_{l < L-1} \cup \{\Omega_{L-1}, \ell_{L-1}, \sigma_\beta, \sigma_n\}$ are trained jointly by maximising this objective. Predictive mean and variance use the same closed-form formulas as single-layer SSGP, applied to $F_{L-1}$.
Why this works. The marginal likelihood acts as both a model selection criterion (Occam's razor on $\sigma_\beta, \sigma_n$) and as a regulariser on the deep features (preferring features that produce a Gram matrix $\Phi\Phi^\top$ that matches the data covariance well). The trade-off vs. the full deep VSSGP: we get closed-form predictive variance for free, but uncertainty in the lower-layer features is collapsed to a point estimate.
class DeepSSGP(eqx.Module):
"""Deep RFF feature extractor + analytic SSGP readout.
Lower layers are deterministic feature extractors with point-estimate
parameters (including biases). The final layer marginalises its head
``β`` via the SSGP closed form; all hyperparameters trained on the
joint marginal likelihood.
"""
omegas: list # length L
log_ells: list # length L
weights: list # length L-1 (no head W on final layer; β is marginalised)
biases: list # length L-1
log_sigma_n: jax.Array
log_sigma_beta: jax.Array
@classmethod
def init(
cls,
key,
in_features,
hidden_features,
*,
depth,
n_features,
lengthscale,
w_init_scale=2.0,
sigma_n=0.05,
sigma_beta=1.0,
):
if depth < 1:
raise ValueError("depth >= 1 required.")
keys = jr.split(key, 2 * depth)
omegas, log_ells = [], []
for layer_idx in range(depth):
in_dim = in_features if layer_idx == 0 else hidden_features
omegas.append(
w_init_scale * jr.normal(keys[2 * layer_idx], (in_dim, n_features))
)
log_ells.append(jnp.log(jnp.array(lengthscale)))
# Linear projection W_l for l = 0..depth-2 (last layer's head is marginalised)
weights, biases = [], []
for layer_idx in range(depth - 1):
weights.append(
0.1
* jr.normal(keys[2 * layer_idx + 1], (2 * n_features, hidden_features))
)
biases.append(jnp.zeros((hidden_features,)))
return cls(
omegas=omegas,
log_ells=log_ells,
weights=weights,
biases=biases,
log_sigma_n=jnp.log(jnp.array(sigma_n)),
log_sigma_beta=jnp.log(jnp.array(sigma_beta)),
)
def features(self, x):
"""Return Φ_{L-1} (the readout RFF) after passing through the deep stack."""
z = x
n_features = self.omegas[0].shape[1]
scale = jnp.sqrt(1.0 / n_features)
# Lower layers — deterministic feature extractor
for layer_idx in range(len(self.weights)):
ell = jnp.exp(self.log_ells[layer_idx])
zw = z @ self.omegas[layer_idx] / ell
phi = scale * jnp.concatenate([jnp.cos(zw), jnp.sin(zw)], axis=-1)
z = phi @ self.weights[layer_idx] + self.biases[layer_idx]
# Final layer's RFF (no projection — head marginalised below)
ell_top = jnp.exp(self.log_ells[-1])
zw_top = z @ self.omegas[-1] / ell_top
return scale * jnp.concatenate([jnp.cos(zw_top), jnp.sin(zw_top)], axis=-1)
def neg_log_marginal(self, x, y):
Phi = self.features(x)
N = x.shape[0]
twoM = Phi.shape[1]
sigma_n2 = jnp.exp(2.0 * self.log_sigma_n)
sigma_b2 = jnp.exp(2.0 * self.log_sigma_beta)
B = sigma_n2 * jnp.eye(twoM) + sigma_b2 * Phi.T @ Phi
L = jnp.linalg.cholesky(B)
v = Phi.T @ y
L_inv_v = jax.scipy.linalg.solve_triangular(L, v, lower=True)
quad = -0.5 / sigma_n2 * jnp.sum(y**2) + 0.5 * sigma_b2 / sigma_n2 * jnp.sum(
L_inv_v**2
)
log_det = 0.5 * (N - twoM) * jnp.log(sigma_n2) + jnp.sum(jnp.log(jnp.diag(L)))
return -(quad - log_det - 0.5 * N * jnp.log(2.0 * jnp.pi))
def predict(self, x_train, y_train, x_query):
Phi = self.features(x_train)
Phi_q = self.features(x_query)
twoM = Phi.shape[1]
sigma_n2 = jnp.exp(2.0 * self.log_sigma_n)
sigma_b2 = jnp.exp(2.0 * self.log_sigma_beta)
B = sigma_n2 * jnp.eye(twoM) + sigma_b2 * Phi.T @ Phi
L = jnp.linalg.cholesky(B)
v = Phi.T @ y_train
mu_beta = sigma_b2 * jax.scipy.linalg.cho_solve((L, True), v)
mean = Phi_q @ mu_beta
L_inv_phi_q = jax.scipy.linalg.solve_triangular(L, Phi_q.T, lower=True)
var_f = sigma_b2 * sigma_n2 * jnp.sum(L_inv_phi_q**2, axis=0)
return mean, sigma_n2 + var_f
def fit_deep_ssgp(model, x_obs, y_obs, *, n_steps=2000, lr=5e-3):
opt = optax.adam(lr)
state = opt.init(eqx.filter(model, eqx.is_inexact_array))
@eqx.filter_jit
def step(m, s):
loss, grads = eqx.filter_value_and_grad(
lambda mm: mm.neg_log_marginal(x_obs, y_obs)
)(m)
upd, s = opt.update(grads, s, eqx.filter(m, eqx.is_inexact_array))
return eqx.apply_updates(m, upd), s, loss
losses = []
for _ in range(n_steps):
model, state, loss = step(model, state)
losses.append(float(loss))
return model, losses
deep_ssgp = DeepSSGP.init(
jr.PRNGKey(13),
in_features=1,
hidden_features=HIDDEN,
depth=DEPTH,
n_features=M_FEAT,
lengthscale=LENGTHSCALE_INIT,
)
deep_ssgp, ssgp_losses = fit_deep_ssgp(deep_ssgp, x_obs, y_obs)
mean_ssgp, var_ssgp = deep_ssgp.predict(x_obs, y_obs, x_test)
std_ssgp = jnp.sqrt(var_ssgp)
mse_ssgp = float(jnp.mean((mean_ssgp - y_truth) ** 2))
print(
f"deep SSGP: final neg-log-ML = {ssgp_losses[-1]:.2f}, "
f"σ_n = {float(jnp.exp(deep_ssgp.log_sigma_n)):.4f}, MSE = {mse_ssgp:.4f}"
)
deep SSGP: final neg-log-ML = -182.11, σ_n = 0.0525, MSE = 0.2645
fig, ax = plt.subplots(figsize=(11, 4.5))
ax.scatter(
x_obs[:, 0],
y_obs,
s=10,
color="C1",
edgecolors="k",
linewidths=0.5,
label="observations",
zorder=5,
)
ax.plot(x_test[:, 0], y_truth, "k--", linewidth=1.5, label="truth")
ax.plot(
x_test[:, 0],
mean_ssgp,
"C0",
linewidth=1.8,
label=f"deep SSGP mean (MSE={mse_ssgp:.4f})",
)
ax.fill_between(
x_test[:, 0],
mean_ssgp - 2 * std_ssgp,
mean_ssgp + 2 * std_ssgp,
color="C0",
alpha=0.25,
label=r"$\pm 2\sigma$ closed form",
)
ax.axvspan(-0.2, 0.4, color="0.85", alpha=0.4, label="held-out gap")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_title(
f"Deep SSGP, depth $L = {DEPTH}$ — analytic readout, deterministic features"
)
ax.grid(True, alpha=0.3)
ax.legend(loc="upper right", fontsize=9)
plt.tight_layout()
plt.show()
Same fit quality as deep MAP, plus a closed-form predictive band on the readout uncertainty. The band opens visibly across the gap — the deep features map gap inputs to readout-feature locations far from any training point, and the SSGP variance formula picks that up automatically. Hyperparameters $\sigma_n$, $\sigma_\beta$, all $\ell_l$ tuned by ML-II for free.
What this band misses. The lower-layer parameters are point estimates, so the band reflects only the readout uncertainty. If the deep features themselves were ambiguous (multiple lower-layer solutions consistent with the data), this method couldn't say so. That is what deep VSSGP adds.
3. Deep VSSGP — full variational stack¶
Math. Variational posteriors on every layer's $(\Omega_l, W_l)$:
$$q(\Omega_l) = \mathcal{N}\!\bigl(\mu_{\Omega,l}, \mathrm{diag}(\sigma_{\Omega,l}^2)\bigr), \qquad q(W_l) = \mathcal{N}\!\bigl(\mu_{W,l}, \mathrm{diag}(\sigma_{W,l}^2)\bigr).$$
Doubly-stochastic ELBO (Cutajar et al. 2017, Eq. 7-12, building on Salimbeni & Deisenroth 2017's general framework):
$$\mathcal{L}_\beta(\phi) = \mathbb{E}_{q(\Theta)}\!\left[\sum_i \log p(y_i \mid x_i, \Theta)\right] - \beta\,\sum_{l=0}^{L-1}\Bigl(\mathrm{KL}\!\bigl[q(\Omega_l) \,\Vert\, p(\Omega_l)\bigr] + \mathrm{KL}\!\bigl[q(W_l) \,\Vert\, p(W_l)\bigr]\Bigr),$$
with one MC sample of the full $\Theta = \{\Omega_l, W_l\}_{l=0}^{L-1}$ per minibatch step (the "doubly stochastic" trick). Reparameterise both each step:
$$\Omega_l = \mu_{\Omega,l} + \sigma_{\Omega,l} \odot \varepsilon_{\Omega,l}, \qquad W_l = \mu_{W,l} + \sigma_{W,l} \odot \varepsilon_{W,l}, \qquad \varepsilon \sim \mathcal{N}(0, I).$$
All KLs are closed-form Gaussian-Gaussian. We use the tempered ELBO with $\beta = 0.05$ — the same KL annealing from the single-layer VSSGP notebook.
Implementation. pyrox.nn.DeepVSSGP registers exactly $3L$ sample sites — layer_{l}.W_freq, layer_{l}.lengthscale, layer_{l}.W_proj — with NumPyro priors, and is the production path described in §5. The notebook code below uses an inline manual class with reparameterised sampling so the tempered ELBO, the KL-warmup schedule, and the MAP warm-start trick are visible.
MAP warm-start. A direct cold start on the deep variational ELBO routinely fails to fit the training data: MC sampling noise plus KL pressure on $\mu_\Omega$ overwhelms the data fit gradient before a sensible solution is established. The fix that consistently works in deep variational GPs (see Salimbeni & Deisenroth 2017, Cutajar et al. 2017) is to initialise the variational means at a trained MAP solution (with very tight initial $\sigma$, so $q$ starts as a near-delta around MAP) and then continue training with the tempered ELBO. The variational training then only has to learn appropriate $\sigma$ — the means already start at a good fit. We use this trick below.
KL_BETA = 0.005
KL_WARMUP_STEPS = 2000
N_MC = 8
NOISE_VAR = NOISE_STD**2
class ManualDeepVSSGP(eqx.Module):
"""Deep VSSGP with explicit reparameterised mean-field q over (Ω_l, W_l).
For each layer l we maintain (μ_Ω_l, log σ_Ω_l) and (μ_W_l, log σ_W_l)
as PyTree leaves. Per-layer lengthscales and biases are deterministic
PyTree leaves — making them variational adds optimisation noise without
obvious benefit at this scale.
The pyrox-native equivalent is :class:`pyrox.nn.DeepVSSGP` driven by
a NumPyro ``AutoNormal`` guide; we use the manual form here so the
tempered ELBO, the KL annealing schedule, and the per-layer
reparameterisation are visible.
"""
mu_omegas: list # length L: each (D_l, M)
log_sigma_omegas: list
mu_weights: list # length L: each (2M, D_{l+1})
log_sigma_weights: list
log_ells: list # length L — trainable point estimate
biases: list # length L — trainable point estimate
@classmethod
def init(
cls,
key,
in_features,
hidden_features,
out_features,
*,
depth,
n_features,
lengthscale,
mu_init_scale=5.0,
log_sigma_init=-2.0,
):
keys = jr.split(key, 2 * depth)
mu_omegas, log_sigma_omegas = [], []
mu_weights, log_sigma_weights = [], []
log_ells, biases = [], []
for layer_idx in range(depth):
in_dim = in_features if layer_idx == 0 else hidden_features
out_dim = out_features if layer_idx == depth - 1 else hidden_features
mu_omegas.append(
mu_init_scale * jr.normal(keys[2 * layer_idx], (in_dim, n_features))
)
log_sigma_omegas.append(jnp.full((in_dim, n_features), log_sigma_init))
mu_weights.append(
0.01 * jr.normal(keys[2 * layer_idx + 1], (2 * n_features, out_dim))
)
log_sigma_weights.append(
jnp.full((2 * n_features, out_dim), log_sigma_init)
)
log_ells.append(jnp.log(jnp.array(lengthscale)))
biases.append(jnp.zeros((out_dim,)))
return cls(
mu_omegas=mu_omegas,
log_sigma_omegas=log_sigma_omegas,
mu_weights=mu_weights,
log_sigma_weights=log_sigma_weights,
log_ells=log_ells,
biases=biases,
)
@classmethod
def init_from_map(cls, map_model: "DeepLearnedRFF", *, log_sigma_init=-3.0):
"""Warm-start VSSGP from a trained deep MAP solution.
Each variational mean is initialised at the MAP point estimate; the
log-sigma starts very tight (default ``e^{-3} ≈ 0.05``) so q begins
as a near-delta around the MAP. Training then learns appropriate
sigma values via the data + tempered KL trade-off.
This bypasses the variational optimisation pathology — direct cold
starts on the deep ELBO often fail to fit the training data because
MC noise + KL pressure dominates before a good fit is established.
"""
mu_omegas = [jnp.array(o) for o in map_model.omegas]
mu_weights = [jnp.array(w) for w in map_model.weights]
log_ells = [jnp.array(le) for le in map_model.log_ells]
biases = [jnp.array(b) for b in map_model.biases]
log_sigma_omegas = [jnp.full_like(o, log_sigma_init) for o in mu_omegas]
log_sigma_weights = [jnp.full_like(w, log_sigma_init) for w in mu_weights]
return cls(
mu_omegas=mu_omegas,
log_sigma_omegas=log_sigma_omegas,
mu_weights=mu_weights,
log_sigma_weights=log_sigma_weights,
log_ells=log_ells,
biases=biases,
)
def _forward(self, x, sample_omegas, sample_weights):
z = x
n_features = sample_omegas[0].shape[1]
scale = jnp.sqrt(1.0 / n_features)
for omega, W, log_ell, b in zip(
sample_omegas,
sample_weights,
self.log_ells,
self.biases,
strict=True,
):
ell = jnp.exp(log_ell)
zw = z @ omega / ell
phi = scale * jnp.concatenate([jnp.cos(zw), jnp.sin(zw)], axis=-1)
z = phi @ W + b
return z
def sample_q(self, key):
sigma_omegas = [jnp.exp(s) for s in self.log_sigma_omegas]
sigma_weights = [jnp.exp(s) for s in self.log_sigma_weights]
keys = jr.split(key, 2 * len(self.mu_omegas))
sampled_omegas = [
mu + s * jr.normal(k, mu.shape)
for mu, s, k in zip(
self.mu_omegas,
sigma_omegas,
keys[: len(self.mu_omegas)],
strict=True,
)
]
sampled_weights = [
mu + s * jr.normal(k, mu.shape)
for mu, s, k in zip(
self.mu_weights,
sigma_weights,
keys[len(self.mu_omegas) :],
strict=True,
)
]
return sampled_omegas, sampled_weights
def kl(self):
total = jnp.array(0.0)
for mu, log_sigma in zip(
self.mu_omegas + self.mu_weights,
self.log_sigma_omegas + self.log_sigma_weights,
strict=True,
):
sigma2 = jnp.exp(2.0 * log_sigma)
total = total + 0.5 * jnp.sum(mu**2 + sigma2 - 1.0 - 2.0 * log_sigma)
return total
def deep_vssgp_elbo(model, x, y, key, kl_weight):
keys = jr.split(key, N_MC)
def one_sample(k):
sample_omegas, sample_weights = model.sample_q(k)
pred = model._forward(x, sample_omegas, sample_weights)[:, 0]
return -0.5 * jnp.sum((pred - y) ** 2) / NOISE_VAR
nll = -jnp.mean(jax.vmap(one_sample)(keys))
return nll + kl_weight * model.kl()
def fit_deep_vssgp(model, x_obs, y_obs, *, n_steps=8000, lr=5e-3, seed=0):
"""Train with linear KL warmup: weight ramps 0 → KL_BETA over the first
KL_WARMUP_STEPS steps, then constant. This lets q(Ω) escape the
prior's bandwidth before the KL pulls μ_Ω back toward 0.
"""
opt = optax.adam(lr)
state = opt.init(eqx.filter(model, eqx.is_inexact_array))
@eqx.filter_jit
def step(m, s, k, klw):
loss, grads = eqx.filter_value_and_grad(deep_vssgp_elbo)(
m, x_obs, y_obs, k, klw
)
upd, s = opt.update(grads, s, eqx.filter(m, eqx.is_inexact_array))
return eqx.apply_updates(m, upd), s, loss
losses = []
key = jr.PRNGKey(seed)
for step_idx in range(n_steps):
key, sub = jr.split(key)
klw = jnp.asarray(KL_BETA * min(1.0, step_idx / KL_WARMUP_STEPS))
model, state, loss = step(model, state, sub, klw)
losses.append(float(loss))
return model, losses
deep_vssgp = ManualDeepVSSGP.init_from_map(deep_map)
deep_vssgp, losses_vssgp = fit_deep_vssgp(deep_vssgp, x_obs, y_obs)
print(f"deep VSSGP: final tempered-ELBO loss = {losses_vssgp[-1]:.2f}")
print(
"deep VSSGP: trained ℓ = "
+ ", ".join(f"{float(jnp.exp(le)):.3f}" for le in deep_vssgp.log_ells)
)
# Predictive: draw S samples from q, push each through the deep stack,
# compute mean + std, add observation noise.
def predict_deep_vssgp(model, x_query, key, n_samples=128):
keys = jr.split(key, n_samples)
def one_pred(k):
sample_omegas, sample_weights = model.sample_q(k)
return model._forward(x_query, sample_omegas, sample_weights)[:, 0]
preds = jax.vmap(one_pred)(keys)
mean = jnp.mean(preds, axis=0)
var = jnp.var(preds, axis=0) + NOISE_VAR
return mean, jnp.sqrt(var), preds
mean_vssgp, std_vssgp, preds_vssgp = predict_deep_vssgp(
deep_vssgp, x_test, jr.PRNGKey(99)
)
mse_vssgp = float(jnp.mean((mean_vssgp - y_truth) ** 2))
print(f"deep VSSGP: predictive MSE = {mse_vssgp:.4f}")
deep VSSGP: final tempered-ELBO loss = 67.15 deep VSSGP: trained ℓ = 0.286, 1.074
deep VSSGP: predictive MSE = 0.2094
fig, ax = plt.subplots(figsize=(11, 4.5))
ax.scatter(
x_obs[:, 0],
y_obs,
s=10,
color="C1",
edgecolors="k",
linewidths=0.5,
label="observations",
zorder=5,
)
ax.plot(x_test[:, 0], y_truth, "k--", linewidth=1.5, label="truth")
for member in preds_vssgp[:8]:
ax.plot(x_test[:, 0], member, color="C2", alpha=0.15, linewidth=0.8)
ax.plot(
x_test[:, 0],
mean_vssgp,
"C2",
linewidth=1.8,
label=f"deep VSSGP mean (MSE={mse_vssgp:.4f})",
)
ax.fill_between(
x_test[:, 0],
mean_vssgp - 2 * std_vssgp,
mean_vssgp + 2 * std_vssgp,
color="C2",
alpha=0.25,
label=r"$\pm 2\sigma$ MC band",
)
ax.axvspan(-0.2, 0.4, color="0.85", alpha=0.4, label="held-out gap")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_title(
f"Deep VSSGP, depth $L = {DEPTH}$ — variational posterior over every layer"
)
ax.grid(True, alpha=0.3)
ax.legend(loc="upper right", fontsize=9)
plt.tight_layout()
plt.show()
Grey traces are independent draws from $\prod_l q(\Omega_l)\,q(W_l)$, run end-to-end through the deep stack. They agree at the data and pull apart in the gap and extrapolation regions — the band reflects uncertainty in every layer's parameters, not just the readout. The MC band is necessarily noisier than SSGP's closed-form band but captures the richer hierarchical uncertainty.
4. Three deep methods, one plot¶
Same target, same data, three predictive bands.
fig, axes = plt.subplots(1, 3, figsize=(18, 4.5), sharey=True)
def _decorate(ax, title):
ax.scatter(
x_obs[:, 0],
y_obs,
s=10,
color="C1",
edgecolors="k",
linewidths=0.5,
label="observations",
zorder=5,
)
ax.plot(x_test[:, 0], y_truth, "k--", linewidth=1.5, label="truth")
ax.axvspan(-0.2, 0.4, color="0.85", alpha=0.4, label="held-out gap")
ax.set_title(title)
ax.set_xlabel("$x$")
ax.grid(True, alpha=0.3)
ax.legend(loc="upper right", fontsize=8)
ax = axes[0]
ax.plot(x_test[:, 0], y_pred_map, "C3", linewidth=1.8, label=f"MAP (MSE={mse_map:.4f})")
_decorate(ax, "Deep RFF MAP — point estimate, no uncertainty")
ax.set_ylabel("$y$")
ax = axes[1]
ax.plot(
x_test[:, 0],
mean_ssgp,
"C0",
linewidth=1.8,
label=f"SSGP mean (MSE={mse_ssgp:.4f})",
)
ax.fill_between(
x_test[:, 0],
mean_ssgp - 2 * std_ssgp,
mean_ssgp + 2 * std_ssgp,
color="C0",
alpha=0.25,
label=r"$\pm 2\sigma$ closed form",
)
_decorate(ax, "Deep SSGP — analytic readout band")
ax = axes[2]
for member in preds_vssgp[:8]:
ax.plot(x_test[:, 0], member, color="C2", alpha=0.15, linewidth=0.8)
ax.plot(
x_test[:, 0],
mean_vssgp,
"C2",
linewidth=1.8,
label=f"VSSGP mean (MSE={mse_vssgp:.4f})",
)
ax.fill_between(
x_test[:, 0],
mean_vssgp - 2 * std_vssgp,
mean_vssgp + 2 * std_vssgp,
color="C2",
alpha=0.25,
label=r"$\pm 2\sigma$ MC",
)
_decorate(ax, "Deep VSSGP — full variational stack")
plt.tight_layout()
plt.show()
# Sanity-check uncertainty values.
trained = ((x_test[:, 0] > -0.5) & (x_test[:, 0] < -0.3)) | (
(x_test[:, 0] > 0.5) & (x_test[:, 0] < 0.7)
)
gap_region = (x_test[:, 0] > -0.2) & (x_test[:, 0] < 0.4)
for name, std_arr in [("deep SSGP", std_ssgp), ("deep VSSGP", std_vssgp)]:
ratio = float(jnp.mean(std_arr[gap_region]) / jnp.mean(std_arr[trained]))
print(f"{name}: gap/data std ratio = {ratio:.1f}x")
deep SSGP: gap/data std ratio = 3.2x deep VSSGP: gap/data std ratio = 1.1x
Reading the figure. What changes between methods is mostly the uncertainty, not the point fit:
- Deep MAP — point estimate, no band; whatever curve the optimiser chose.
- Deep SSGP — closed-form band on the readout posterior, ML-II hyperparameters. Band opens across the gap.
- Deep VSSGP — MC band over the full $3L$-site posterior. Captures lower-layer uncertainty too; band is noisier than SSGP's because it's MC, and tighter when warm-started from MAP (which we did to get a usable optimisation — see §3).
Whether depth itself wins on this target is best judged by comparing to the single-layer SSGP / VSSGP from the SSGP notebook. The architecture used here is much harder to optimise than a single-layer model — depth is worth reaching for when the target has clear compositional structure (warps of warps, hierarchical scales) that a stationary kernel cannot represent.
5. Implementation note — pyrox.nn.DeepVSSGP is the production path¶
The notebook's ManualDeepVSSGP above is written inline so the tempered ELBO and the per-layer reparameterisation are visible — same pedagogical choice as the single-layer VSSGP class in the SSGP notebook. For production use, the same model is packaged as pyrox.nn.DeepVSSGP:
from pyrox.nn import DeepVSSGP
import numpyro
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
def model(x, y=None):
net = DeepVSSGP.init(
in_features=1, hidden_features=HIDDEN, out_features=1,
depth=DEPTH, n_features=M_FEAT, lengthscale=LENGTHSCALE_INIT,
pyrox_name="dgp",
)
pred = net(x)[:, 0]
sigma_n = numpyro.sample("sigma_n", numpyro.distributions.LogNormal(jnp.log(0.05), 0.5))
numpyro.sample("y", numpyro.distributions.Normal(pred, sigma_n), obs=y)
guide = AutoNormal(model, init_scale=0.1)
svi = SVI(model, guide, numpyro.optim.Adam(5e-3), loss=Trace_ELBO(num_particles=1))
state = svi.init(key, x_obs, y_obs)
# ... training loop ...
Each forward call to DeepVSSGP registers $3L$ NumPyro sample sites (layer_{l}.W_freq, layer_{l}.lengthscale, layer_{l}.W_proj for $l = 0, \dots, L-1$) with the pyrox sampling protocol. Standard NumPyro AutoNormal + SVI on top discovers all sites and learns mean-field Gaussian posteriors with no custom variational bookkeeping. The price for that elegance is more JIT-compile overhead than the manual loop above; both are correct.
Sanity check at $L = 1$. A DeepVSSGP of depth 1 reduces to a single VSSGP layer with the same architecture as VariationalFourierFeatures composed with DenseReparameterization. The unit test tests/nn/test_deep_vssgp.py::test_depth_one_reduces_to_single_vssgp enforces this — at depth 1 there are exactly 3 sample sites and the output dimension is out_features directly.
Takeaways¶
- Stacking RFF layers gives a path toward non-stationary kernels. The composition $\Phi_{L-1}(\Phi_{L-2}(\cdots \Phi_0(x) \cdots))$ produces an effective kernel whose dependence on $x$ is no longer translation-invariant. Whether your problem benefits from this depends on whether it has the kind of compositional / hierarchical structure that a stationary kernel cannot represent.
- Layer-wise marginalisation breaks at depth $> 1$. The Gaussian conjugacy that gives single-layer SSGP its closed form only works because the input to the readout is deterministic. Stacking makes that input itself stochastic, so the marginal likelihood is no longer Gaussian — hence the deep VSSGP doubly-stochastic ELBO.
- Three deep flavors, three trade-offs.
- Deep MAP — cheapest, no uncertainty.
- Deep SSGP — closed-form readout band + ML-II hyperparameters, but lower layers are point estimates (no uncertainty in the deep features).
- Deep VSSGP (Cutajar et al. 2017) — variational posterior over every layer; MC predictive band that captures hierarchical uncertainty. Hardest to fit: the data fit gradient on the deep ELBO routinely loses to the KL when started cold, so we warm-start the variational means from a trained MAP solution (and start log-sigma very tight). Without this trick, the deep ELBO objective routinely fails to fit even the training data.
pyrox.nn.DeepVSSGPpackages the full variational stack as a single PyroxModule. Standard NumPyro AutoNormal + SVI on top.- Benchmark before reaching for depth. For a moderate 1D problem, a single-layer SSGP with a wide enough $M$ usually fits as well or better than depth-2, with a fraction of the optimisation pain. Treat depth as a tool for high-dim non-stationary problems with clear hierarchical structure — not as a default upgrade from single-layer.
Where to next.
- For the single-layer versions and the apples-to-apples baseline, see the RFF → SSGP → VSSGP notebook.
- For fixed-Ω + ensemble RFF, see the RFF as Neural Networks notebook.
- The industrial-strength cousin of deep VSSGP for large datasets is doubly-stochastic Deep GPs with inducing points (Salimbeni & Deisenroth, NeurIPS 2017) — the inducing-point analogue of the random-feature stack here.