Adjoint Methods¶
Every analysis method that involves minimisation or dynamics needs
gradients. For 4DVar, that's gradients through the ODE rollout. For
FourDVarNet, that's gradients through the inner learned solver. For
training amortized heads, it's standard backprop.
Vardax does not own adjoint code. Gradients through dynamics come from
diffrax.AbstractAdjoint; gradients through inner minimisers come
from optimistix.AbstractAdjoint (Decision D15). This chapter
explains the choices and when to use which.
Why composable adjoints¶
Historical DA libraries hand-coded adjoint models — every forward operator \(M_t\) required a hand-written companion \(M^\top_t\), often generated by automatic-differentiation source-to-source transformations (TAMC, OpenAD). Maintaining the adjoint code as the forward evolved was a significant engineering tax.
JAX changed the picture by making reverse-mode autodiff a first-class
language feature. diffrax and optimistix build on this by exposing
named adjoint strategies as user-selectable types. The user
chooses between memory and time, between exact and approximate; the
library handles the differentiation.
Vardax exposes these strategies as constructor slots on the Layer 2
models. No vardax-owned grad_mode enum, no in-house adjoint
implementation (Decision D15).
Adjoints through dynamics — diffrax.AbstractAdjoint¶
For 4DVar with dynamics, the cost is
and the gradient pulls back through every \(M_t\):
diffrax.diffeqsolve(..., adjoint=...) handles this. Four choices:
RecursiveCheckpointAdjoint(checkpoints=N) — default¶
Recursive checkpointing of the forward trajectory. Stores \(N \sim \log K\) checkpoints; recomputes intermediate states from checkpoints during the backward pass.
| Memory | Time | When |
|---|---|---|
| \(O(\log K)\) | \(O(K \log K)\) | Default. Balanced. |
The default \(N = \lceil \log_2 K \rceil + 1\) is usually right. Override only if profiling shows memory pressure.
BacksolveAdjoint() — continuous adjoint¶
Solve the adjoint ODE backwards in time. The adjoint state \(\lambda(t)\) satisfies
integrated from \(t = T\) to \(t = 0\). Constant memory in the forward trajectory length.
| Memory | Time | When |
|---|---|---|
| \(O(1)\) | \(\sim 2 \times\) forward | Long windows, memory-constrained |
Caveat: reverse-time stiffness. If the forward ODE is well-conditioned
with an explicit integrator, the backward might be stiff. Use
dfx.Kvaerno5 or dfx.ImplicitEuler for the adjoint if so.
ForwardMode() — forward sensitivity¶
Compute gradients via forward-mode autodiff — solve sensitivity ODEs forward in time alongside the state.
| Memory | Time | When |
|---|---|---|
| \(O(P)\) | \(O(P) \times\) forward | Few parameters (\(P\) small) |
Use when you're differentiating with respect to a small number of parameters (e.g. a single emission rate \(Q\) in methane Tier I). Reverse mode dominates when \(P \gg 1\) — forward mode dominates when \(P \ll N\) (state dimension).
DirectAdjoint() — plain reverse¶
Standard reverse-mode autodiff through the lax.scan of integration
steps. No checkpointing.
| Memory | Time | When |
|---|---|---|
| \(O(K)\) | \(O(K)\) | Short windows; debugging |
Slow but transparent. Useful when debugging adjoint issues — the gradient is literally the reverse-mode autodiff result with no clever tricks.
Adjoints through inner minimisers — optimistix.AbstractAdjoint¶
For methods that train through an inner minimisation (FourDVarNet,
AmortizedPosterior, or any classical method whose hyperparameters are
themselves trained), the training gradient flows through
optimistix.minimise(...) via its adjoint argument.
RecursiveCheckpointAdjoint()¶
Unroll the inner iteration with checkpointing. Standard backprop
through the iterative solver. Used by FourDVarNet for end-to-end
training when memory permits.
| Memory | Convergence requirement |
|---|---|
| \(O(K)\) checkpoints | None |
ImplicitAdjoint() — IFT at the optimum¶
At the optimum \(x^*\) of \(\min_x f(x; \theta)\):
optimistix solves the inner system implicitly — never explicitly
inverts the Hessian, uses lineax to apply \((\nabla^2_x f)^{-1}\)
to vectors.
| Memory | Convergence requirement |
|---|---|
| \(O(1)\) | Yes (must reach optimum) |
Used by classical methods (ThreeDVar, StrongFourDVar,
WeakFourDVar) when training hyperparameters. The exact-at-optimum
property is the right model — at a well-converged MAP, IFT gives the
correct sensitivity.
DirectAdjoint()¶
Plain reverse through the optimistix solver loop. Same trade-off as
diffrax.DirectAdjoint.
vardax.adjoints.OneStepAdjoint() — Bolte et al. 2023¶
Custom vardax-owned adjoint targeting upstream contribution:
| Memory | Convergence requirement |
|---|---|
| \(O(1)\) | Approximately convergent |
Runs \(K-1\) inner iterations with jax.lax.stop_gradient, then a
single differentiable step. The training gradient picks up only the
last step's contribution. Bolte et al. prove this is exact at the
fixed point; in practice it gives near-correct gradients for converged
inner solvers and dramatically reduces memory.
from vardax.adjoints import OneStepAdjoint
model = FourDVarNet(
prior=prior, obs_op=obs_op, grad_mod=grad_mod, config=config,
solver_adjoint=OneStepAdjoint(), # O(1) memory
)
OneStepAdjoint is an optimistix.AbstractAdjoint subclass. The plan
is to contribute upstream once stable, joining the standard adjoint
family (Decision D6).
Two slots, two purposes¶
Vardax models that involve both dynamics and inner minimisation carry two adjoint slots:
| Slot | Used by | Default |
|---|---|---|
forward_adjoint: diffrax.AbstractAdjoint |
StrongFourDVar, WeakFourDVar, IncrementalFourDVar, FourDVarNet (when prior is DynamicalPrior) |
RecursiveCheckpointAdjoint() |
minimiser_adjoint: optimistix.AbstractAdjoint |
ThreeDVar, StrongFourDVar, WeakFourDVar |
ImplicitAdjoint() |
solver_adjoint: optimistix.AbstractAdjoint |
FourDVarNet (through the learned inner solver) |
RecursiveCheckpointAdjoint() |
The two are independent. A typical operational 4DVar configuration:
import diffrax as dfx
import optimistix as optx
from vardax.models import StrongFourDVar
model = StrongFourDVar(
forward=somax_model,
obs_op=obs_op,
prior_mean=x_b, prior_cov_op=B_op, obs_cov_op=R_op,
minimiser=optx.NonlinearCG(rtol=1e-5),
minimiser_adjoint=optx.ImplicitAdjoint(),
forward_adjoint=dfx.BacksolveAdjoint(), # constant memory through dynamics
)
A typical FourDVarNet training configuration:
from vardax.models import FourDVarNet
from vardax.adjoints import OneStepAdjoint
model = FourDVarNet(
prior=DynamicalPrior(forward=somax_model, n_steps=5,
forward_adjoint=dfx.BacksolveAdjoint()),
obs_op=obs_op,
grad_mod=ConvLSTMGradMod2D(hidden_dim=64),
config=SolverConfig(n_steps=15),
solver_adjoint=OneStepAdjoint(), # O(1) training memory through solver
forward_adjoint=dfx.BacksolveAdjoint(), # O(1) through the dynamical prior
)
Validation: do they agree?¶
All adjoint choices should produce the same gradient (up to floating-point tolerance) for converged problems. The test suite includes:
def test_adjoint_choices_agree(cost_fn, x_b):
grads = {}
for adj in [RecursiveCheckpointAdjoint(), BacksolveAdjoint(),
DirectAdjoint(), ForwardMode()]:
grads[type(adj).__name__] = compute_gradient(cost_fn, x_b, adj)
reference = grads['RecursiveCheckpointAdjoint']
for name, g in grads.items():
assert jnp.allclose(g, reference, atol=1e-4), \
f"{name} disagrees with reference"
Disagreement is a bug. The most common cause: reverse-time instability
in BacksolveAdjoint (switch to an implicit integrator), or
under-converged optimisation in ImplicitAdjoint (tighten the
tolerance).
Adjoint selection heuristics¶
| Scenario | Pick |
|---|---|
| Short window (\(T \le 10\)), default | RecursiveCheckpointAdjoint() |
| Long window (\(T \ge 100\)), memory-bound | BacksolveAdjoint() |
| Parameter estimation, \(P \ll N\) | ForwardMode() |
| Debugging | DirectAdjoint() |
Training FourDVarNet, memory-bound |
OneStepAdjoint() |
Training FourDVarNet, well-converged inner |
ImplicitAdjoint() |
Training FourDVarNet, default |
RecursiveCheckpointAdjoint() |
| Training classical hyperparameters | ImplicitAdjoint() (always converged at MAP) |
See also¶
- Chapter 3 — dynamical model (
forward_adjointuse) - Chapter 6 —
StrongFourDVar(both slots active) - Chapter 8 —
IncrementalFourDVar(usesforward_adjointonly; inner solver is hand-rolled GN+CG, not optimistix) - Chapter 9 —
FourDVarNet(thesolver_adjointchoice matters most here) - Design doc:
design/decisions.md#d15
References¶
- Chen, R. T. Q., Rubanova, Y., Bettencourt, J., & Duvenaud, D. (2018). Neural ordinary differential equations. NeurIPS.
- Kidger, P. (2021). On neural differential equations. PhD thesis, University of Oxford. [diffrax author's thesis; see Ch. 5 on adjoints.]
- Errico, R. M. (1997). What is an adjoint model? BAMS 78(11).
- Bolte, J., Pauwels, E., & Vaiter, S. (2023). One-step differentiation of iterative algorithms. NeurIPS.
- Blondel, M., Berthet, Q., Cuturi, M., Frostig, R., Hoyer, S.,
Llinares-López, F., Pedregosa, F., & Vert, J.-P. (2022). Efficient
and modular implicit differentiation. NeurIPS. [jaxopt /
optimistix's
ImplicitAdjointfoundation.]