Almost every method in this section — 3DVar,
strong / weak / incremental
4DVar, 4DVarNet, amortized inference
— boils down to minimise a cost, and minimising needs a gradient. This note
is the cross-cutting reference for how that gradient is computed: the adjoint
method, its modern incarnation as reverse-mode automatic differentiation, and the
handful of strategies (diffrax for dynamics, optimistix for optimisers) that
trade memory against compute.
Historically the “adjoint model” was hand-coded line by line — a notorious source of bugs Errico, 1997. Modern autodiff makes it a library feature: you write the forward model, and the adjoint comes for free. The art is no longer writing adjoints but choosing which adjoint strategy fits the problem.
Two Foundational Operations: JVP and VJP¶
Every autodiff system is built from two primitives. For a function with Jacobian :
Jacobian–vector product (JVP) — forward mode, the tangent-linear model. Push a tangent (input perturbation) forward to an output perturbation: .
Vector–Jacobian product (VJP) — reverse mode, the adjoint model. Pull a cotangent (output sensitivity) backward to an input sensitivity: .
In the DA notation we have used throughout, is the tangent-linear operator and its adjoint; likewise / for the observation operator and / for the dynamics.
Table 1:The two autodiff primitives. = #inputs, = #outputs.
| Direction | Maps | “Wiggle…” | Cheap when | |
|---|---|---|---|---|
| JVP | forward / tangent-linear | one input, watch all outputs | few inputs ( small) | |
| VJP | reverse / adjoint | one output, watch all inputs | few outputs ( small) |
The gradient of any scalar cost is a single VJP seeded with :
import jax
# forward mode (tangent-linear / JVP): y, J v
y, Jv = jax.jvp(f, (x,), (v,))
# reverse mode (adjoint / VJP): y, then Jᵀ w
y, vjp_fn = jax.vjp(f, x)
(JT_w,) = vjp_fn(w)
# a scalar cost's gradient IS a VJP seeded with 1.0 — this is all jax.grad does:
grad_J = jax.grad(J)(x) # == jax.vjp(J, x)[1](1.0)The Adjoint of a Dynamical Model¶
For 4DVar, the state is marched forward by the dynamics and the cost compares the trajectory to observations. The gradient w.r.t. the initial state pulls back through every step (this is just the chain rule):
Rather than form the dense , introduce an adjoint state (a costate / Lagrange multiplier) and sweep it backward in time. In continuous time it satisfies the adjoint ODE:
The adjoint of one step is a VJP, so the whole backward sweep needs no hand-coding:
import jax
from typing import Callable
def init_step_adjoint(step: Callable) -> Callable:
"""Adjoint of one forward step via VJP: λ ↦ (∂step/∂x)ᵀ λ."""
def step_adj(x, lam):
_, vjp = jax.vjp(step, x)
return vjp(lam)[0]
return step_adjWhere (eq:adj-ode) comes from — the Lagrangian
Treat the dynamics as a constraint and adjoin it to the cost with a multiplier :
Requiring stationarity of w.r.t. the state (integrate the term by parts) forces to satisfy the adjoint ODE (2); stationarity w.r.t. then gives the gradient . The multiplier is the adjoint state — “Lagrange multiplier,” “costate,” and “adjoint variable” are three names for .
Adjoints Through Dynamics — Choosing a Strategy¶
The backward sweep needs the forward states . How you obtain them
during the backward pass is the entire trade-off, and diffrax exposes each as a
selectable AbstractAdjoint. Let be the number of forward steps.
Table 2:diffrax adjoint strategies for differentiating through an ODE solve.
| Strategy | Memory | Time | Use when |
|---|---|---|---|
RecursiveCheckpointAdjoint(checkpoints=N) | default — balanced | ||
BacksolveAdjoint() | forward | long windows, memory-bound | |
ForwardMode() | forward | few parameters | |
DirectAdjoint() | short windows, debugging |
import diffrax as dfx
sol = dfx.diffeqsolve(
dfx.ODETerm(vector_field), solver=dfx.Tsit5(),
t0=0.0, t1=T, dt0=dt, y0=x0,
adjoint=dfx.RecursiveCheckpointAdjoint(), # ← swap the strategy here
)
# jax.grad over a cost that calls diffeqsolve now flows through the chosen adjointAdjoints Through an Optimiser — Implicit Differentiation¶
The harder case: the quantity you want to differentiate is itself the output of an optimisation, — for example, training a hyperparameter of 3DVar, or the inner solver of 4DVarNet. You could unroll every optimiser iteration and backprop through it, but that is memory-hungry and wasteful.
Differentiating the optimality condition gives the implicit-function-theorem gradient Blondel et al., 2022:
The Hessian inverse is applied implicitly (a matrix-free linear solve, e.g. CG
via lineax) — never formed. optimistix exposes these as an
AbstractAdjoint on the solver:
ImplicitAdjoint()— (4) at the optimum. memory, exact at convergence. The default for classical methods, whose inner solve always converges to a MAP.RecursiveCheckpointAdjoint()— unroll-and-checkpoint through the iterates. memory, no convergence assumption. Use when the solver may not fully converge.One-step (Bolte et al. 2023) Bolte et al., 2023 — run iterations under
stop_gradient, then differentiate only the last step. memory, exact at a fixed point, near-exact for well-converged solvers. The memory-saver for training 4DVarNet.DirectAdjoint()— plain reverse through the solver loop; transparent, for debugging.
Two Slots, Two Purposes¶
A method that has both dynamics and an inner optimiser carries two independent adjoint choices — and conflating them is a common confusion:
Table 3:The two adjoint slots in a variational-DA solver.
| Slot | Differentiates through… | Library | Typical default |
|---|---|---|---|
forward_adjoint | the dynamics rollout | diffrax.AbstractAdjoint | RecursiveCheckpointAdjoint() |
minimiser_adjoint / solver_adjoint | the optimiser / inner solve | optimistix.AbstractAdjoint | ImplicitAdjoint() (classical), checkpoint/one-step (learned) |
import diffrax as dfx
import optimistix as optx
# strong-4DVar: long window → constant-memory dynamics adjoint;
# converged inner solve → exact implicit optimiser adjoint
model = StrongFourDVar(
forward=dynamics, obs_op=obs_op,
prior_mean=x_b, prior_cov_op=B_op, obs_cov_op=R_op,
minimiser=optx.NonlinearCG(rtol=1e-5),
forward_adjoint=dfx.BacksolveAdjoint(), # O(1) memory through M_t
minimiser_adjoint=optx.ImplicitAdjoint(), # exact at the MAP
)Forward vs. Adjoint Sensitivity¶
The forward/reverse distinction is the sensitivity
Forward sensitivity carries alongside the state — one tangent run per parameter. Best when few inputs (parameter estimation, ).
Adjoint sensitivity carries backward — one sweep per output. Best when few outputs (a scalar cost, , with a high-dimensional state).
A handy rule: count the thing you have fewer of. Fewer inputs → forward; fewer outputs → adjoint. Variational DA always has one scalar cost, so it is adjoint territory. Linearisation-based uncertainty propagation (the delta method, ) uses the same Jacobians/Hessians these primitives compute.
Adjoints in Differentiable Physics¶
The same machinery powers differentiable simulators beyond DA:
Implicit solvers (linear systems, eigenproblems, fixed points) appear inside radiative-transfer and PDE codes. Differentiating through a
lineax/CG solve uses (4) — the solver’s answer, not its iterations — and numerically delicate operations get a hand-writtenjax.custom_jvp. See the differentiable RTM note.PINNs put a PDE residual in the loss, so training backprops through the differential operators themselves — autodiff differentiating an autodiff expression. See physics-informed neural networks.
Validation: Do They Agree?¶
Every adjoint strategy must yield the same gradient (to floating-point tolerance) on a converged problem. This is the single most useful test in the whole stack — a disagreement is a bug:
import jax.numpy as jnp
import diffrax as dfx
def test_adjoints_agree(cost_fn, x0, atol=1e-4):
strategies = [dfx.RecursiveCheckpointAdjoint(), dfx.BacksolveAdjoint(),
dfx.DirectAdjoint(), dfx.ForwardMode()]
grads = {type(a).__name__: grad_with_adjoint(cost_fn, x0, a) for a in strategies}
ref = grads["RecursiveCheckpointAdjoint"]
for name, g in grads.items():
assert jnp.allclose(g, ref, atol=atol), f"{name} disagrees — adjoint bug"Two classic culprits when they don’t agree: (1) reverse-time instability in
BacksolveAdjoint (switch to an implicit integrator backward), or (2) an
under-converged inner solve breaking ImplicitAdjoint’s exactness assumption
(tighten the optimiser tolerance).
Selection Heuristics¶
Table 4:Which adjoint to reach for.
| Scenario | Reach for |
|---|---|
| Short window (), default | RecursiveCheckpointAdjoint() |
| Long window (), memory-bound | BacksolveAdjoint() |
| Parameter estimation, | ForwardMode() |
| Debugging a suspected adjoint bug | DirectAdjoint() |
| Training 4DVarNet, memory-bound | one-step (Bolte) |
| Training 4DVarNet, well-converged inner | ImplicitAdjoint() |
| Training classical hyperparameters (always converged at MAP) | ImplicitAdjoint() |
- Errico, R. M. (1997). What Is an Adjoint Model? Bulletin of the American Meteorological Society, 78(11), 2577–2591. https://doi.org/10.1175/1520-0477(1997)078<2577:WIAAM>2.0.CO;2
- Chen, R. T. Q., Rubanova, Y., Bettencourt, J., & Duvenaud, D. K. (2018). Neural Ordinary Differential Equations. Advances in Neural Information Processing Systems (NeurIPS).
- Blondel, M., Berthet, Q., Cuturi, M., Frostig, R., Hoyer, S., Llinares-López, F., Pedregosa, F., & Vert, J.-P. (2022). Efficient and Modular Implicit Differentiation. Advances in Neural Information Processing Systems (NeurIPS).
- Bolte, J., Pauwels, E., & Vaiter, S. (2023). One-Step Differentiation of Iterative Algorithms. Advances in Neural Information Processing Systems (NeurIPS).
- Kidger, P. (2021). On Neural Differential Equations [Phdthesis]. University of Oxford.