Skip to content

Adjoint Methods

Every analysis method that involves minimisation or dynamics needs gradients. For 4DVar, that's gradients through the ODE rollout. For FourDVarNet, that's gradients through the inner learned solver. For training amortized heads, it's standard backprop.

Vardax does not own adjoint code. Gradients through dynamics come from diffrax.AbstractAdjoint; gradients through inner minimisers come from optimistix.AbstractAdjoint (Decision D15). This chapter explains the choices and when to use which.

Why composable adjoints

Historical DA libraries hand-coded adjoint models — every forward operator \(M_t\) required a hand-written companion \(M^\top_t\), often generated by automatic-differentiation source-to-source transformations (TAMC, OpenAD). Maintaining the adjoint code as the forward evolved was a significant engineering tax.

JAX changed the picture by making reverse-mode autodiff a first-class language feature. diffrax and optimistix build on this by exposing named adjoint strategies as user-selectable types. The user chooses between memory and time, between exact and approximate; the library handles the differentiation.

Vardax exposes these strategies as constructor slots on the Layer 2 models. No vardax-owned grad_mode enum, no in-house adjoint implementation (Decision D15).

Adjoints through dynamics — diffrax.AbstractAdjoint

For 4DVar with dynamics, the cost is

\[ J(x_0) = J_b(x_0) + \sum_t \|y_t - H_t(M_t(x_0))\|^2_{R_t^{-1}}, \]

and the gradient pulls back through every \(M_t\):

\[ \nabla_{x_0} J = B^{-1}(x_0 - x_b) + \sum_t (M'_t)^\top H'_t{}^\top R_t^{-1} (H_t(M_t(x_0)) - y_t). \]

diffrax.diffeqsolve(..., adjoint=...) handles this. Four choices:

RecursiveCheckpointAdjoint(checkpoints=N) — default

Recursive checkpointing of the forward trajectory. Stores \(N \sim \log K\) checkpoints; recomputes intermediate states from checkpoints during the backward pass.

Memory Time When
\(O(\log K)\) \(O(K \log K)\) Default. Balanced.

The default \(N = \lceil \log_2 K \rceil + 1\) is usually right. Override only if profiling shows memory pressure.

BacksolveAdjoint() — continuous adjoint

Solve the adjoint ODE backwards in time. The adjoint state \(\lambda(t)\) satisfies

\[ \frac{d\lambda}{dt} = -(M'(x(t)))^\top \lambda(t) + \nabla_{x(t)} J, \]

integrated from \(t = T\) to \(t = 0\). Constant memory in the forward trajectory length.

Memory Time When
\(O(1)\) \(\sim 2 \times\) forward Long windows, memory-constrained

Caveat: reverse-time stiffness. If the forward ODE is well-conditioned with an explicit integrator, the backward might be stiff. Use dfx.Kvaerno5 or dfx.ImplicitEuler for the adjoint if so.

ForwardMode() — forward sensitivity

Compute gradients via forward-mode autodiff — solve sensitivity ODEs forward in time alongside the state.

Memory Time When
\(O(P)\) \(O(P) \times\) forward Few parameters (\(P\) small)

Use when you're differentiating with respect to a small number of parameters (e.g. a single emission rate \(Q\) in methane Tier I). Reverse mode dominates when \(P \gg 1\) — forward mode dominates when \(P \ll N\) (state dimension).

DirectAdjoint() — plain reverse

Standard reverse-mode autodiff through the lax.scan of integration steps. No checkpointing.

Memory Time When
\(O(K)\) \(O(K)\) Short windows; debugging

Slow but transparent. Useful when debugging adjoint issues — the gradient is literally the reverse-mode autodiff result with no clever tricks.

Adjoints through inner minimisers — optimistix.AbstractAdjoint

For methods that train through an inner minimisation (FourDVarNet, AmortizedPosterior, or any classical method whose hyperparameters are themselves trained), the training gradient flows through optimistix.minimise(...) via its adjoint argument.

RecursiveCheckpointAdjoint()

Unroll the inner iteration with checkpointing. Standard backprop through the iterative solver. Used by FourDVarNet for end-to-end training when memory permits.

Memory Convergence requirement
\(O(K)\) checkpoints None

ImplicitAdjoint() — IFT at the optimum

At the optimum \(x^*\) of \(\min_x f(x; \theta)\):

\[ \frac{d x^*}{d \theta} = -\big( \nabla^2_x f(x^*; \theta) \big)^{-1} \nabla^2_{x \theta} f(x^*; \theta). \]

optimistix solves the inner system implicitly — never explicitly inverts the Hessian, uses lineax to apply \((\nabla^2_x f)^{-1}\) to vectors.

Memory Convergence requirement
\(O(1)\) Yes (must reach optimum)

Used by classical methods (ThreeDVar, StrongFourDVar, WeakFourDVar) when training hyperparameters. The exact-at-optimum property is the right model — at a well-converged MAP, IFT gives the correct sensitivity.

DirectAdjoint()

Plain reverse through the optimistix solver loop. Same trade-off as diffrax.DirectAdjoint.

vardax.adjoints.OneStepAdjoint() — Bolte et al. 2023

Custom vardax-owned adjoint targeting upstream contribution:

Memory Convergence requirement
\(O(1)\) Approximately convergent

Runs \(K-1\) inner iterations with jax.lax.stop_gradient, then a single differentiable step. The training gradient picks up only the last step's contribution. Bolte et al. prove this is exact at the fixed point; in practice it gives near-correct gradients for converged inner solvers and dramatically reduces memory.

from vardax.adjoints import OneStepAdjoint

model = FourDVarNet(
    prior=prior, obs_op=obs_op, grad_mod=grad_mod, config=config,
    solver_adjoint=OneStepAdjoint(),     # O(1) memory
)

OneStepAdjoint is an optimistix.AbstractAdjoint subclass. The plan is to contribute upstream once stable, joining the standard adjoint family (Decision D6).

Two slots, two purposes

Vardax models that involve both dynamics and inner minimisation carry two adjoint slots:

Slot Used by Default
forward_adjoint: diffrax.AbstractAdjoint StrongFourDVar, WeakFourDVar, IncrementalFourDVar, FourDVarNet (when prior is DynamicalPrior) RecursiveCheckpointAdjoint()
minimiser_adjoint: optimistix.AbstractAdjoint ThreeDVar, StrongFourDVar, WeakFourDVar ImplicitAdjoint()
solver_adjoint: optimistix.AbstractAdjoint FourDVarNet (through the learned inner solver) RecursiveCheckpointAdjoint()

The two are independent. A typical operational 4DVar configuration:

import diffrax as dfx
import optimistix as optx
from vardax.models import StrongFourDVar

model = StrongFourDVar(
    forward=somax_model,
    obs_op=obs_op,
    prior_mean=x_b, prior_cov_op=B_op, obs_cov_op=R_op,
    minimiser=optx.NonlinearCG(rtol=1e-5),
    minimiser_adjoint=optx.ImplicitAdjoint(),
    forward_adjoint=dfx.BacksolveAdjoint(),     # constant memory through dynamics
)

A typical FourDVarNet training configuration:

from vardax.models import FourDVarNet
from vardax.adjoints import OneStepAdjoint

model = FourDVarNet(
    prior=DynamicalPrior(forward=somax_model, n_steps=5,
                         forward_adjoint=dfx.BacksolveAdjoint()),
    obs_op=obs_op,
    grad_mod=ConvLSTMGradMod2D(hidden_dim=64),
    config=SolverConfig(n_steps=15),
    solver_adjoint=OneStepAdjoint(),            # O(1) training memory through solver
    forward_adjoint=dfx.BacksolveAdjoint(),     # O(1) through the dynamical prior
)

Validation: do they agree?

All adjoint choices should produce the same gradient (up to floating-point tolerance) for converged problems. The test suite includes:

def test_adjoint_choices_agree(cost_fn, x_b):
    grads = {}
    for adj in [RecursiveCheckpointAdjoint(), BacksolveAdjoint(),
                 DirectAdjoint(), ForwardMode()]:
        grads[type(adj).__name__] = compute_gradient(cost_fn, x_b, adj)

    reference = grads['RecursiveCheckpointAdjoint']
    for name, g in grads.items():
        assert jnp.allclose(g, reference, atol=1e-4), \
            f"{name} disagrees with reference"

Disagreement is a bug. The most common cause: reverse-time instability in BacksolveAdjoint (switch to an implicit integrator), or under-converged optimisation in ImplicitAdjoint (tighten the tolerance).

Adjoint selection heuristics

Scenario Pick
Short window (\(T \le 10\)), default RecursiveCheckpointAdjoint()
Long window (\(T \ge 100\)), memory-bound BacksolveAdjoint()
Parameter estimation, \(P \ll N\) ForwardMode()
Debugging DirectAdjoint()
Training FourDVarNet, memory-bound OneStepAdjoint()
Training FourDVarNet, well-converged inner ImplicitAdjoint()
Training FourDVarNet, default RecursiveCheckpointAdjoint()
Training classical hyperparameters ImplicitAdjoint() (always converged at MAP)

See also

  • Chapter 3 — dynamical model (forward_adjoint use)
  • Chapter 6 — StrongFourDVar (both slots active)
  • Chapter 8 — IncrementalFourDVar (uses forward_adjoint only; inner solver is hand-rolled GN+CG, not optimistix)
  • Chapter 9 — FourDVarNet (the solver_adjoint choice matters most here)
  • Design doc: design/decisions.md#d15

References

  • Chen, R. T. Q., Rubanova, Y., Bettencourt, J., & Duvenaud, D. (2018). Neural ordinary differential equations. NeurIPS.
  • Kidger, P. (2021). On neural differential equations. PhD thesis, University of Oxford. [diffrax author's thesis; see Ch. 5 on adjoints.]
  • Errico, R. M. (1997). What is an adjoint model? BAMS 78(11).
  • Bolte, J., Pauwels, E., & Vaiter, S. (2023). One-step differentiation of iterative algorithms. NeurIPS.
  • Blondel, M., Berthet, Q., Cuturi, M., Frostig, R., Hoyer, S., Llinares-López, F., Pedregosa, F., & Vert, J.-P. (2022). Efficient and modular implicit differentiation. NeurIPS. [jaxopt / optimistix's ImplicitAdjoint foundation.]