Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

PDE parameter estimation

Part 6.3: Fitting a PDE Model to Data

Estimate advection velocity (U) and diffusion coefficient (κ) of the 1D advection-diffusion equation from synthetic observations.

Stack: equinox (model) · diffrax (PDE solver) · optax (optimizer) · coordax (state wrapping)

import time as _time

import coordax as cx
import jax
import jax.numpy as jnp

try:
    import equinox as eqx
    has_eqx = True
except ImportError:
    has_eqx = False
    print("equinox not installed.")

try:
    import diffrax
    has_diffrax = True
except ImportError:
    has_diffrax = False
    print("diffrax not installed.")

try:
    import optax
    has_optax = True
except ImportError:
    has_optax = False
    print("optax not installed.")

# NOTE: cx.field() is preferred over cx.wrap() in coordax >= 0.2

Problem Setup: 1D Advection-Diffusion

PDE: ∂T/∂t = −U·∂T/∂x + κ·∂²T/∂x² (periodic domain)

n_space = 64
x_values = jnp.linspace(0, 100, n_space, endpoint=False)  # km
dx = float(x_values[1] - x_values[0])
x_axis = cx.LabeledAxis('space', x_values)

T_base, T_amp, x_center, sigma = 288.0, 20.0, 25.0, 5.0
T_initial_data = T_base + T_amp * jnp.exp(-((x_values - x_center) / sigma)**2)
T_initial = cx.field(T_initial_data, x_axis)
# shape: (64,) | dims: ('space',) | units: K  — Gaussian pulse: T(x,0) = T_base + A·exp(…)

# True parameters (to be recovered)
U_true    = 5.0   # km/h
kappa_true = 0.5  # km²/h

print(f"Spatial domain: {float(x_values.min()):.1f} — {float(x_values.max()):.1f} km  ({n_space} pts)")
print(f"True params: U={U_true} km/h, κ={kappa_true} km²/h")
Spatial domain: 0.0 — 98.4 km  (64 pts)
True params: U=5.0 km/h, κ=0.5 km²/h

PDE Right-Hand Side (Spectral, Periodic BC)

def advection_diffusion_rhs(t, T_data, args):
    """dT/dt = -U*dT/dx + κ*d²T/dx²  using spectral derivatives."""
    U, kappa, dx, xa = args['U'], args['kappa'], args['dx'], args['x_axis']
    n = len(xa.ticks)
    T = cx.field(T_data, xa)
    k = jnp.fft.rfftfreq(n, d=dx)
    T_hat = jnp.fft.rfft(T.untag('space').data)
    dT_dx   = jnp.fft.irfft(2j * jnp.pi * k * T_hat, n=n)
    d2T_dx2 = jnp.fft.irfft(-(2 * jnp.pi * k) ** 2 * T_hat, n=n)
    # dT_dx:   shape (n,) | units: K/km  — ∂T/∂x  (spectral)
    # d2T_dx2: shape (n,) | units: K/km² — ∂²T/∂x²
    return -U * dT_dx + kappa * d2T_dx2


args_true = {'U': U_true, 'kappa': kappa_true, 'dx': dx, 'x_axis': x_axis}
rhs_test = advection_diffusion_rhs(0.0, T_initial.data, args_true)
print(f"RHS range: {float(rhs_test.min()):.3f} — {float(rhs_test.max()):.3f} K/h")
RHS range: -17.034 — 16.797 K/h

Generate Synthetic Observations

if has_diffrax:
    t_obs = jnp.linspace(0, 5.0, 11)  # 11 snapshots over 5 h
    time_axis = cx.LabeledAxis('time', t_obs)

    dt_max = min(dx / U_true, 0.5 * dx**2 / kappa_true) * 0.5
    solution_true = diffrax.diffeqsolve(
        terms=diffrax.ODETerm(advection_diffusion_rhs),
        solver=diffrax.Tsit5(),
        t0=0.0, t1=5.0, dt0=None,
        y0=T_initial.data,
        args=args_true,
        saveat=diffrax.SaveAt(ts=t_obs),
        stepsize_controller=diffrax.PIDController(rtol=1e-5, atol=1e-7, dtmax=dt_max),
        max_steps=50000,
    )
    T_obs = solution_true.ys  # (11, 64)

    key = jax.random.PRNGKey(0)
    noise_level = 0.5  # K
    T_obs_noisy = T_obs + jax.random.normal(key, T_obs.shape) * noise_level

    T_obs_field = cx.field(T_obs_noisy, time_axis, x_axis)
    # shape: (11, 64) | dims: ('time','space') | units: K  — noisy observations (σ=0.5 K)
    print(f"Observations: {T_obs_field.dims}, shape: {T_obs_field.shape}")
    print(f"Noise: {noise_level} K (Gaussian)")
else:
    t_obs = jnp.linspace(0, 5.0, 11)
    time_axis = cx.LabeledAxis('time', t_obs)
    T_obs_noisy = jnp.zeros((11, n_space))
    T_obs_field = cx.field(T_obs_noisy, time_axis, x_axis)
    print("Skipping observation generation (diffrax not installed).")
Observations: ('time', 'space'), shape: (11, 64)
Noise: 0.5 K (Gaussian)

PDE Model with Equinox

if has_eqx and has_diffrax:
    class AdvDiffModel(eqx.Module):
        """Differentiable 1D advection-diffusion PDE model."""
        U: float
        kappa: float
        dx: float       = eqx.field(static=True)
        x_axis: cx.Coordinate = eqx.field(static=True)
        dt_max: float   = eqx.field(static=True)

        def __call__(self, t_span, y0):
            args = {'U': self.U, 'kappa': self.kappa,
                    'dx': self.dx, 'x_axis': self.x_axis}
            sol = diffrax.diffeqsolve(
                terms=diffrax.ODETerm(advection_diffusion_rhs),
                solver=diffrax.Tsit5(),
                t0=t_span[0], t1=t_span[-1], dt0=None,
                y0=y0,
                args=args,
                saveat=diffrax.SaveAt(ts=t_span),
                stepsize_controller=diffrax.PIDController(
                    rtol=1e-5, atol=1e-7, dtmax=self.dt_max
                ),
                max_steps=50000,
                adjoint=diffrax.RecursiveCheckpointAdjoint(),
            )
            # sol.ys: shape (n_time, n_space) | units: K  — predicted temperature trajectory
            return sol.ys  # (n_time, n_space)

    dt_max = min(dx / U_true, 0.5 * dx**2 / kappa_true) * 0.5
    model_true_pde = AdvDiffModel(U=U_true, kappa=kappa_true, dx=dx,
                                   x_axis=x_axis, dt_max=dt_max)

    # Perturbed initial guess — use jnp.array so eqx.is_array recognizes them
    model_init = AdvDiffModel(U=jnp.array(3.0), kappa=jnp.array(0.2), dx=dx,
                               x_axis=x_axis, dt_max=dt_max)
    print(f"Initial guess: U={float(model_init.U)}, κ={float(model_init.kappa)}")
Initial guess: U=3.0, κ=0.20000000298023224

Loss Function and Optimization

if has_eqx and has_diffrax and has_optax:
    def pde_loss(model, T_obs, t_span, y0):
        """MSE between model predictions and observations."""
        T_pred = model(t_span, y0)
        return jnp.mean((T_pred - T_obs)**2)

    loss_init = pde_loss(model_init, T_obs_noisy, t_obs, T_initial.data)
    loss_true = pde_loss(model_true_pde, T_obs_noisy, t_obs, T_initial.data)
    print(f"Loss at initial guess: {float(loss_init):.4f}")
    print(f"Loss at true params:   {float(loss_true):.4f}")

    optimizer = optax.adam(1e-2)
    model_opt = model_init
    opt_state = optimizer.init(eqx.filter(model_opt, eqx.is_array))

    @eqx.filter_jit
    def train_step(model, opt_state, T_obs, t_span, y0):
        loss, grads = eqx.filter_value_and_grad(pde_loss)(model, T_obs, t_span, y0)
        updates, new_opt_state = optimizer.update(
            eqx.filter(grads, eqx.is_array), opt_state
        )
        new_model = eqx.apply_updates(model, updates)
        return new_model, new_opt_state, loss

    losses = []
    t0 = _time.perf_counter()
    for i in range(500):
        model_opt, opt_state, loss = train_step(
            model_opt, opt_state, T_obs_noisy, t_obs, T_initial.data
        )
        losses.append(float(loss))
        if (i + 1) % 25 == 0:
            print(f"  iter {i+1:3d}: loss={float(loss):.4f}  "
                  f"U={float(model_opt.U):.3f}  κ={float(model_opt.kappa):.4f}")
    elapsed = _time.perf_counter() - t0
    print(f"\nOptimization: {elapsed:.1f}s for 500 steps")
    print(f"True:      U={U_true}, κ={kappa_true}")
    print(f"Recovered: U={float(model_opt.U):.4f}, κ={float(model_opt.kappa):.4f}")
    print(f"U error: {abs(float(model_opt.U) - U_true):.4f}")
    print(f"κ error: {abs(float(model_opt.kappa) - kappa_true):.4f}")
Loss at initial guess: 17.1521
Loss at true params:   0.2363
  iter  25: loss=13.6256  U=3.251  κ=0.4448
  iter  50: loss=10.2988  U=3.501  κ=0.6630
  iter  75: loss=7.4664  U=3.745  κ=0.8407
  iter 100: loss=5.2027  U=3.974  κ=0.9729
  iter 125: loss=3.5026  U=4.183  κ=1.0598
  iter 150: loss=2.2995  U=4.367  κ=1.1055
  iter 175: loss=1.4947  U=4.523  κ=1.1174
  iter 200: loss=0.9843  U=4.651  κ=1.1038
  iter 225: loss=0.6763  U=4.753  κ=1.0725
  iter 250: loss=0.4981  U=4.831  κ=1.0303
  iter 275: loss=0.3976  U=4.889  κ=0.9822
  iter 300: loss=0.3410  U=4.930  κ=0.9319
  iter 325: loss=0.3079  U=4.958  κ=0.8817
  iter 350: loss=0.2871  U=4.977  κ=0.8333
  iter 375: loss=0.2729  U=4.988  κ=0.7877
  iter 400: loss=0.2626  U=4.995  κ=0.7456
  iter 425: loss=0.2549  U=4.999  κ=0.7075
  iter 450: loss=0.2491  U=5.001  κ=0.6734
  iter 475: loss=0.2449  U=5.002  κ=0.6436
  iter 500: loss=0.2419  U=5.003  κ=0.6179

Optimization: 18.8s for 500 steps
True:      U=5.0, κ=0.5
Recovered: U=5.0029, κ=0.6179
U error: 0.0029
κ error: 0.1179

Coordax: Wrap Predictions and Residuals

if has_eqx and has_diffrax:
    T_pred_data = (model_opt if has_optax else model_true_pde)(t_obs, T_initial.data)
    T_pred_field = cx.field(T_pred_data, time_axis, x_axis)
    # shape: (11, 64) | dims: ('time','space') | units: K  — model predictions

    obs_field2 = cx.field(T_obs_noisy, time_axis, x_axis)
    # shape: (11, 64) | dims: ('time','space') | units: K
    residual = T_pred_field - obs_field2
    # shape: (11, 64) | dims: ('time','space') | units: K  — residual = pred − obs

    print(f"Predictions: {T_pred_field.dims}, shape: {T_pred_field.shape}")
    print(f"Residual RMS: {float(jnp.sqrt(jnp.mean(residual.data**2))):.4f} K")

    # Temporal mean of predictions
    T_pred_mean = cx.cmap(lambda x: jnp.mean(x, axis=0))(T_pred_field.untag('time'))
    # shape: (64,) | dims: ('space',) | units: K  — time-mean of predicted T
    print(f"Prediction time-mean: {T_pred_mean.dims}, shape: {T_pred_mean.shape}")

    # Spatial mean over time
    T_spatial_mean_ts = cx.cmap(lambda x: jnp.mean(x, axis=-1))(
        T_pred_field.untag('space')
    )
    # shape: (11,) | dims: ('time',) | units: K  — spatial mean at each of 11 observation times
    print(f"Spatial mean time series: {T_spatial_mean_ts.dims}, shape: {T_spatial_mean_ts.shape}")
Predictions: ('time', 'space'), shape: (11, 64)
Residual RMS: 0.4918 K
Prediction time-mean: ('space',), shape: (64,)
Spatial mean time series: ('time',), shape: (11,)

Key Patterns

  • equinox.Module — structured, differentiable model with static fields
  • diffrax inside __call__ — fully differentiable PDE solve
  • optax.adam — gradient-based optimizer
  • eqx.filter_value_and_grad — differentiates only array leaves
  • coordax — wraps prediction/observation arrays with time & space coordinates
  • Pattern: generate obs → define loss → optimize → wrap results as Fields