Gaussian copula on the observations — modelling spatial dependence
Annual temperature extremes over Spain — part 4: a Gaussian copula on the residuals¶
This is the fourth notebook in the Kronecker-GP-with-GEV thread, following 03
What changes is the observation model. nbs 02–03 assumed conditional independence: given the latent fields, at different stations are independent draws from their marginal GEVs. Real spatial extremes don’t behave like that — synoptic weather systems span hundreds of kilometres, so a hot summer in Madrid is also (with high probability) a hot summer in Toledo. This notebook adds a Gaussian copula on the per-year cross-station residuals, parameterised by a two-range exponential dependogram. The marginals stay GEV; only the joint structure changes.
Two practical consequences of getting this right:
- Joint return periods stop being absurd. Under conditional independence, the probability that 5 nearby stations all simultaneously exceed their 25-year return level is approximately . The true (correlated) probability is several orders of magnitude larger — that’s the difference between “essentially impossible” and “expect once a century in the cluster”.
- Posterior intervals become honest. Treating correlated draws as independent inflates the effective sample size, giving artificially tight posteriors. The copula corrects this — wider but well-calibrated intervals.
Background — Gaussian copula in three steps¶
The factorisation¶
Sklar’s theorem guarantees that for a continuous joint distribution with continuous marginals, the joint density factorises uniquely as
where are the marginals (already done — non-stationary GEV from nb 03), are their CDFs, and is the copula density on the unit hypercube. Choosing a Gaussian copula means
with a correlation matrix (unit diagonal) determining how the latent normals co-vary. The copula density vanishes () iff , which recovers conditional independence. Any deviation tells us the data are jointly non-trivial after marginalisation.
Two-range dependogram¶
We parameterise — correlation as a function of inter-station distance — using the two-range exponential mixture from jej_vc_snippets/extremes/models/temp_gevd_gp_copula.py:
Two ranges because synoptic extremes have at least two characteristic scales: a tight ~50 km convective band (high local correlation, decays fast) and a broader ~500 km synoptic envelope (lower correlation but persists across the whole peninsula). One range can’t capture both; the convex mixture interpolates smoothly between them.
Per-year likelihood, conditional on the latents¶
Conditional on the latent fields and the copula parameters , years are independent (the temporal trend is encoded in the deterministic piece of the location; whatever’s left is i.i.d. across years). Stations within a year are jointly Gaussian-coupled. So
The right-hand sum is the marginal log-likelihood from nb 03. The left-hand term is new and non-factorisable across stations within a year.
Variational treatment — plug-in copula¶
Marginal: integrate over via the same 3-D Gauss–Hermite as nb 03, per .
Copula: integrating over exactly would require an -dimensional Gauss-Hermite — intractable. Instead, we plug in the posterior means at the point of evaluation:
This is exactly the Inference-Functions-for-Margins (IFM) approximation from the spatial-extremes literature (Joe 2005), brought into a variational setting. The bias is — small when posterior variance is small, which it is for in our setup. (ξ is the noisy one, but it sits inside a CDF so its sensitivity is bounded.)
The full ELBO is then
Synthetic data — with spatial correlation¶
This is the substantive change from nb 03. We can no longer draw independent GEVs per (s, t); we must draw jointly correlated observations whose marginals match the per-station GEV . The standard recipe — the inverse of Sklar’s theorem — does this in three steps:
Step 1: sample the correlated latent normals for each year :
Step 2: probability-integral-transform to the unit hypercube via the standard normal CDF:
By construction, each is uniform on marginally, while jointly has Gaussian-copula dependence with correlation .
Step 3: marginal-CDF inverse at the per-station GEV parameters to land back in observation space:
Sanity check. Each is marginally regardless of (uniformity of is preserved through the inverse CDF). If we set , step 1 gives independent ’s, ’s, and ’s — recovering nb 03’s data-generating process. So the Gaussian-copula extension is a strict generalisation.
We pick true copula parameters
which gives a healthy mix of fast-decay near-field correlation and slow-decay long-range correlation across the Iberian peninsula.
Setup¶
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.scipy.linalg as jsla
import jax.scipy.stats as jss
import matplotlib.pyplot as plt
import numpy as np
import optax
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import lineax as lx
import numpyro.distributions as nd
from numpyro.distributions import constraints
from jaxtyping import Array, Float
from scipy.stats import genextreme
import gaussx
from pyrox.gp._src.kernels import matern_kernel
jax.config.update("jax_enable_x64", True)
EPS_U = 1e-6 # PIT clipping for numerical stability of Phi^{-1}
JITTER_R = 1e-6 # diagonal nudge to keep R strictly PD/anaconda/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Stations + GMST (same as nbs 02–03)¶
SPAIN_BBOX = (-9.5, 3.5, 36.0, 43.8)
def _on_land(lon, lat):
west_of_galicia = (lon < -8.8) & (lat > 42.3)
below_africa = lat < 36.2
east_mediterranean = (lon > 2.5) & (lat < 39.0)
return ~(west_of_galicia | below_africa | east_mediterranean)
def _sample_stations(key, n_target):
cands = jr.uniform(
key, (n_target * 4, 2),
minval=jnp.array([SPAIN_BBOX[0], SPAIN_BBOX[2]]),
maxval=jnp.array([SPAIN_BBOX[1], SPAIN_BBOX[3]]),
)
return cands[_on_land(cands[:, 0], cands[:, 1])][:n_target]
S = 40
T = 40
YEAR_0 = 1985
YEARS = jnp.arange(YEAR_0, YEAR_0 + T)
key = jr.PRNGKey(2024)
key, key_st = jr.split(key)
stations = _sample_stations(key_st, S)
lon_st = stations[:, 0]
lat_st = stations[:, 1]
def synthetic_gmst(years, key):
t_norm = (years - years[0]) / jnp.maximum(years[-1] - years[0], 1)
trend = 0.1 + 0.8 * t_norm
alpha = 0.6
white = 0.05 * jr.normal(key, (years.shape[0],))
def step(prev, eps):
cur = alpha * prev + eps
return cur, cur
_, red = jax.lax.scan(step, white[0], white[1:])
return trend + jnp.concatenate([white[:1], red])
key, key_g = jr.split(key)
gmst = synthetic_gmst(YEARS, key_g)
d_vec = gmst - jnp.mean(gmst)Pairwise station distances (degrees)¶
def pairwise_distance(coords):
diff = coords[:, None, :] - coords[None, :, :]
return jnp.sqrt(jnp.sum(diff ** 2, axis=-1))
D_mat = pairwise_distance(stations) # (S, S), Euclidean in lon/lat
print(f"station distance range: {float(D_mat[D_mat > 0].min()):.2f} -> {float(D_mat.max()):.2f} degrees")station distance range: 0.10 -> 13.01 degrees
Truth — four latent fields (same as nb 03) plus the copula¶
def _matern32_gram(X1, X2, var, ls):
return matern_kernel(X1, X2, jnp.asarray(var), jnp.asarray(ls), nu=1.5)
def draw_from_gp(K, key, jitter=1e-4):
L = jnp.linalg.cholesky(K + jitter * jnp.eye(K.shape[0]))
return L @ jr.normal(key, (K.shape[0],))
TRUTH = {
"mu0": 35.0,
"beta0": 1.2,
"logsig0": jnp.log(1.8),
"xi0": 0.12,
"k_mu_var": 4.0, "k_mu_ls": 2.0,
"k_beta_var": 0.25, "k_beta_ls": 3.0,
"k_sig_var": 0.05, "k_sig_ls": 3.0,
"k_xi_var": 0.003, "k_xi_ls": 3.0,
# Copula truth
"c0": 0.5,
"c1": 0.5, # degrees, ~55 km
"c2": 4.0, # degrees, ~440 km
}
K_mu_t = _matern32_gram(stations, stations, TRUTH["k_mu_var"], TRUTH["k_mu_ls"])
K_b_t = _matern32_gram(stations, stations, TRUTH["k_beta_var"], TRUTH["k_beta_ls"])
K_sig_t = _matern32_gram(stations, stations, TRUTH["k_sig_var"], TRUTH["k_sig_ls"])
K_xi_t = _matern32_gram(stations, stations, TRUTH["k_xi_var"], TRUTH["k_xi_ls"])
key, *subs = jr.split(key, 5)
mu_truth = draw_from_gp(K_mu_t, subs[0])
beta_truth = TRUTH["beta0"] + draw_from_gp(K_b_t, subs[1])
logsig_truth = TRUTH["logsig0"] + draw_from_gp(K_sig_t, subs[2])
xi_truth = TRUTH["xi0"] + draw_from_gp(K_xi_t, subs[3])
sigma_truth = jnp.exp(logsig_truth)
# Latent location field
f_truth = TRUTH["mu0"] + mu_truth[:, None] + beta_truth[:, None] * d_vec[None, :] # (S, T)True copula correlation matrix and its dependogram¶
def correlation_matrix(c0, c1, c2, dist):
n = dist.shape[0]
R = c0 * jnp.exp(-dist / c1) + (1.0 - c0) * jnp.exp(-dist / c2)
R = R.at[jnp.diag_indices(n)].set(1.0 + JITTER_R)
return R
R_truth = correlation_matrix(TRUTH["c0"], TRUTH["c1"], TRUTH["c2"], D_mat)
print(f"R* shape: {R_truth.shape} PD check: min eigval = {float(jnp.linalg.eigvalsh(R_truth).min()):.3e}")R* shape: (40, 40) PD check: min eigval = 9.946e-02
Inverse-PIT simulation (the math from above, in code)¶
def gev_inverse_cdf(u, loc, scale, shape):
"""Quantile function of GEV(loc, scale, shape). Handles xi -> 0 (Gumbel) limit."""
log_u = jnp.log(u)
small_xi = jnp.abs(shape) < 1e-6
safe_shape = jnp.where(small_xi, 1.0, shape)
y_gev = loc + (scale / safe_shape) * ((-log_u) ** (-safe_shape) - 1.0)
y_gumbel = loc - scale * jnp.log(-log_u)
return jnp.where(small_xi, y_gumbel, y_gev)
# Step 1: z_t ~ N(0, R*) for each year, jointly across stations
key, key_z = jr.split(key)
L_truth = jnp.linalg.cholesky(R_truth)
xi_iid = jr.normal(key_z, (T, S)) # (T, S) iid normal
z_truth = xi_iid @ L_truth.T # (T, S) correlated normals — Cov = L L^T = R*
# Step 2: PIT to uniform via Phi
u_truth = jss.norm.cdf(z_truth) # (T, S)
# Step 3: inverse GEV at the marginal parameters
# broadcast: tau (S, T) -> (T, S); sigma_truth (S,); xi_truth (S,)
y_obs = gev_inverse_cdf(
u_truth,
loc=f_truth.T, # (T, S)
scale=sigma_truth[None, :],
shape=xi_truth[None, :],
)
# Reshape to the (S, T) convention used everywhere else
y_obs = y_obs.T # (S, T)
print(f"y_obs shape: {y_obs.shape} range: [{float(y_obs.min()):.1f}, {float(y_obs.max()):.1f}] °C")
# Empirical correlation of the underlying z field — should be close to R_truth
emp_corr_z = jnp.corrcoef(z_truth.T)
print(f"empirical |emp_corr_z - R_truth| max: {float(jnp.max(jnp.abs(emp_corr_z - R_truth))):.3f}")
print(f" (rough: 1/sqrt(T) ~ {1.0 / np.sqrt(T):.3f}, so a few percent is expected)")y_obs shape: (40, 40) range: [28.1, 60.1] °C
empirical |emp_corr_z - R_truth| max: 0.502
(rough: 1/sqrt(T) ~ 0.158, so a few percent is expected)
GEV likelihood (verbatim from nb 03)¶
class GeneralizedExtremeValue(nd.Distribution):
arg_constraints = {
"loc": constraints.real, "scale": constraints.positive, "shape": constraints.real,
}
support = constraints.real
reparametrized_params = ["loc", "scale"]
def __init__(self, loc, scale, shape, *, validate_args=None):
self.loc = jnp.asarray(loc)
self.scale = jnp.asarray(scale)
self.shape = jnp.asarray(shape)
bs = jax.lax.broadcast_shapes(jnp.shape(self.loc), jnp.shape(self.scale), jnp.shape(self.shape))
super().__init__(batch_shape=bs, validate_args=validate_args)
def log_prob(self, value):
z = (value - self.loc) / self.scale
small_xi = jnp.abs(self.shape) < 1e-6
safe_shape = jnp.where(small_xi, 1.0, self.shape)
arg = 1.0 + safe_shape * z
safe_arg = jnp.where(arg > 0, arg, 1.0)
t = safe_arg ** (-1.0 / safe_shape)
gev_lp = -jnp.log(self.scale) - (1.0 + 1.0 / safe_shape) * jnp.log(safe_arg) - t
gev_lp = jnp.where(arg > 0, gev_lp, -jnp.inf)
gumbel_lp = -jnp.log(self.scale) - z - jnp.exp(-z)
return jnp.where(small_xi, gumbel_lp, gev_lp)
def sample(self, key, sample_shape=()):
u = jr.uniform(key, sample_shape + self.batch_shape, minval=1e-12, maxval=1.0)
small_xi = jnp.abs(self.shape) < 1e-6
safe_shape = jnp.where(small_xi, 1.0, self.shape)
gev_draw = self.loc + self.scale * ((-jnp.log(u)) ** (-safe_shape) - 1.0) / safe_shape
gumbel_draw = self.loc - self.scale * jnp.log(-jnp.log(u))
return jnp.where(small_xi, gumbel_draw, gev_draw)
def gev_cdf(y, loc, scale, shape):
"""Closed-form GEV CDF (handles the xi -> 0 limit)."""
z = (y - loc) / scale
small_xi = jnp.abs(shape) < 1e-6
safe_shape = jnp.where(small_xi, 1.0, shape)
arg = 1.0 + safe_shape * z
safe_arg = jnp.where(arg > 0, arg, 1e-12)
cdf_gev = jnp.exp(-safe_arg ** (-1.0 / safe_shape))
cdf_gumbel = jnp.exp(-jnp.exp(-z))
return jnp.where(small_xi, cdf_gumbel, cdf_gev)
# Sanity-check vs scipy
y_g = jnp.linspace(-2.0, 25.0, 60)
ours = GeneralizedExtremeValue(3.0, 1.5, 0.2).log_prob(y_g)
theirs = genextreme.logpdf(np.asarray(y_g), c=-0.2, loc=3.0, scale=1.5)
print(f"GEV vs scipy (xi=0.2) max |diff| = {float(jnp.max(jnp.abs(ours - theirs))):.2e}")GEV vs scipy (xi=0.2) max |diff| = 1.14e-13
Inspect the truth — four fields + the copula correlation as a function of distance¶
def plot_stations(ax, values, *, cmap, vlim, label):
ax.set_extent(
[SPAIN_BBOX[0] - 1, SPAIN_BBOX[1] + 1, SPAIN_BBOX[2] - 1, SPAIN_BBOX[3] + 1],
crs=ccrs.PlateCarree(),
)
ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
ax.add_feature(cfeature.BORDERS, linestyle=":", linewidth=0.5)
ax.add_feature(cfeature.OCEAN, facecolor="#e8f3fa")
ax.add_feature(cfeature.LAND, facecolor="#fdf9ec")
sc = ax.scatter(
np.asarray(lon_st), np.asarray(lat_st),
c=np.asarray(values), cmap=cmap, vmin=vlim[0], vmax=vlim[1],
s=55, edgecolors="k", linewidths=0.4,
transform=ccrs.PlateCarree(), zorder=5,
)
gl = ax.gridlines(draw_labels=True, linewidth=0.2, alpha=0.4)
gl.top_labels = gl.right_labels = False
plt.colorbar(sc, ax=ax, shrink=0.75, label=label)
fig = plt.figure(figsize=(16, 9))
specs = [
(1, mu_truth, "RdBu_r", (-5, 5), r"$\mu^*(s)$ [°C]", r"Spatial offset $\mu^*(s)$"),
(2, beta_truth, "viridis", (0.0, 2.4), r"$\beta^*(s)$", r"Amplification $\beta^*(s)$"),
(3, sigma_truth, "plasma", (1.0, 2.9), r"$\sigma^*(s)$ [°C]", r"GEV scale $\sigma^*(s)$"),
(4, xi_truth, "coolwarm", (0.0, 0.25), r"$\xi^*(s)$", r"GEV shape $\xi^*(s)$"),
]
for idx, vals, cmap, vlim, label, title in specs:
ax = fig.add_subplot(2, 3, idx, projection=ccrs.PlateCarree())
plot_stations(ax, vals, cmap=cmap, vlim=vlim, label=label)
ax.set_title(title)
# Dependogram: scatter of pairwise (distance, R*) plus the analytic curve
ax_dep = fig.add_subplot(2, 3, 5)
mask = np.triu(np.ones_like(np.asarray(D_mat), dtype=bool), k=1)
d_pairs = np.asarray(D_mat)[mask]
r_pairs = np.asarray(R_truth)[mask]
ax_dep.scatter(d_pairs, r_pairs, s=8, alpha=0.4, color="grey", label="pairs")
d_grid = np.linspace(0, float(D_mat.max()) * 1.05, 200)
r_curve = TRUTH["c0"] * np.exp(-d_grid / TRUTH["c1"]) + (1 - TRUTH["c0"]) * np.exp(-d_grid / TRUTH["c2"])
ax_dep.plot(d_grid, r_curve, "k-", lw=2, label=r"$R^*(d)$ analytic")
ax_dep.set_xlabel("pairwise distance [degrees]")
ax_dep.set_ylabel(r"$R^*_{ij}$")
ax_dep.set_title(r"True dependogram (two-range exponential)")
ax_dep.legend()
ax_dep.grid(alpha=0.3)
ax_dep.set_ylim(-0.05, 1.05)
# Empirical: correlation of the latent z field across years
ax_emp = fig.add_subplot(2, 3, 6)
emp_pairs = np.asarray(emp_corr_z)[mask]
ax_emp.scatter(d_pairs, emp_pairs, s=8, alpha=0.4, color="C2", label="empirical corr(z)")
ax_emp.plot(d_grid, r_curve, "k-", lw=2, label=r"$R^*(d)$")
ax_emp.set_xlabel("pairwise distance [degrees]")
ax_emp.set_ylabel("correlation")
ax_emp.set_title(r"Empirical $\widehat{\mathrm{corr}}(z)$ matches $R^*(d)$")
ax_emp.legend()
ax_emp.grid(alpha=0.3)
ax_emp.set_ylim(-0.3, 1.05)
plt.tight_layout()
plt.show()
Bottom-row check: the empirical pair-correlation of the truth field tracks the analytic dependogram, with scatter . That’s the structure the copula model needs to recover from (not from — which is unobserved).
Kernel + variational factor classes (same as nb 03)¶
class MaternLite(eqx.Module):
log_variance: Float[Array, ""]
log_lengthscale: Float[Array, ""]
@classmethod
def init(cls, variance=1.0, lengthscale=1.0):
return cls(
log_variance=jnp.log(jnp.asarray(variance)),
log_lengthscale=jnp.log(jnp.asarray(lengthscale)),
)
def __call__(self, X1, X2):
return matern_kernel(X1, X2, jnp.exp(self.log_variance), jnp.exp(self.log_lengthscale), nu=1.5)
def _tril(L):
tri = jnp.tril(L, k=-1)
diag = jax.nn.softplus(jnp.diag(L))
return tri + jnp.diag(diag)
class VariationalFactor(eqx.Module):
mean: Float[Array, " N"]
raw_L: Float[Array, "N N"]
@classmethod
def init(cls, n: int, scale: float = 0.3):
raw_L = jnp.eye(n) * jnp.log(jnp.expm1(jnp.asarray(scale)))
return cls(mean=jnp.zeros(n), raw_L=raw_L)
@property
def L(self):
return _tril(self.raw_L)
@property
def cov(self):
L = self.L
return L @ L.T
@property
def variance_diag(self):
L = self.L
return jnp.sum(L * L, axis=-1)Model — nb 03 + three copula scalars¶
The new fields are stored in unconstrained chart so the optimiser sees a parameter; properties give us back and .
class Model(eqx.Module):
k_mu: MaternLite
k_beta: MaternLite
k_sig: MaternLite
k_xi: MaternLite
q_mu: VariationalFactor
q_beta: VariationalFactor
q_sig: VariationalFactor
q_xi: VariationalFactor
mu0: Float[Array, ""]
beta0: Float[Array, ""]
logsig0: Float[Array, ""]
xi0: Float[Array, ""]
raw_c0: Float[Array, ""] # sigmoid(raw_c0) = c0 in (0, 1)
log_c1: Float[Array, ""] # exp(log_c1) = c1 > 0
log_c2: Float[Array, ""] # exp(log_c2) = c2 > 0
jitter: float = eqx.field(static=True, default=1e-4)
@classmethod
def init(
cls, S: int, *,
mu0_init=30.0, beta0_init=1.0, sigma_init=2.0, xi_init=0.05,
c0_init=0.5, c1_init=1.0, c2_init=2.0,
):
return cls(
k_mu=MaternLite.init(variance=1.0, lengthscale=3.0),
k_beta=MaternLite.init(variance=0.2, lengthscale=3.0),
k_sig=MaternLite.init(variance=0.05, lengthscale=3.0),
k_xi=MaternLite.init(variance=0.003, lengthscale=3.0),
q_mu=VariationalFactor.init(S, scale=0.5),
q_beta=VariationalFactor.init(S, scale=0.3),
q_sig=VariationalFactor.init(S, scale=0.1),
q_xi=VariationalFactor.init(S, scale=0.05),
mu0=jnp.asarray(mu0_init),
beta0=jnp.asarray(beta0_init),
logsig0=jnp.log(jnp.asarray(sigma_init)),
xi0=jnp.asarray(xi_init),
raw_c0=jnp.log(jnp.asarray(c0_init / (1.0 - c0_init))), # logit
log_c1=jnp.log(jnp.asarray(c1_init)),
log_c2=jnp.log(jnp.asarray(c2_init)),
)
@property
def c0(self):
return jax.nn.sigmoid(self.raw_c0)
@property
def c1(self):
return jnp.exp(self.log_c1)
@property
def c2(self):
return jnp.exp(self.log_c2)
def _K(self, kernel, X):
K = kernel(X, X) + self.jitter * jnp.eye(X.shape[0])
return lx.MatrixLinearOperator(K, lx.positive_semidefinite_tag)ELBO — KL + 3-D GH ELL (nb 03) + plug-in copula correction¶
def kl_factor(q: VariationalFactor, K_op):
zeros = jnp.zeros_like(q.mean)
q_cov = lx.MatrixLinearOperator(q.cov, lx.positive_semidefinite_tag)
return gaussx.dist_kl_divergence(q.mean, q_cov, zeros, K_op)
_GH3 = gaussx.GaussHermiteIntegrator(order=8)
def ell_point(m_tau, v_tau, m_logsig, v_logsig, m_xi, v_xi, y):
r"""$\mathbb{E}_q[\log \mathrm{GEV}(y \mid \tau, e^{\log\sigma}, \xi)]$ via 3-D GH per (s, t)."""
mean3 = jnp.stack([m_tau, m_logsig, m_xi])
var3 = jnp.stack([v_tau, v_logsig, v_xi])
cov3 = jnp.diag(var3)
state = gaussx.GaussianState(
mean=mean3,
cov=lx.MatrixLinearOperator(cov3, lx.positive_semidefinite_tag),
)
def log_lik(f):
return GeneralizedExtremeValue(loc=f[0], scale=jnp.exp(f[1]), shape=f[2]).log_prob(y)
return gaussx.log_likelihood_expectation(log_lik, state, _GH3)
def total_marginal_ell(model: Model, d_vec, y):
m_mu_s = model.q_mu.mean
v_mu_s = model.q_mu.variance_diag
m_b_s = model.beta0 + model.q_beta.mean
v_b_s = model.q_beta.variance_diag
m_sig_s = model.logsig0 + model.q_sig.mean
v_sig_s = model.q_sig.variance_diag
m_xi_s = model.xi0 + model.q_xi.mean
v_xi_s = model.q_xi.variance_diag
m_tau = model.mu0 + m_mu_s[:, None] + m_b_s[:, None] * d_vec[None, :]
v_tau = v_mu_s[:, None] + v_b_s[:, None] * (d_vec ** 2)[None, :]
m_logsig = jnp.broadcast_to(m_sig_s[:, None], (S, T))
v_logsig = jnp.broadcast_to(v_sig_s[:, None], (S, T))
m_xi_b = jnp.broadcast_to(m_xi_s[:, None], (S, T))
v_xi_b = jnp.broadcast_to(v_xi_s[:, None], (S, T))
ell_vmapped = jax.vmap(jax.vmap(ell_point), in_axes=(0,) * 7)
return jnp.sum(ell_vmapped(m_tau, v_tau, m_logsig, v_logsig, m_xi_b, v_xi_b, y))
def copula_correction(model: Model, dist_mat, d_vec, y):
"""Plug-in (IFM) copula log-density at the posterior means.
For each year t:
u_st = F_GEV(y_st | tau_st, sigma_s, xi_s) # PIT
z_st = Phi^{-1}(u_st) # decorrelate
log c(u_t; R) = -0.5 log|R| - 0.5 z_t^T (R^{-1} - I) z_t
"""
# Posterior means of the latent fields
m_mu_s = model.q_mu.mean
m_b_s = model.beta0 + model.q_beta.mean
sigma_s = jnp.exp(model.logsig0 + model.q_sig.mean)
xi_s = model.xi0 + model.q_xi.mean
# tau on the (S, T) grid
tau_st = model.mu0 + m_mu_s[:, None] + m_b_s[:, None] * d_vec[None, :] # (S, T)
# PIT and inverse-Phi
u = gev_cdf(y, tau_st, sigma_s[:, None], xi_s[:, None])
u = jnp.clip(u, EPS_U, 1.0 - EPS_U)
z = jss.norm.ppf(u) # (S, T)
# Build R(c0, c1, c2) and Cholesky
R = correlation_matrix(model.c0, model.c1, model.c2, dist_mat) # (S, S)
L = jnp.linalg.cholesky(R)
log_det_R = 2.0 * jnp.sum(jnp.log(jnp.diag(L)))
# Per-year quadratic form: z_t^T R^{-1} z_t = || L^{-1} z_t ||_2^2
Linv_z = jsla.solve_triangular(L, z, lower=True) # (S, T)
quad_per_year = jnp.sum(Linv_z ** 2, axis=0) # (T,)
z_norm_per_year = jnp.sum(z ** 2, axis=0) # (T,)
log_c_per_year = -0.5 * log_det_R - 0.5 * (quad_per_year - z_norm_per_year)
return jnp.sum(log_c_per_year)
def neg_elbo(model: Model, X_s, dist_mat, d_vec, y):
kl = (
kl_factor(model.q_mu, model._K(model.k_mu, X_s))
+ kl_factor(model.q_beta, model._K(model.k_beta, X_s))
+ kl_factor(model.q_sig, model._K(model.k_sig, X_s))
+ kl_factor(model.q_xi, model._K(model.k_xi, X_s))
)
ell_marginal = total_marginal_ell(model, d_vec, y)
ell_copula = copula_correction(model, dist_mat, d_vec, y)
return kl - ell_marginal - ell_copulaTraining¶
model = Model.init(
S=S,
mu0_init=float(jnp.mean(y_obs)),
beta0_init=1.0,
sigma_init=2.0,
xi_init=0.05,
c0_init=0.5,
c1_init=1.0,
c2_init=2.0,
)
optimiser = optax.chain(optax.clip_by_global_norm(5.0), optax.adam(5e-3))
opt_state = optimiser.init(eqx.filter(model, eqx.is_inexact_array))
@eqx.filter_jit
def train_step(model, opt_state):
loss, grads = eqx.filter_value_and_grad(neg_elbo)(model, stations, D_mat, d_vec, y_obs)
updates, opt_state = optimiser.update(grads, opt_state, model)
return eqx.apply_updates(model, updates), opt_state, loss
n_steps = 3000
losses: list[float] = []
for _ in range(n_steps):
model, opt_state, loss = train_step(model, opt_state)
losses.append(float(loss))
beta_post = model.beta0 + model.q_beta.mean
sig_post = jnp.exp(model.logsig0 + model.q_sig.mean)
xi_post = model.xi0 + model.q_xi.mean
c0_fit, c1_fit, c2_fit = float(model.c0), float(model.c1), float(model.c2)
print(f"final -ELBO = {losses[-1]:.2f}")
print(f"fitted mu0 = {float(model.mu0):.2f} (truth {TRUTH['mu0']:.2f})")
print(f"fitted beta0 = {float(model.beta0):.3f} (truth mean beta* {float(beta_truth.mean()):.3f})")
print(f"fitted logsig0 = {float(model.logsig0):.3f} (truth {float(TRUTH['logsig0']):.3f})")
print(f"fitted xi0 = {float(model.xi0):.3f} (truth {TRUTH['xi0']:.3f})")
print(f"fitted sigma(s) rng = [{float(sig_post.min()):.2f}, {float(sig_post.max()):.2f}] °C "
f"(truth [{float(sigma_truth.min()):.2f}, {float(sigma_truth.max()):.2f}])")
print(f"fitted xi(s) range = [{float(xi_post.min()):.3f}, {float(xi_post.max()):.3f}] "
f"(truth [{float(xi_truth.min()):.3f}, {float(xi_truth.max()):.3f}])")
print(f"fitted c0 = {c0_fit:.3f} (truth {TRUTH['c0']:.3f})")
print(f"fitted c1 = {c1_fit:.3f} (truth {TRUTH['c1']:.3f}) [degrees]")
print(f"fitted c2 = {c2_fit:.3f} (truth {TRUTH['c2']:.3f}) [degrees]")final -ELBO = 3425.27
fitted mu0 = 34.40 (truth 35.00)
fitted beta0 = 1.465 (truth mean beta* 1.235)
fitted logsig0 = 0.621 (truth 0.588)
fitted xi0 = 0.098 (truth 0.120)
fitted sigma(s) rng = [1.47, 2.25] °C (truth [1.53, 2.54])
fitted xi(s) range = [0.094, 0.103] (truth [-0.012, 0.223])
fitted c0 = 0.480 (truth 0.500)
fitted c1 = 0.412 (truth 0.500) [degrees]
fitted c2 = 3.729 (truth 4.000) [degrees]
Loss curve¶
fig, ax = plt.subplots(figsize=(10, 3.2))
ax.plot(losses, "C5-", lw=1.2)
ax.set_xlabel("step")
ax.set_ylabel("-ELBO")
ax.set_yscale("symlog", linthresh=100.0)
ax.set_title("Training — 4-GP marginals + Gaussian copula on residuals")
ax.grid(alpha=0.3, which="both")
plt.show()
Recovery of the four latent fields¶
Same recovery panels as nb 03 — the copula doesn’t change the marginal recovery in either direction (modulo small effects from the joint correction).
mu_mean = model.q_mu.mean
mu_std = jnp.sqrt(model.q_mu.variance_diag)
beta_mean = model.beta0 + model.q_beta.mean
beta_std = jnp.sqrt(model.q_beta.variance_diag)
sig_mean = jnp.exp(model.logsig0 + model.q_sig.mean)
sig_std = sig_mean * jnp.sqrt(model.q_sig.variance_diag)
xi_mean = model.xi0 + model.q_xi.mean
xi_std = jnp.sqrt(model.q_xi.variance_diag)
rows = [
(r"$\mu$", r"$\mu^*(s)$", r"$\hat\mu(s)$", mu_truth, mu_mean, mu_std, "RdBu_r", (-5, 5)),
(r"$\beta$", r"$\beta^*(s)$", r"$\hat\beta(s)$", beta_truth, beta_mean, beta_std, "viridis", (0.0, 2.4)),
(r"$\sigma$", r"$\sigma^*(s)$", r"$\hat\sigma(s)$", sigma_truth, sig_mean, sig_std, "plasma", (1.0, 2.9)),
(r"$\xi$", r"$\xi^*(s)$", r"$\hat\xi(s)$", xi_truth, xi_mean, xi_std, "coolwarm", (0.0, 0.25)),
]
fig = plt.figure(figsize=(16, 15))
for r, (short, truth_lab, post_lab, truth_v, post_m, post_s, cmap, vlim) in enumerate(rows):
ax = fig.add_subplot(4, 3, 3 * r + 1, projection=ccrs.PlateCarree())
plot_stations(ax, truth_v, cmap=cmap, vlim=vlim, label=truth_lab)
ax.set_title(truth_lab + " — truth")
ax = fig.add_subplot(4, 3, 3 * r + 2, projection=ccrs.PlateCarree())
plot_stations(ax, post_m, cmap=cmap, vlim=vlim, label=post_lab)
ax.set_title(post_lab + " — posterior mean")
ax = fig.add_subplot(4, 3, 3 * r + 3)
ax.errorbar(
np.asarray(truth_v), np.asarray(post_m), yerr=2 * np.asarray(post_s),
fmt="o", ms=4, alpha=0.75, capsize=2,
)
lo, hi = vlim
ax.plot([lo, hi], [lo, hi], "k--", lw=1)
ax.set_xlim(lo, hi)
ax.set_ylim(lo, hi)
ax.set_xlabel(truth_lab + " truth")
ax.set_ylabel(post_lab + r" $\pm 2\sigma$")
ax.set_title(short + " recovery")
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()
The killer plot — copula recovery¶
The point of the copula is to model the residual cross-station dependence. Three things should line up on a single dependogram axis:
- Truth
- Empirical residual correlation of the fitted — i.e. apply the model’s inferred margins to the data, then read off pairwise correlations
- Fitted parametric
Agreement of all three says the copula is recovering both the parametric form and the empirical residual structure consistently.
# Empirical z's at posterior means
tau_post = model.mu0 + mu_mean[:, None] + beta_mean[:, None] * d_vec[None, :]
u_hat = gev_cdf(y_obs, tau_post, sig_mean[:, None], xi_mean[:, None])
u_hat = jnp.clip(u_hat, EPS_U, 1.0 - EPS_U)
z_hat = jss.norm.ppf(u_hat)
emp_corr_hat = jnp.corrcoef(z_hat) # (S, S) station-by-station correlation across years
emp_pairs_hat = np.asarray(emp_corr_hat)[mask]
R_fit = correlation_matrix(model.c0, model.c1, model.c2, D_mat)
fit_pairs = np.asarray(R_fit)[mask]
d_grid = np.linspace(0, float(D_mat.max()) * 1.05, 200)
truth_curve = TRUTH["c0"] * np.exp(-d_grid / TRUTH["c1"]) + (1 - TRUTH["c0"]) * np.exp(-d_grid / TRUTH["c2"])
fit_curve = c0_fit * np.exp(-d_grid / c1_fit) + (1 - c0_fit) * np.exp(-d_grid / c2_fit)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Left: scatter of empirical vs fitted parametric, with truth curve
ax = axes[0]
ax.scatter(d_pairs, emp_pairs_hat, s=10, alpha=0.5, color="C2", label=r"empirical $\widehat{\mathrm{corr}}(\hat z)$")
ax.scatter(d_pairs, fit_pairs, s=10, alpha=0.5, color="C5", label=r"fitted $\hat R_{ij}$")
ax.plot(d_grid, truth_curve, "k-", lw=2, label=r"truth $R^*(d)$")
ax.plot(d_grid, fit_curve, "C5--", lw=1.6, label=r"fitted $\hat R(d)$")
ax.set_xlabel("pairwise distance [degrees]")
ax.set_ylabel("correlation")
ax.set_title("Dependogram recovery — three should agree")
ax.legend()
ax.grid(alpha=0.3)
ax.set_ylim(-0.3, 1.05)
# Right: side-by-side bar of c0, c1, c2
ax = axes[1]
labels = [r"$c_0$", r"$c_1$ [deg]", r"$c_2$ [deg]"]
truths = [TRUTH["c0"], TRUTH["c1"], TRUTH["c2"]]
fits = [c0_fit, c1_fit, c2_fit]
x = np.arange(len(labels))
ax.bar(x - 0.18, truths, width=0.35, color="k", alpha=0.7, label="truth")
ax.bar(x + 0.18, fits, width=0.35, color="C5", alpha=0.85, label="fitted")
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.set_title("Copula scalar recovery")
ax.legend()
ax.grid(alpha=0.3, axis="y")
plt.tight_layout()
plt.show()
Joint exceedance probabilities — where the copula visibly wins¶
The copula’s biggest practical pay-off: it correctly accounts for the probability that multiple nearby stations exceed their return level in the same year. Pick the 5 stations closest to the centre of the peninsula; ask “what’s the probability that all 5 exceed their station-specific 25-year return level in the same year?”
Under conditional independence (nb 03’s implicit model), this is just .
Under the fitted Gaussian copula, the joint probability is computed by the multivariate normal CDF:
where on the sub-correlation matrix indexed by .
# Pick 5 nearest-to-centre stations
centre = jnp.array([(SPAIN_BBOX[0] + SPAIN_BBOX[1]) / 2, (SPAIN_BBOX[2] + SPAIN_BBOX[3]) / 2])
dist_to_centre = jnp.sqrt(jnp.sum((stations - centre[None, :]) ** 2, axis=1))
core_idx = jnp.argsort(dist_to_centre)[:5]
p = 1.0 / 25.0 # exceedance probability for the 25-year level
threshold_z = float(jss.norm.ppf(1.0 - p)) # ~1.751
# Independent: (p)^K
prob_indep = p ** len(core_idx)
# Copula: MC-estimate Pr(all Z_i > threshold) under N(0, R_I)
R_I = R_fit[jnp.ix_(core_idx, core_idx)]
key, key_mc = jr.split(key)
n_mc = 200_000
L_I = jnp.linalg.cholesky(R_I)
Z_mc = jr.normal(key_mc, (n_mc, len(core_idx))) @ L_I.T
exceed = jnp.all(Z_mc > threshold_z, axis=1)
prob_copula = float(jnp.mean(exceed))
# Empirical sanity (very noisy at p^K events with T=40 — but worth a glance).
# Threshold is the per-(s, t) GEV inverse CDF at quantile (1 - p), using fitted margins.
core_np = np.asarray(core_idx)
core_thresholds = np.asarray(
gev_inverse_cdf(
jnp.asarray(1.0 - p),
loc=tau_post[core_idx, :],
scale=sig_mean[core_idx, None],
shape=xi_mean[core_idx, None],
)
) # (5, T) — per-station, per-year 25-yr return level
y_obs_np = np.asarray(y_obs)
exceed_emp_count = int(np.sum(np.all(y_obs_np[core_np, :] > core_thresholds, axis=0)))
prob_emp = exceed_emp_count / T
print("Pr(all 5 nearest-centre stations simultaneously exceed their 25-yr return level):")
print(f" conditional independence (nb 03 implicit): {prob_indep:.2e}")
print(f" Gaussian copula (this notebook, MC): {prob_copula:.2e} "
f"({prob_copula / prob_indep:.0f}x larger)")
print(f" observed in {T} years of data: {prob_emp:.2e} "
f"({exceed_emp_count}/{T} years)")Pr(all 5 nearest-centre stations simultaneously exceed their 25-yr return level):
conditional independence (nb 03 implicit): 1.02e-07
Gaussian copula (this notebook, MC): 1.60e-04 (1562x larger)
observed in 40 years of data: 0.00e+00 (0/40 years)
Three orders of magnitude or more — that’s the practical cost of ignoring spatial dependence in extremes. The empirical count from years is a noisy estimate (we’re counting at most a handful of events at probabilities), but it sits in the copula’s ballpark, not the independence prediction’s.
Summary¶
- The four-GP marginal model from nb 03 is unchanged. The copula adds a single per-year correction term, , computed at the posterior means (plug-in / IFM approximation) — exactly the recipe in
jej_vc_snippets/extremes/models/temp_gevd_gp_copula.py, ported from numpyro MCMC into our variational scaffold. - The two-range exponential dependogram recovers cleanly with three trainable scalars: a convective-scale , a synoptic-scale , and a mixing weight .
- The empirical residual correlation matches both the truth and the fitted parametric curve — a one-glance sanity check that the model isn’t compensating elsewhere.
- On joint return periods (5 nearby stations all exceeding their 25-year level in the same year), the copula gives an answer that is several orders of magnitude larger than the conditional-independence approximation — and consistent with the empirical event count. For any application that cares about compound climate risk (insurance, infrastructure stress-tests, grid load), this is the correction that matters.
Follow-ups¶
- Reparameterised-MC copula contribution instead of plug-in: removes the Jensen bias at the cost of a stochastic ELBO. Worth doing when posterior variance on the latents is large (e.g. shorter records, more aggressive priors).
- Non-stationary copula: that varies across the domain, capturing region-specific dependence regimes (Atlantic vs continental vs Mediterranean). Needs a more flexible parameterisation than the two-range exponential.
- Bivariate non-Gaussian copulas (Student-, Gumbel, mixed) for tail dependence — the Gaussian copula is asymptotically tail-independent, which under-models the tendency of extreme heatwaves to “synchronise” across regions.