Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Adjoint Methods & Automatic Differentiation

CSIC
UCM
IGEO

Almost every method in this section — 3DVar, strong / weak / incremental 4DVar, 4DVarNet, amortized inference — boils down to minimise a cost, and minimising needs a gradient. This note is the cross-cutting reference for how that gradient is computed: the adjoint method, its modern incarnation as reverse-mode automatic differentiation, and the handful of strategies (diffrax for dynamics, optimistix for optimisers) that trade memory against compute.

Historically the “adjoint model” was hand-coded line by line — a notorious source of bugs Errico, 1997. Modern autodiff makes it a library feature: you write the forward model, and the adjoint comes for free. The art is no longer writing adjoints but choosing which adjoint strategy fits the problem.

Two Foundational Operations: JVP and VJP

Every autodiff system is built from two primitives. For a function f:RDRM\boldsymbol{f}: \mathbb{R}^{D} \to \mathbb{R}^{M} with Jacobian Jf(x)=f/x\mathbf{J}_{\boldsymbol{f}}(\mathbf{x}) = \partial \boldsymbol{f}/\partial \mathbf{x}:

In the DA notation we have used throughout, Jf\mathbf{J}_{\boldsymbol{f}} is the tangent-linear operator and Jf\mathbf{J}_{\boldsymbol{f}}^\top its adjoint; likewise JH\mathbf{J}_{H} / JH\mathbf{J}_{H}^\top for the observation operator and JM\mathbf{J}_{M} / JM\mathbf{J}_{M}^\top for the dynamics.

Table 1:The two autodiff primitives. DD = #inputs, MM = #outputs.

DirectionMaps“Wiggle…”Cheap when
JVP Jv\mathbf{J}\mathbf{v}forward / tangent-linearRDRM\mathbb{R}^{D} \to \mathbb{R}^{M}one input, watch all outputsfew inputs (DD small)
VJP Jw\mathbf{J}^\top\mathbf{w}reverse / adjointRMRD\mathbb{R}^{M} \to \mathbb{R}^{D}one output, watch all inputsfew outputs (MM small)

The gradient of any scalar cost is a single VJP seeded with w=1\mathbf{w} = 1:

import jax

# forward mode (tangent-linear / JVP):  y, J v
y, Jv = jax.jvp(f, (x,), (v,))

# reverse mode (adjoint / VJP):  y, then  Jᵀ w
y, vjp_fn = jax.vjp(f, x)
(JT_w,) = vjp_fn(w)

# a scalar cost's gradient IS a VJP seeded with 1.0 — this is all jax.grad does:
grad_J = jax.grad(J)(x)        #  ==  jax.vjp(J, x)[1](1.0)

The Adjoint of a Dynamical Model

For 4DVar, the state is marched forward by the dynamics MtM_t and the cost compares the trajectory to observations. The gradient w.r.t. the initial state pulls back through every step (this is just the chain rule):

x0J=B1(x0xb)+t(JMt)(JHt)Rt1(Ht(Mt(x0))yt).\nabla_{\mathbf{x}_0} J = \mathbf{B}^{-1}(\mathbf{x}_0 - \mathbf{x}_b) + \sum_{t} (\mathbf{J}_{M_t})^\top (\mathbf{J}_{H_t})^\top \mathbf{R}_t^{-1}\big(H_t(M_t(\mathbf{x}_0)) - \mathbf{y}_t\big).

Rather than form the dense (JMt)(\mathbf{J}_{M_t})^\top, introduce an adjoint state (a costate / Lagrange multiplier) λt\boldsymbol{\lambda}_t and sweep it backward in time. In continuous time it satisfies the adjoint ODE:

dλdt=(JM(x(t)))λ(t)+x(t)J,x0J=B1(x0xb)+λ0.\frac{\mathrm{d}\boldsymbol{\lambda}}{\mathrm{d}t} = -\big(\mathbf{J}_{M}(\mathbf{x}(t))\big)^\top \boldsymbol{\lambda}(t) + \nabla_{\mathbf{x}(t)} J, \qquad \nabla_{\mathbf{x}_0} J = \mathbf{B}^{-1}(\mathbf{x}_0 - \mathbf{x}_b) + \boldsymbol{\lambda}_0.

The adjoint of one step is a VJP, so the whole backward sweep needs no hand-coding:

import jax
from typing import Callable

def init_step_adjoint(step: Callable) -> Callable:
    """Adjoint of one forward step via VJP:  λ ↦ (∂step/∂x)ᵀ λ."""
    def step_adj(x, lam):
        _, vjp = jax.vjp(step, x)
        return vjp(lam)[0]
    return step_adj

Adjoints Through Dynamics — Choosing a Strategy

The backward sweep needs the forward states x(t)\mathbf{x}(t). How you obtain them during the backward pass is the entire trade-off, and diffrax exposes each as a selectable AbstractAdjoint. Let KK be the number of forward steps.

Table 2:diffrax adjoint strategies for differentiating through an ODE solve.

StrategyMemoryTimeUse when
RecursiveCheckpointAdjoint(checkpoints=N)O(logK)O(\log K)O(KlogK)O(K \log K)default — balanced
BacksolveAdjoint()O(1)O(1)2×\sim 2\times forwardlong windows, memory-bound
ForwardMode()O(P)O(P)O(P)×O(P) \times forwardfew parameters PNP \ll N
DirectAdjoint()O(K)O(K)O(K)O(K)short windows, debugging
import diffrax as dfx

sol = dfx.diffeqsolve(
    dfx.ODETerm(vector_field), solver=dfx.Tsit5(),
    t0=0.0, t1=T, dt0=dt, y0=x0,
    adjoint=dfx.RecursiveCheckpointAdjoint(),   # ← swap the strategy here
)
# jax.grad over a cost that calls diffeqsolve now flows through the chosen adjoint

Adjoints Through an Optimiser — Implicit Differentiation

The harder case: the quantity you want to differentiate is itself the output of an optimisation, x(θ)=arg minxf(x;θ)\mathbf{x}^\star(\boldsymbol{\theta}) = \operatorname*{arg\,min}_{\mathbf{x}} f(\mathbf{x}; \boldsymbol{\theta}) — for example, training a hyperparameter of 3DVar, or the inner solver of 4DVarNet. You could unroll every optimiser iteration and backprop through it, but that is memory-hungry and wasteful.

Differentiating the optimality condition xf(x;θ)=0\nabla_{\mathbf{x}} f(\mathbf{x}^\star; \boldsymbol{\theta}) = \boldsymbol{0} gives the implicit-function-theorem gradient Blondel et al., 2022:

xθ=(xx2f(x;θ))1xθ2f(x;θ).\frac{\partial \mathbf{x}^\star}{\partial \boldsymbol{\theta}} = -\big(\nabla^2_{\mathbf{x}\mathbf{x}} f(\mathbf{x}^\star; \boldsymbol{\theta})\big)^{-1}\, \nabla^2_{\mathbf{x}\boldsymbol{\theta}} f(\mathbf{x}^\star; \boldsymbol{\theta}).

The Hessian inverse is applied implicitly (a matrix-free linear solve, e.g. CG via lineax) — never formed. optimistix exposes these as an AbstractAdjoint on the solver:

Two Slots, Two Purposes

A method that has both dynamics and an inner optimiser carries two independent adjoint choices — and conflating them is a common confusion:

Table 3:The two adjoint slots in a variational-DA solver.

SlotDifferentiates through…LibraryTypical default
forward_adjointthe dynamics rollout MtM_tdiffrax.AbstractAdjointRecursiveCheckpointAdjoint()
minimiser_adjoint / solver_adjointthe optimiser / inner solveoptimistix.AbstractAdjointImplicitAdjoint() (classical), checkpoint/one-step (learned)
import diffrax as dfx
import optimistix as optx

# strong-4DVar: long window → constant-memory dynamics adjoint;
#               converged inner solve → exact implicit optimiser adjoint
model = StrongFourDVar(
    forward=dynamics, obs_op=obs_op,
    prior_mean=x_b, prior_cov_op=B_op, obs_cov_op=R_op,
    minimiser=optx.NonlinearCG(rtol=1e-5),
    forward_adjoint=dfx.BacksolveAdjoint(),     # O(1) memory through M_t
    minimiser_adjoint=optx.ImplicitAdjoint(),   # exact at the MAP
)

Forward vs. Adjoint Sensitivity

The forward/reverse distinction is the sensitivity-analysis distinction:

A handy rule: count the thing you have fewer of. Fewer inputs → forward; fewer outputs → adjoint. Variational DA always has one scalar cost, so it is adjoint territory. Linearisation-based uncertainty propagation (the delta method, E[h(x)]h(μ)+12Tr[2hΣ]\mathbb{E}[h(\mathbf{x})] \approx h(\boldsymbol{\mu}) + \tfrac{1}{2}\mathrm{Tr}[\nabla^2 h\,\boldsymbol{\Sigma}]) uses the same Jacobians/Hessians these primitives compute.

Adjoints in Differentiable Physics

The same machinery powers differentiable simulators beyond DA:

Validation: Do They Agree?

Every adjoint strategy must yield the same gradient (to floating-point tolerance) on a converged problem. This is the single most useful test in the whole stack — a disagreement is a bug:

import jax.numpy as jnp
import diffrax as dfx

def test_adjoints_agree(cost_fn, x0, atol=1e-4):
    strategies = [dfx.RecursiveCheckpointAdjoint(), dfx.BacksolveAdjoint(),
                  dfx.DirectAdjoint(), dfx.ForwardMode()]
    grads = {type(a).__name__: grad_with_adjoint(cost_fn, x0, a) for a in strategies}
    ref = grads["RecursiveCheckpointAdjoint"]
    for name, g in grads.items():
        assert jnp.allclose(g, ref, atol=atol), f"{name} disagrees — adjoint bug"

Two classic culprits when they don’t agree: (1) reverse-time instability in BacksolveAdjoint (switch to an implicit integrator backward), or (2) an under-converged inner solve breaking ImplicitAdjoint’s exactness assumption (tighten the optimiser tolerance).

Selection Heuristics

Table 4:Which adjoint to reach for.

ScenarioReach for
Short window (T10T \le 10), defaultRecursiveCheckpointAdjoint()
Long window (T100T \ge 100), memory-boundBacksolveAdjoint()
Parameter estimation, PNP \ll NForwardMode()
Debugging a suspected adjoint bugDirectAdjoint()
Training 4DVarNet, memory-boundone-step (Bolte)
Training 4DVarNet, well-converged innerImplicitAdjoint()
Training classical hyperparameters (always converged at MAP)ImplicitAdjoint()
References
  1. Errico, R. M. (1997). What Is an Adjoint Model? Bulletin of the American Meteorological Society, 78(11), 2577–2591. https://doi.org/10.1175/1520-0477(1997)078<;2577:WIAAM>2.0.CO;2
  2. Chen, R. T. Q., Rubanova, Y., Bettencourt, J., & Duvenaud, D. K. (2018). Neural Ordinary Differential Equations. Advances in Neural Information Processing Systems (NeurIPS).
  3. Blondel, M., Berthet, Q., Cuturi, M., Frostig, R., Hoyer, S., Llinares-López, F., Pedregosa, F., & Vert, J.-P. (2022). Efficient and Modular Implicit Differentiation. Advances in Neural Information Processing Systems (NeurIPS).
  4. Bolte, J., Pauwels, E., & Vaiter, S. (2023). One-Step Differentiation of Iterative Algorithms. Advances in Neural Information Processing Systems (NeurIPS).
  5. Kidger, P. (2021). On Neural Differential Equations [Phdthesis]. University of Oxford.