Non-stationary GEV — a mechanistic ODE
A forced energy-balance trajectory for the warming location
Abstract¶
The parametric notebook let the GEV location follow a straight line. A line is
the strongest possible assumption about the shape of a trend: it cannot bend,
lag, or saturate, and it extrapolates without limit. Here we replace it with a
mechanistic ordinary differential equation — a one-box energy-balance model
in which the location relaxes toward a forced equilibrium with a physical
response time τ. The trajectory is integrated with diffrax inside the
NumPyro model and inferred end-to-end by NUTS. The result is a grey-box that
sits between the rigid line of Non-stationary GEV — a parametric trend and the
free Gaussian process of Non-stationary GEV — a state-space GP trend: few, physically
interpretable parameters, but a trajectory whose curvature the data can shape.
Non-stationary extremes: a mechanistic ODE¶
In Non-stationary GEV — a parametric trend the GEV location drifted along a line, . That worked — WAIC preferred it, the warming was clear — but the linear form is an assumption imposed on the trend, not learned. The real climate does not warm along a ruler: forcing accumulates, and the surface temperature lags it with a memory of years to decades, so the warming curve can be concave, can accelerate, can saturate.
This notebook encodes that mechanism directly. We model the location as
where — a temperature anomaly — is the solution of a differential
equation, not a fixed algebraic curve. We integrate that ODE with
diffrax inside the probabilistic
model, so its physical parameters are inferred jointly with the GEV scale and
shape by the same NUTS sampler we have used throughout.
Background¶
A one-box energy-balance model¶
The simplest climate model is a single heat reservoir relaxing toward the equilibrium set by a radiative forcing :
Two parameters carry all the physics:
- τ — the response time (years). Small τ: the system tracks the forcing almost instantly, so inherits ’s shape. Large τ: the system is sluggish, averaging and lagging the forcing into a smoother, flatter curve.
- β — the sensitivity (°C). With normalized to end at 1, β is the equilibrium warming the system is heading toward; the realized warming over a finite record is always less, because of the lag.
This is a genuine grey box: far more structured than a free curve (only two knobs, both interpretable), yet — unlike the line — able to bend. The line of the previous notebook is in fact the , slowly-forced limit of ((2)).
A forcing proxy¶
The honest input would be an observed effective-forcing series, but to keep every notebook offline we prescribe a stylized anthropogenic ramp that grows faster in recent decades (as greenhouse forcing has),
fixed and known (we use a mild acceleration ). Swapping in a real
forcing time series is a one-line change to forcing(t) — the inference machinery
is identical.
What the data can and cannot pin down¶
With ~115 noisy maxima we should expect the amount of warming (β, and the realized end-to-end shift) to be well constrained, but the mechanism — the response time τ — only weakly so: many pairs trace nearly the same century-long curve. We will see exactly that, and it is the right lesson about fitting dynamics to short records.
Setup¶
import sys
import pathlib
try:
import spatial_extremes # noqa: F401 installed editable in the project venv
except ModuleNotFoundError:
_here = pathlib.Path.cwd().resolve()
_roots = (_here, *_here.parents)
_cands = [r / "src" for r in _roots]
_cands += [r / "projects" / "spatial_extremes" / "src" for r in _roots]
_src = next((c for c in _cands if (c / "spatial_extremes").exists()), None)
if _src is None:
raise RuntimeError("cannot locate spatial_extremes/src") from None
sys.path.insert(0, str(_src))from __future__ import annotations
import warnings
warnings.filterwarnings("ignore", message=r".*IProgress.*")
import jax
jax.config.update("jax_enable_x64", True) # GEV + ODE want float64
import numpy as np
import pandas as pd
import jax.numpy as jnp
import jax.random as jr
from jax.scipy.special import logsumexp
import matplotlib.pyplot as plt
import seaborn as sns
import numpyro
import numpyro.distributions as ndist
from numpyro.infer import MCMC, NUTS
from numpyro.infer.initialization import init_to_median
from diffrax import ODETerm, Tsit5, diffeqsolve, SaveAt, ConstantStepSize
from xtremax import GeneralizedExtremeValueDistribution as GEV
from xtremax import gev_log_prob, gev_survival
from spatial_extremes import data, viz
sns.set_theme(style="whitegrid", context="notebook", palette="deep")# Same long station as the parametric notebook: Albacete, 1901-2025.
years, maxima, meta = data.load_single_station()
place = meta["name"]
y_np = np.asarray(maxima, float)
yr = np.asarray(years, int)
m = y_np.size
y = jnp.asarray(y_np)
y_mean, y_std = float(y.mean()), float(y.std())
t0, t1 = float(yr.min()), float(yr.max())
ts = jnp.asarray(yr, float) # observation times for the ODE solve
ALPHA = 2.0 # forcing acceleration (fixed)
print("source:", "REAL CDS" if meta["is_real"] else "synthetic")
print(f"{place} — id {meta['station_id']}, {m} annual maxima, {yr.min()}-{yr.max()}")source: REAL CDS
Albacete — id SP000008280, 115 annual maxima, 1901-2025
The forcing and the solver¶
forcing(t) is the fixed proxy ((3)); integrate(beta, tau, tq)
solves the energy-balance ODE ((2)) and returns at the requested times.
We use a fixed-step solver (ConstantStepSize, ½-year steps): the ODE is a
simple non-stiff relaxation, and a constant step keeps the step count bounded and
differentiable — essential when NUTS probes tiny τ, where an adaptive
controller would refine to a halt.
def forcing(t):
u = (t - t0) / (t1 - t0)
return (jnp.exp(ALPHA * u) - 1.0) / (jnp.exp(ALPHA) - 1.0)
def integrate(beta, tau, tq):
def vf(t, T, args):
return (beta * forcing(t) - T) / tau
sol = diffeqsolve(ODETerm(vf), Tsit5(), t0=t0, t1=t1, dt0=0.5, y0=0.0,
saveat=SaveAt(ts=tq), stepsize_controller=ConstantStepSize(),
max_steps=2000)
return sol.ys
# sanity: the solve is differentiable end-to-end (needed for NUTS gradients)
g = jax.grad(lambda p: integrate(p[0], p[1], ts)[-1])(jnp.array([2.0, 20.0]))
print("dT_end/d(beta,tau) =", [float(v) for v in g])dT_end/d(beta,tau) = [0.7179999997438439, -0.021273917616068297]
Before fitting, look at what the model can say. For a fixed forcing, the response time τ alone reshapes the trajectory — fast systems (small τ) inherit the forcing’s acceleration, sluggish ones (large τ) lag into a straighter, flatter curve. This is the flexibility the data will exploit.
tgrid = jnp.linspace(t0, t1, 200)
fig, (axF, axT) = plt.subplots(1, 2, figsize=(11, 4.2))
axF.plot(np.asarray(tgrid), np.asarray(forcing(tgrid)), color="0.3", lw=2)
axF.set(xlabel="year", ylabel="normalized forcing F(t)",
title=f"Prescribed forcing proxy (alpha={ALPHA:g})")
for tau, c in zip((5.0, 20.0, 60.0), sns.color_palette("viridis", 3)):
axT.plot(np.asarray(tgrid), np.asarray(integrate(1.0, tau, tgrid)),
color=c, lw=2, label=f"tau = {tau:g} yr")
axT.set(xlabel="year", ylabel="response T(t) (beta = 1)",
title="Same forcing, different memory")
axT.legend()
fig.tight_layout()
plt.show()
The model¶
The probabilistic model is the GEV likelihood with a location that is the ODE
solution. Inside numpyro we sample the physical parameters, call integrate
to get at every observation year, and feed to the GEV.
NUTS differentiates straight through the solver.
def gev_ode(obs):
mu0 = numpyro.sample("mu0", ndist.Normal(y_mean, 5.0)) # baseline location
beta = numpyro.sample("beta", ndist.Normal(0.0, 3.0)) # warming amplitude (degC)
tau = numpyro.sample("tau", ndist.LogNormal(np.log(20.0), 0.8)) # response time (yr)
log_sigma = numpyro.sample("log_sigma", ndist.Normal(np.log(y_std), 0.5))
xi = numpyro.sample("xi", ndist.Normal(0.0, 0.25))
T = integrate(beta, tau, ts) # anomaly at each year
numpyro.sample("obs", GEV(loc=mu0 + T, scale=jnp.exp(log_sigma),
concentration=xi), obs=obs)
mcmc = MCMC(NUTS(gev_ode, target_accept_prob=0.95, init_strategy=init_to_median),
num_warmup=1000, num_samples=1000, num_chains=2,
chain_method="sequential", progress_bar=False)
mcmc.run(jr.PRNGKey(0), y)
mcmc.print_summary()
post = mcmc.get_samples()
print(f"divergences: {int(mcmc.get_extra_fields()['diverging'].sum())}")
mean std median 5.0% 95.0% n_eff r_hat
beta 2.56 1.01 2.43 0.96 4.15 784.18 1.00
log_sigma 0.39 0.07 0.38 0.26 0.50 1364.64 1.00
mu0 37.53 0.24 37.53 37.14 37.91 1404.74 1.00
tau 27.58 20.45 22.21 3.35 54.81 973.87 1.00
xi -0.09 0.05 -0.09 -0.17 -0.02 1240.97 1.00
Number of divergences: 1
divergences: 1
Reading the physics¶
β is the equilibrium warming the climate is heading toward; the realized warming over the record, , is what actually happened and is the quantity to compare with the parametric notebook’s +1.4°C. The response time τ is far less certain — its posterior is broad and right-skewed, the signature of a mechanism only weakly identified by a century of noisy maxima.
def ci(a): return np.quantile(a, [0.025, 0.5, 0.975])
beta_s = np.asarray(post["beta"]); tau_s = np.asarray(post["tau"])
# realized warming over the record for each draw
sub = jr.choice(jr.PRNGKey(1), beta_s.size, (600,), replace=False)
T_ends = np.asarray(jax.vmap(lambda b, ta: integrate(b, ta, jnp.array([t0, t1]))
)(post["beta"][sub], post["tau"][sub]))
realized = T_ends[:, 1] - T_ends[:, 0]
blo, bmd, bhi = ci(beta_s); rlo, rmd, rhi = ci(realized)
tlo, tmd, thi = ci(tau_s)
print(f"beta (equilibrium warming) : {bmd:.2f} degC (95% CI {blo:.2f}, {bhi:.2f})")
print(f"realized warming {yr.min()}-{yr.max()}: {rmd:.2f} degC (95% CI {rlo:.2f}, {rhi:.2f})")
print(f"tau (response time) : {tmd:.0f} yr (95% CI {tlo:.0f}, {thi:.0f})")
print(f"P(beta > 0 | data) = {float((beta_s > 0).mean()):.3f}")beta (equilibrium warming) : 2.43 degC (95% CI 0.95, 5.00)
realized warming 1901-2025: 1.67 degC (95% CI 0.60, 2.54)
tau (response time) : 22 yr (95% CI 5, 79)
P(beta > 0 | data) = 0.998
The joint posterior makes the weak identifiability concrete: β and τ trade off along a ridge — a stronger, slower response looks much like a weaker, faster one over 124 years — while the realized warming (the integral the data actually see) is pinned far more tightly.
draws = pd.DataFrame({"beta (degC)": beta_s[sub], "tau (yr)": tau_s[sub],
"realized (degC)": realized})
g = sns.pairplot(draws, corner=True, diag_kind="kde",
plot_kws=dict(s=8, alpha=0.25, edgecolor=None),
diag_kws=dict(fill=True))
g.figure.suptitle("Joint posterior: amplitude vs memory vs realized warming", y=1.02)
plt.show()
The fitted trajectory — and the line it replaces¶
Now the payoff of a mechanistic trend: a curved warming trajectory with uncertainty. We overlay the posterior on the maxima, and — for a direct contrast — refit the linear model of Non-stationary GEV — a parametric trend and draw its straight through the same data. The ODE bends where the line cannot.
# posterior trajectory band on a dense grid
mu_draws = (post["mu0"][sub][:, None]
+ jax.vmap(lambda b, ta: integrate(b, ta, tgrid))(
post["beta"][sub], post["tau"][sub]))
mu_draws = np.asarray(mu_draws)
mu_med, mu_lo, mu_hi = (np.quantile(mu_draws, q, 0) for q in (0.5, 0.025, 0.975))
# quick refit of the linear-location model for comparison
z_np = (yr - yr.mean()) / yr.std()
z = jnp.asarray(z_np)
def gev_lin(obs, zc):
a = numpyro.sample("a", ndist.Normal(y_mean, 5.0))
b = numpyro.sample("b", ndist.Normal(0.0, 2.0))
ls = numpyro.sample("ls", ndist.Normal(np.log(y_std), 0.5))
xi = numpyro.sample("xi", ndist.Normal(0.0, 0.25))
numpyro.sample("obs", GEV(loc=a + b * zc, scale=jnp.exp(ls),
concentration=xi), obs=obs)
lin = MCMC(NUTS(gev_lin, target_accept_prob=0.99, init_strategy=init_to_median),
num_warmup=1000, num_samples=1000, num_chains=2,
chain_method="vectorized", progress_bar=False)
lin.run(jr.PRNGKey(2), y, z)
plin = lin.get_samples()
zg = (np.asarray(tgrid) - yr.mean()) / yr.std()
mu_lin = np.median(plin["a"])[None] + np.median(plin["b"]) * zg
fig, ax = plt.subplots(figsize=(10, 4.8))
ax.scatter(yr, y_np, color="0.35", s=26, zorder=4, label="annual maxima")
ax.fill_between(np.asarray(tgrid), mu_lo, mu_hi, color="#C44E52", alpha=0.20,
label="ODE mu(t): 95% band")
ax.plot(np.asarray(tgrid), mu_med, color="#C44E52", lw=2.2, label="ODE mu(t): median")
ax.plot(np.asarray(tgrid), mu_lin, color="#4C72B0", lw=2, ls="--",
label="linear mu(t) (NB10)")
ax.set(xlabel="year", ylabel="location mu(t) (degC)",
title=f"{place}: mechanistic warming trajectory vs the straight line")
ax.legend(loc="upper left", fontsize=9)
plt.show()
Does the mechanism earn its keep? WAIC¶
The ODE has the same parameter count as the linear model (two location knobs: here ; there intercept, slope), so WAIC is a fair contest. We score the stationary, linear, and ODE models on the same maxima.
def waic(ll):
S = ll.shape[0]
lppd = logsumexp(ll, axis=0) - jnp.log(S)
return float(-2.0 * (lppd.sum() - ll.var(axis=0).sum()))
# stationary baseline (constant location)
def gev_stat(obs):
mu = numpyro.sample("mu", ndist.Normal(y_mean, 5.0))
sigma = numpyro.sample("sigma", ndist.HalfNormal(y_std))
xi = numpyro.sample("xi", ndist.Normal(0.0, 0.25))
numpyro.sample("obs", GEV(loc=mu, scale=sigma, concentration=xi), obs=obs)
stat = MCMC(NUTS(gev_stat, target_accept_prob=0.99, init_strategy=init_to_median),
num_warmup=1000, num_samples=1000, num_chains=2,
chain_method="vectorized", progress_bar=False)
stat.run(jr.PRNGKey(3), y)
ps = stat.get_samples()
ll_stat = gev_log_prob(y[None, :], ps["mu"][:, None], ps["sigma"][:, None],
ps["xi"][:, None])
ll_lin = gev_log_prob(y[None, :], plin["a"][:, None] + plin["b"][:, None] * z[None, :],
jnp.exp(plin["ls"])[:, None], plin["xi"][:, None])
T_all = jax.vmap(lambda b, ta: integrate(b, ta, ts))(post["beta"], post["tau"])
ll_ode = gev_log_prob(y[None, :], post["mu0"][:, None] + T_all,
jnp.exp(post["log_sigma"])[:, None], post["xi"][:, None])
tab = pd.DataFrame({"WAIC": [waic(ll_stat), waic(ll_lin), waic(ll_ode)]},
index=["stationary", "linear mu(t)", "ODE mu(t)"])
tab["dWAIC"] = tab["WAIC"] - tab["WAIC"].min()
print(tab.round(2).to_string()) WAIC dWAIC
stationary 448.04 9.21
linear mu(t) 442.72 3.89
ODE mu(t) 438.83 0.00
Time-varying return levels¶
As in the parametric notebook the return level is now a function of year. We read ((5)) at the ODE location for the first and last years, and place each observed maximum at its own year’s climate (its return period under the fitted GEV for that year), coloured by year — so the cloud rides the ODE trajectory from the early envelope up to the recent one.
# median-parameter ODE location at each year, and at the two reference years
mu0_m = float(np.median(post["mu0"]))
sig_m = float(np.exp(np.median(post["log_sigma"])))
xi_m = float(np.median(post["xi"]))
beta_m, tau_m = float(np.median(beta_s)), float(np.median(tau_s))
loc_year = mu0_m + np.asarray(integrate(beta_m, tau_m, ts))
loc_first, loc_last = float(loc_year[0]), float(loc_year[-1])
# each maximum at its own-year climate
p_t = np.asarray(gev_survival(y, jnp.asarray(loc_year), sig_m, xi_m))
p_t = np.clip(p_t, 0.5 / m, None)
T_own = 1.0 / p_t
t_lo = max(1.02, float(T_own.min()))
periods = jnp.logspace(np.log10(t_lo), 3, 80)
idx = jr.choice(jr.PRNGKey(5), post["mu0"].size, (600,), replace=False)
def rl_curve_at(tref):
# each posterior draw's location at year tref = mu0 + T(tref; beta, tau)
Tref = jax.vmap(lambda b, ta: integrate(b, ta, jnp.array([float(tref)]))[0])(
post["beta"][idx], post["tau"][idx])
locs = post["mu0"][idx] + Tref
sig = jnp.exp(post["log_sigma"])[idx]; xi = post["xi"][idx]
rl = jax.vmap(lambda i: GEV(loc=locs[i], scale=sig[i], concentration=xi[i])
.return_level(periods))(jnp.arange(locs.size))
return np.asarray(rl)
fig, ax = plt.subplots(figsize=(8.8, 5))
for tref, c, lab in [(t0, "#4C72B0", f"climate of {yr.min()}"),
(t1, "#C44E52", f"climate of {yr.max()}")]:
rl = rl_curve_at(tref)
med, lo, hi = (np.quantile(rl, q, 0) for q in (0.5, 0.025, 0.975))
ax.fill_between(np.asarray(periods), lo, hi, color=c, alpha=0.16)
ax.plot(np.asarray(periods), med, color=c, lw=2, label=lab)
sc = ax.scatter(T_own, y_np, c=yr, cmap="coolwarm", s=28, zorder=6,
edgecolor="0.25", linewidth=0.3, label="observed maxima")
fig.colorbar(sc, ax=ax, pad=0.02).set_label("year of the maximum")
ax.set_xscale("log")
ax.set(xlabel="return period T (years)", ylabel=r"return level $z_T$ (degC)",
title=f"{place}: return levels under the ODE trajectory")
ax.legend(loc="upper left")
plt.show()
Extension: couple the scale to the same state¶
A clean feature of the grey box is that one latent state can drive several GEV parameters. We let the log-scale share the trajectory, , so the spread of summer maxima rises (or falls) with the same warming signal — and check whether the data want it.
def gev_ode_sigma(obs):
mu0 = numpyro.sample("mu0", ndist.Normal(y_mean, 5.0))
beta = numpyro.sample("beta", ndist.Normal(0.0, 3.0))
tau = numpyro.sample("tau", ndist.LogNormal(np.log(20.0), 0.8))
s0 = numpyro.sample("s0", ndist.Normal(np.log(y_std), 0.5))
b_sig = numpyro.sample("b_sig", ndist.Normal(0.0, 0.3))
xi = numpyro.sample("xi", ndist.Normal(0.0, 0.25))
T = integrate(beta, tau, ts)
numpyro.sample("obs", GEV(loc=mu0 + T, scale=jnp.exp(s0 + b_sig * T),
concentration=xi), obs=obs)
mcmc_s = MCMC(NUTS(gev_ode_sigma, target_accept_prob=0.95, init_strategy=init_to_median),
num_warmup=1000, num_samples=1000, num_chains=2,
chain_method="sequential", progress_bar=False)
mcmc_s.run(jr.PRNGKey(4), y)
ps2 = mcmc_s.get_samples()
bsig = np.asarray(ps2["b_sig"]); s_lo, s_md, s_hi = ci(bsig)
T_all2 = jax.vmap(lambda b, ta: integrate(b, ta, ts))(ps2["beta"], ps2["tau"])
ll_ode_s = gev_log_prob(y[None, :], ps2["mu0"][:, None] + T_all2,
jnp.exp(ps2["s0"][:, None] + ps2["b_sig"][:, None] * T_all2),
ps2["xi"][:, None])
print(f"b_sig (scale coupling) = {s_md:+.3f} (95% CI {s_lo:+.3f}, {s_hi:+.3f})")
print(f"WAIC: ODE mu only {waic(ll_ode):.1f} | ODE mu+sigma {waic(ll_ode_s):.1f}")
print("scale coupling " + ("helps" if waic(ll_ode_s) < waic(ll_ode) - 2
else "is not supported by the data"))b_sig (scale coupling) = -0.029 (95% CI -0.280, +0.274)
WAIC: ODE mu only 438.8 | ODE mu+sigma 440.4
scale coupling is not supported by the data
Recap & where next¶
We swapped the straight line for a mechanism. The GEV location is now the
solution of a forced energy-balance ODE ((2)), integrated with diffrax
inside the model and inferred by NUTS. The realized warming matches the
parametric fit, but the trajectory can curve — and the posterior is candid
about what a century of maxima can support: the warming amplitude is well
determined, the response time τ much less so.
The grey box buys interpretability and physically-bounded extrapolation at the price of a committed functional form: if the true trend does not look like a one-box relaxation, the ODE cannot represent it. The final notebook, Non-stationary GEV — a state-space GP trend, goes to the opposite pole — a Gaussian process over time that assumes only smoothness and lets the data draw the warming curve freely — and puts all three trend models, line, ODE, and GP, on the same axes.