Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

ODE integration (Diffrax)

Part 6: The JAX Superpower — JIT, vmap, and PDE Integration (Diffrax)

Coordax Fields are JAX pytrees, so they work seamlessly with jax.jit, jax.grad, and libraries like Diffrax.

import time as _time

import coordax as cx
import jax
import jax.numpy as jnp

# NOTE: cx.field() is preferred over cx.wrap() in coordax >= 0.2

6.1 JIT Compilation of a Complex Analysis Pipeline

n_time, n_lat, n_lon = 20, 16, 32

lat_values = jnp.linspace(-90, 90, n_lat)
lon_values = jnp.linspace(0, 360, n_lon, endpoint=False)
time_values = jnp.arange(n_time) * 6.0

lat_mesh, lon_mesh = jnp.meshgrid(lat_values, lon_values, indexing='ij')
base_temp = 288.0 + 30.0 * jnp.cos(jnp.deg2rad(lat_mesh))
time_factor = jnp.sin(2 * jnp.pi * time_values / (n_time * 6))[:, None, None]
spatial_waves = 5.0 * jnp.sin(3 * jnp.deg2rad(lon_mesh))
temperature_data = base_temp[None] * (1 + 0.1 * time_factor) + spatial_waves[None]

time_axis = cx.LabeledAxis('time', time_values)
lat_axis  = cx.LabeledAxis('latitude', lat_values)
lon_axis  = cx.LabeledAxis('longitude', lon_values)

temperature = cx.field(temperature_data, time_axis, lat_axis, lon_axis)
# shape: (20, 16, 32) | dims: ('time','latitude','longitude') | units: K
print(f"Temperature: {temperature.dims}, shape: {temperature.shape}")
Temperature: ('time', 'latitude', 'longitude'), shape: (20, 16, 32)
_R_EARTH = 6.371e6  # reference constant (unused in this section)

def climate_analysis(temp_field):
    """Pipeline: temporal mean, anomalies, spatial stats."""
    _lat_vals = temp_field.axes['latitude'].ticks
    _lon_vals = temp_field.axes['longitude'].ticks
    _t_axis   = temp_field.axes['time']
    _la_axis  = temp_field.axes['latitude']
    _lo_axis  = temp_field.axes['longitude']

    # Temporal mean
    temp_mean = cx.cmap(lambda x: jnp.mean(x, axis=0))(temp_field.untag('time'))
    # shape: (16, 32) | dims: ('latitude','longitude') | units: K  — T̄ (temporal mean)

    # Anomaly
    temp_anomaly = temp_field - temp_mean
    # shape: (20, 16, 32) | dims: ('time','latitude','longitude') | units: K  — T' = T − T̄

    # Spatial mean per time step
    spatial_mean = cx.cmap(lambda x: jnp.mean(x, axis=(-2, -1)))(
        temp_field.untag('latitude', 'longitude')
    )
    # shape: (20,) | dims: ('time',) | units: K  — global mean per time step

    return {'temp_mean': temp_mean, 'temp_anomaly': temp_anomaly,
            'spatial_mean': spatial_mean}


# Without JIT
t0 = _time.perf_counter()
results = climate_analysis(temperature)
t_eager = _time.perf_counter() - t0
print(f"Eager: {t_eager*1000:.1f} ms")

# JIT-compiled
jitted_analysis = jax.jit(climate_analysis)
t0 = _time.perf_counter()
results_jit = jitted_analysis(temperature)
jax.block_until_ready(results_jit)
t_compile = _time.perf_counter() - t0
print(f"JIT (first, with compilation): {t_compile*1000:.1f} ms")

t0 = _time.perf_counter()
results_jit2 = jitted_analysis(temperature)
jax.block_until_ready(results_jit2)
t_jit = _time.perf_counter() - t0
print(f"JIT (subsequent): {t_jit*1000:.1f} ms")

# Verify correctness
for key in results:
    diff = float(jnp.max(jnp.abs(results[key].data - results_jit2[key].data)))
    print(f"  {key}: max diff = {diff:.2e}")
Eager: 270.0 ms
JIT (first, with compilation): 102.5 ms
JIT (subsequent): 0.3 ms
  temp_mean: max diff = 0.00e+00
  temp_anomaly: max diff = 1.43e-05
  spatial_mean: max diff = 0.00e+00

6.2 Solving a 1D Advection-Diffusion PDE (Diffrax)

Equation: ∂T/∂t = −U·∂T/∂x + κ·∂²T/∂x²

n_space = 64
x_values = jnp.linspace(0, 100, n_space, endpoint=False)  # km
dx = float(x_values[1] - x_values[0])

x_axis = cx.LabeledAxis('space', x_values)

T_base, T_amp, x_center, sigma = 288.0, 20.0, 25.0, 5.0
T_initial_data = T_base + T_amp * jnp.exp(-((x_values - x_center) / sigma)**2)
T_initial = cx.field(T_initial_data, x_axis)
# shape: (64,) | dims: ('space',) | units: K  — T(x,0) = T_base + A·exp(−(x−x₀)²/σ²)

U, kappa = 5.0, 0.5  # km/h and km²/h

print(f"Initial condition: shape={T_initial.shape}, dims={T_initial.dims}")
print(f"T range: {float(T_initial.data.min()):.1f} — {float(T_initial.data.max()):.1f} K")
Initial condition: shape=(64,), dims=('space',)
T range: 288.0 — 308.0 K
def advection_diffusion_rhs(t, T_data, args):
    """
    dT/dt = -U*∂T/∂x + κ*∂²T/∂x²  (spectral, periodic BC)
    """
    U     = args['U']
    kappa = args['kappa']
    dx    = args['dx']
    n     = len(args['x_axis'].ticks)

    T = cx.field(T_data, args['x_axis'])
    k = jnp.fft.rfftfreq(n, d=dx)
    T_hat = jnp.fft.rfft(T.untag('space').data)

    dT_dx   = jnp.fft.irfft(2j * jnp.pi * k * T_hat, n=n)
    d2T_dx2 = jnp.fft.irfft(-(2 * jnp.pi * k) ** 2 * T_hat, n=n)
    # dT_dx:   shape: (n,) | units: K/km  — ∂T/∂x  (spectral, periodic)
    # d2T_dx2: shape: (n,) | units: K/km² — ∂²T/∂x²

    return -U * dT_dx + kappa * d2T_dx2


args = {'U': U, 'kappa': kappa, 'dx': dx, 'x_axis': x_axis}
dT_dt_test = advection_diffusion_rhs(0.0, T_initial.data, args)
print(f"RHS shape: {dT_dt_test.shape}, range: {float(dT_dt_test.min()):.2f} — {float(dT_dt_test.max()):.2f} K/h")
RHS shape: (64,), range: -17.03 — 16.80 K/h

Solve with Diffrax

try:
    import diffrax
    has_diffrax = True
except ImportError:
    has_diffrax = False
    print("diffrax not installed — skipping ODE solve.")

if has_diffrax:
    t0_sim, t1_sim = 0.0, 5.0  # hours (shortened for speed)
    save_times = jnp.linspace(t0_sim, t1_sim, 26)

    dt_max_cfl = dx / U
    dt_max_diff = 0.5 * dx**2 / kappa
    dt_safe = min(dt_max_cfl, dt_max_diff) * 0.5

    ode_term = diffrax.ODETerm(advection_diffusion_rhs)
    solver   = diffrax.Tsit5()
    saveat   = diffrax.SaveAt(ts=save_times)
    step_ctrl = diffrax.PIDController(rtol=1e-5, atol=1e-7, dtmax=dt_safe)

    t0 = _time.perf_counter()
    solution = diffrax.diffeqsolve(
        terms=ode_term,
        solver=solver,
        t0=t0_sim,
        t1=t1_sim,
        dt0=None,
        y0=T_initial.data,
        args=args,
        saveat=saveat,
        stepsize_controller=step_ctrl,
        max_steps=50000,
    )
    elapsed = _time.perf_counter() - t0
    print(f"ODE solve: {elapsed:.2f}s, steps={solution.stats['num_accepted_steps']}")

    T_solution = solution.ys   # (n_time, n_space)
    # shape: (26, 64) | axes: (time_saves, space) | units: K  — full spatiotemporal solution
    print(f"Solution shape: {T_solution.shape}")

    T_final = cx.field(T_solution[-1], x_axis)
    # shape: (64,) | dims: ('space',) | units: K  — final state at t = t1
    print(f"Final T: {T_final.dims}, range: {float(T_final.data.min()):.1f} — {float(T_final.data.max()):.1f} K")

    # Verify advection: center-of-mass of the anomaly should move by U * t1
    T_anom_0 = T_solution[0]  - float(T_base)
    T_anom_f = T_solution[-1] - float(T_base)
    com_0 = float(jnp.sum(x_values * T_anom_0) / jnp.sum(T_anom_0))
    com_f = float(jnp.sum(x_values * T_anom_f) / jnp.sum(T_anom_f))
    expected = U * t1_sim
    print(f"COM motion (anomaly): {com_f - com_0:.2f} km  (expected ~{expected:.1f} km)")
ODE solve: 1.45s, steps=34
Solution shape: (26, 64)
Final T: ('space',), range: 288.0 — 304.9 K
COM motion (anomaly): 25.00 km  (expected ~25.0 km)

6.3 Automatic Differentiation Through the PDE

Find an optimal diffusion coefficient by differentiating through a short simulation.

if has_diffrax:
    def loss_kappa(kappa_val, T0_data, T_target_data, args_template, n_steps=50):
        """MSE loss between simulated and target final state (Forward Euler)."""
        T = T0_data
        dt_small = 0.01  # hours
        a = dict(args_template)
        a = {k: (kappa_val if k == 'kappa' else v) for k, v in a.items()}
        for _ in range(n_steps):
            T = T + dt_small * advection_diffusion_rhs(0.0, T, a)
        return jnp.mean((T - T_target_data)**2)

    # Target: run with kappa_true=0.8
    args_true = dict(args, kappa=0.8)
    T_target = T_initial.data
    n_euler = 30
    for _ in range(n_euler):
        T_target = T_target + 0.01 * advection_diffusion_rhs(0.0, T_target, args_true)

    kappa_guess = jnp.array(0.3)
    loss_val = loss_kappa(kappa_guess, T_initial.data, T_target, args, n_steps=n_euler)
    grad_val = jax.grad(loss_kappa)(kappa_guess, T_initial.data, T_target, args, n_steps=n_euler)
    print(f"Loss at kappa={float(kappa_guess):.2f}: {float(loss_val):.4f}")
    print(f"dL/dkappa: {float(grad_val):.4f}  (sign points toward kappa_true=0.8)")
Loss at kappa=0.30: 0.0026
dL/dkappa: -0.0105  (sign points toward kappa_true=0.8)

Key Patterns

  • cx.Field objects are JAX pytrees — pass them directly to jax.jit, jax.grad, jax.vmap
  • Diffrax integrates any ODE whose state is a JAX array
  • The untag + spectral/FD derivative pattern works inside diffrax.ODETerm
  • Wrap Diffrax solutions back as coordax Fields with the appropriate axes