For 4D analysis, observations at times must be connected to the initial state through the dynamics. The dynamical model propagates the state forward in time,
and the dynamics enter the cost through the composition — the observation operator applied to the propagated state.
As in the observation-model note, the code is library-agnostic: shapes via jaxtyping, array ops via einx, ODE integration and adjoints via diffrax, and gradients via JAX autodiff.
The Forward-Model Interface¶
A forward model only needs to expose how to advance the state. Everything else — multi-step rollouts, tangent-linear, adjoint — is built on top.
from typing import Protocol
from jaxtyping import Array, Float
class ForwardModel(Protocol):
dt: float
def step(self, state: Float[Array, "N"], dt: float) -> Float[Array, "N"]:
"""Advance the state by one time step dt (discrete dynamics)."""
...
def vector_field(self, t: float, state: Float[Array, "N"], args) -> Float[Array, "N"]:
"""Continuous-time RHS dy/dt = f(t, y) for ODE integration."""
...Any model satisfying this is interchangeable from the assimilation code’s point
of view. For multi-step rollouts, compose step calls into a trajectory via
diffrax.diffeqsolve (when the model is an ODE) or direct iteration (when it is
discrete). The integrator choice (Tsit5, Dopri5, Heun, …) belongs to the
dynamics, not the assimilation layer.
Why the Adjoint Matters¶
The 4DVar cost sums a background term and an observation term over time,
Minimising over requires , which by the chain rule pulls back through every :
In practice you never assemble (3) by hand. Write the cost (2), then let autodiff compose the pullback:
import jax
import einx
from jaxtyping import Array, Float
def fourdvar_cost(
x0: Float[Array, "N"],
xb: Float[Array, "N"],
B_inv: Float[Array, "N N"],
ys: Float[Array, "T M"],
rollout, # x0 -> trajectory (T, N); diffrax handles the adjoint
H, # observation operator applied per step
R_inv: Float[Array, "M M"],
) -> Float[Array, ""]:
# background term: ½ ‖x0 - xb‖²_{B⁻¹}
db = x0 - xb
Jb = 0.5 * einx.dot("N, N ->", db, einx.dot("N K, K -> N", B_inv, db))
# observation term: ½ Σ_t ‖y_t - H(M_t(x0))‖²_{R⁻¹}
traj = rollout(x0) # (T, N)
pred = jax.vmap(H)(traj) # (T, M)
res = ys - pred # (T, M)
Rr = jax.vmap(lambda r: einx.dot("M K, K -> M", R_inv, r))(res)
Jo = 0.5 * einx.dot("T M, T M ->", res, Rr)
return Jb + Jo
grad_J = jax.grad(fourdvar_cost) # adjoint pullback through every M_t, for freeChoosing an Adjoint Method¶
Expose the adjoint as a configurable slot on your 4DVar solver and thread it into
diffeqsolve; the choice flows straight through to diffrax. The trade-offs:
Table 1:diffrax adjoint methods. = number of forward steps, = number of parameters.
| Method | Memory | Extra cost | Use when |
|---|---|---|---|
RecursiveCheckpointAdjoint(checkpoints=N) | extra forward solves | default — balances memory and recompute | |
BacksolveAdjoint() | backward ODE solve | long windows, memory-constrained | |
ForwardMode() | in params | one extra forward per parameter | few parameters (parameter estimation) |
DirectAdjoint() | steps | standard reverse | short windows, debugging |
Default —
RecursiveCheckpointAdjoint. Standard reverse-mode with recursive checkpointing; the best general-purpose balance.Long windows —
BacksolveAdjoint. Solves the continuous adjoint ODE backwards in time, with memory constant in trajectory length. Caveat: the reverse-time ODE can be stiff even when the forward one is not — use a stable (often implicit) integrator backward.Parameter estimation —
ForwardMode. When sensitivity is to a small number of parameters (e.g. a single emission rate in a methane Tier-I problem), forward sensitivity beats reverse.Debugging —
DirectAdjoint. Plain reverse-mode through thelax.scanof forward steps: slow but transparent.
Dynamics Regimes¶
The full nonlinear is integrated each outer step and diffrax’s adjoint
handles the gradient. This is the unwrapped strong/weak-constraint
4DVar setting — nothing special is required beyond choosing an
adjoint method.
If is linear, the tangent-linear is — no linearisation needed. For incremental 4DVar, the outer nonlinear model is linearised once per outer iteration and the inner loop operates on that tangent-linear across many CG iterations.
import jax
# at the outer iterate x_b, build the tangent-linear of one step
_, M_lin = jax.linearize(lambda s: forward.step(s, forward.dt), x_b)
dy = M_lin(dx) # M'_t · dx (tangent-linear)
M_adj = jax.linear_transpose(M_lin, x_b) # (M'_t)ᵀ (adjoint)
(g,) = M_adj(residual)For SDEs (Langevin dynamics in Lagrangian transport, stochastic
parameterisations), diffrax still applies — diffeqsolve integrates SDEs
(EulerMaruyama, Heun, …) and the adjoints work too, with care around
non-deterministic reverse integration. Seed realisations deterministically (carry
a jax.random.PRNGKey through the batch) so the gradient is well-defined.
Forward Rollout in Practice¶
A typical rollout integrates the vector field and saves the trajectory, threading
the chosen adjoint so jax.grad flows back through it:
import jax.numpy as jnp
import diffrax as dfx
from jaxtyping import Array, Float
def rollout(
x0: Float[Array, "N"],
forward, # exposes forward.vector_field(t, y, args)
n_steps: int,
dt: float,
*,
adjoint: dfx.AbstractAdjoint,
) -> Float[Array, "T N"]:
"""Integrate forward and return the trajectory (n_steps + 1, N)."""
sol = dfx.diffeqsolve(
dfx.ODETerm(lambda t, y, args: forward.vector_field(t, y, args)),
solver=dfx.Tsit5(),
t0=0.0, t1=n_steps * dt, dt0=dt,
y0=x0,
saveat=dfx.SaveAt(ts=jnp.arange(n_steps + 1) * dt),
adjoint=adjoint,
)
return sol.ysCompose this inside fourdvar_cost; when jax.grad(J)(x0) is called, the
gradient flows back through rollout along the adjoint path you selected.