PDE parameter estimation
Part 6.3: Fitting a PDE Model to Data¶
Estimate advection velocity (U) and diffusion coefficient (κ) of the 1D advection-diffusion equation from synthetic observations.
Stack: equinox (model) · diffrax (PDE solver) · optax (optimizer) · coordax (state wrapping)
import time as _time
import coordax as cx
import jax
import jax.numpy as jnp
try:
import equinox as eqx
has_eqx = True
except ImportError:
has_eqx = False
print("equinox not installed.")
try:
import diffrax
has_diffrax = True
except ImportError:
has_diffrax = False
print("diffrax not installed.")
try:
import optax
has_optax = True
except ImportError:
has_optax = False
print("optax not installed.")
# NOTE: cx.field() is preferred over cx.wrap() in coordax >= 0.2Problem Setup: 1D Advection-Diffusion¶
PDE: ∂T/∂t = −U·∂T/∂x + κ·∂²T/∂x² (periodic domain)
n_space = 64
x_values = jnp.linspace(0, 100, n_space, endpoint=False) # km
dx = float(x_values[1] - x_values[0])
x_axis = cx.LabeledAxis('space', x_values)
T_base, T_amp, x_center, sigma = 288.0, 20.0, 25.0, 5.0
T_initial_data = T_base + T_amp * jnp.exp(-((x_values - x_center) / sigma)**2)
T_initial = cx.field(T_initial_data, x_axis)
# shape: (64,) | dims: ('space',) | units: K — Gaussian pulse: T(x,0) = T_base + A·exp(…)
# True parameters (to be recovered)
U_true = 5.0 # km/h
kappa_true = 0.5 # km²/h
print(f"Spatial domain: {float(x_values.min()):.1f} — {float(x_values.max()):.1f} km ({n_space} pts)")
print(f"True params: U={U_true} km/h, κ={kappa_true} km²/h")Spatial domain: 0.0 — 98.4 km (64 pts)
True params: U=5.0 km/h, κ=0.5 km²/h
PDE Right-Hand Side (Spectral, Periodic BC)¶
def advection_diffusion_rhs(t, T_data, args):
"""dT/dt = -U*dT/dx + κ*d²T/dx² using spectral derivatives."""
U, kappa, dx, xa = args['U'], args['kappa'], args['dx'], args['x_axis']
n = len(xa.ticks)
T = cx.field(T_data, xa)
k = jnp.fft.rfftfreq(n, d=dx)
T_hat = jnp.fft.rfft(T.untag('space').data)
dT_dx = jnp.fft.irfft(2j * jnp.pi * k * T_hat, n=n)
d2T_dx2 = jnp.fft.irfft(-(2 * jnp.pi * k) ** 2 * T_hat, n=n)
# dT_dx: shape (n,) | units: K/km — ∂T/∂x (spectral)
# d2T_dx2: shape (n,) | units: K/km² — ∂²T/∂x²
return -U * dT_dx + kappa * d2T_dx2
args_true = {'U': U_true, 'kappa': kappa_true, 'dx': dx, 'x_axis': x_axis}
rhs_test = advection_diffusion_rhs(0.0, T_initial.data, args_true)
print(f"RHS range: {float(rhs_test.min()):.3f} — {float(rhs_test.max()):.3f} K/h")RHS range: -17.034 — 16.797 K/h
Generate Synthetic Observations¶
if has_diffrax:
t_obs = jnp.linspace(0, 5.0, 11) # 11 snapshots over 5 h
time_axis = cx.LabeledAxis('time', t_obs)
dt_max = min(dx / U_true, 0.5 * dx**2 / kappa_true) * 0.5
solution_true = diffrax.diffeqsolve(
terms=diffrax.ODETerm(advection_diffusion_rhs),
solver=diffrax.Tsit5(),
t0=0.0, t1=5.0, dt0=None,
y0=T_initial.data,
args=args_true,
saveat=diffrax.SaveAt(ts=t_obs),
stepsize_controller=diffrax.PIDController(rtol=1e-5, atol=1e-7, dtmax=dt_max),
max_steps=50000,
)
T_obs = solution_true.ys # (11, 64)
key = jax.random.PRNGKey(0)
noise_level = 0.5 # K
T_obs_noisy = T_obs + jax.random.normal(key, T_obs.shape) * noise_level
T_obs_field = cx.field(T_obs_noisy, time_axis, x_axis)
# shape: (11, 64) | dims: ('time','space') | units: K — noisy observations (σ=0.5 K)
print(f"Observations: {T_obs_field.dims}, shape: {T_obs_field.shape}")
print(f"Noise: {noise_level} K (Gaussian)")
else:
t_obs = jnp.linspace(0, 5.0, 11)
time_axis = cx.LabeledAxis('time', t_obs)
T_obs_noisy = jnp.zeros((11, n_space))
T_obs_field = cx.field(T_obs_noisy, time_axis, x_axis)
print("Skipping observation generation (diffrax not installed).")Observations: ('time', 'space'), shape: (11, 64)
Noise: 0.5 K (Gaussian)
PDE Model with Equinox¶
if has_eqx and has_diffrax:
class AdvDiffModel(eqx.Module):
"""Differentiable 1D advection-diffusion PDE model."""
U: float
kappa: float
dx: float = eqx.field(static=True)
x_axis: cx.Coordinate = eqx.field(static=True)
dt_max: float = eqx.field(static=True)
def __call__(self, t_span, y0):
args = {'U': self.U, 'kappa': self.kappa,
'dx': self.dx, 'x_axis': self.x_axis}
sol = diffrax.diffeqsolve(
terms=diffrax.ODETerm(advection_diffusion_rhs),
solver=diffrax.Tsit5(),
t0=t_span[0], t1=t_span[-1], dt0=None,
y0=y0,
args=args,
saveat=diffrax.SaveAt(ts=t_span),
stepsize_controller=diffrax.PIDController(
rtol=1e-5, atol=1e-7, dtmax=self.dt_max
),
max_steps=50000,
adjoint=diffrax.RecursiveCheckpointAdjoint(),
)
# sol.ys: shape (n_time, n_space) | units: K — predicted temperature trajectory
return sol.ys # (n_time, n_space)
dt_max = min(dx / U_true, 0.5 * dx**2 / kappa_true) * 0.5
model_true_pde = AdvDiffModel(U=U_true, kappa=kappa_true, dx=dx,
x_axis=x_axis, dt_max=dt_max)
# Perturbed initial guess — use jnp.array so eqx.is_array recognizes them
model_init = AdvDiffModel(U=jnp.array(3.0), kappa=jnp.array(0.2), dx=dx,
x_axis=x_axis, dt_max=dt_max)
print(f"Initial guess: U={float(model_init.U)}, κ={float(model_init.kappa)}")Initial guess: U=3.0, κ=0.20000000298023224
Loss Function and Optimization¶
if has_eqx and has_diffrax and has_optax:
def pde_loss(model, T_obs, t_span, y0):
"""MSE between model predictions and observations."""
T_pred = model(t_span, y0)
return jnp.mean((T_pred - T_obs)**2)
loss_init = pde_loss(model_init, T_obs_noisy, t_obs, T_initial.data)
loss_true = pde_loss(model_true_pde, T_obs_noisy, t_obs, T_initial.data)
print(f"Loss at initial guess: {float(loss_init):.4f}")
print(f"Loss at true params: {float(loss_true):.4f}")
optimizer = optax.adam(1e-2)
model_opt = model_init
opt_state = optimizer.init(eqx.filter(model_opt, eqx.is_array))
@eqx.filter_jit
def train_step(model, opt_state, T_obs, t_span, y0):
loss, grads = eqx.filter_value_and_grad(pde_loss)(model, T_obs, t_span, y0)
updates, new_opt_state = optimizer.update(
eqx.filter(grads, eqx.is_array), opt_state
)
new_model = eqx.apply_updates(model, updates)
return new_model, new_opt_state, loss
losses = []
t0 = _time.perf_counter()
for i in range(500):
model_opt, opt_state, loss = train_step(
model_opt, opt_state, T_obs_noisy, t_obs, T_initial.data
)
losses.append(float(loss))
if (i + 1) % 25 == 0:
print(f" iter {i+1:3d}: loss={float(loss):.4f} "
f"U={float(model_opt.U):.3f} κ={float(model_opt.kappa):.4f}")
elapsed = _time.perf_counter() - t0
print(f"\nOptimization: {elapsed:.1f}s for 500 steps")
print(f"True: U={U_true}, κ={kappa_true}")
print(f"Recovered: U={float(model_opt.U):.4f}, κ={float(model_opt.kappa):.4f}")
print(f"U error: {abs(float(model_opt.U) - U_true):.4f}")
print(f"κ error: {abs(float(model_opt.kappa) - kappa_true):.4f}")Loss at initial guess: 17.1521
Loss at true params: 0.2363
iter 25: loss=13.6256 U=3.251 κ=0.4448
iter 50: loss=10.2988 U=3.501 κ=0.6630
iter 75: loss=7.4664 U=3.745 κ=0.8407
iter 100: loss=5.2027 U=3.974 κ=0.9729
iter 125: loss=3.5026 U=4.183 κ=1.0598
iter 150: loss=2.2995 U=4.367 κ=1.1055
iter 175: loss=1.4947 U=4.523 κ=1.1174
iter 200: loss=0.9843 U=4.651 κ=1.1038
iter 225: loss=0.6763 U=4.753 κ=1.0725
iter 250: loss=0.4981 U=4.831 κ=1.0303
iter 275: loss=0.3976 U=4.889 κ=0.9822
iter 300: loss=0.3410 U=4.930 κ=0.9319
iter 325: loss=0.3079 U=4.958 κ=0.8817
iter 350: loss=0.2871 U=4.977 κ=0.8333
iter 375: loss=0.2729 U=4.988 κ=0.7877
iter 400: loss=0.2626 U=4.995 κ=0.7456
iter 425: loss=0.2549 U=4.999 κ=0.7075
iter 450: loss=0.2491 U=5.001 κ=0.6734
iter 475: loss=0.2449 U=5.002 κ=0.6436
iter 500: loss=0.2419 U=5.003 κ=0.6179
Optimization: 18.8s for 500 steps
True: U=5.0, κ=0.5
Recovered: U=5.0029, κ=0.6179
U error: 0.0029
κ error: 0.1179
Coordax: Wrap Predictions and Residuals¶
if has_eqx and has_diffrax:
T_pred_data = (model_opt if has_optax else model_true_pde)(t_obs, T_initial.data)
T_pred_field = cx.field(T_pred_data, time_axis, x_axis)
# shape: (11, 64) | dims: ('time','space') | units: K — model predictions
obs_field2 = cx.field(T_obs_noisy, time_axis, x_axis)
# shape: (11, 64) | dims: ('time','space') | units: K
residual = T_pred_field - obs_field2
# shape: (11, 64) | dims: ('time','space') | units: K — residual = pred − obs
print(f"Predictions: {T_pred_field.dims}, shape: {T_pred_field.shape}")
print(f"Residual RMS: {float(jnp.sqrt(jnp.mean(residual.data**2))):.4f} K")
# Temporal mean of predictions
T_pred_mean = cx.cmap(lambda x: jnp.mean(x, axis=0))(T_pred_field.untag('time'))
# shape: (64,) | dims: ('space',) | units: K — time-mean of predicted T
print(f"Prediction time-mean: {T_pred_mean.dims}, shape: {T_pred_mean.shape}")
# Spatial mean over time
T_spatial_mean_ts = cx.cmap(lambda x: jnp.mean(x, axis=-1))(
T_pred_field.untag('space')
)
# shape: (11,) | dims: ('time',) | units: K — spatial mean at each of 11 observation times
print(f"Spatial mean time series: {T_spatial_mean_ts.dims}, shape: {T_spatial_mean_ts.shape}")Predictions: ('time', 'space'), shape: (11, 64)
Residual RMS: 0.4918 K
Prediction time-mean: ('space',), shape: (64,)
Spatial mean time series: ('time',), shape: (11,)
Key Patterns¶
- equinox.Module — structured, differentiable model with static fields
- diffrax inside
__call__— fully differentiable PDE solve - optax.adam — gradient-based optimizer
- eqx.filter_value_and_grad — differentiates only array leaves
- coordax — wraps prediction/observation arrays with time & space coordinates
- Pattern: generate obs → define loss → optimize → wrap results as Fields