Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

The Dynamical Model

CSIC
UCM
IGEO

For 4D analysis, observations yt\boldsymbol{y}_t at times t=0,1,,Tt = 0, 1, \ldots, T must be connected to the initial state x0\boldsymbol{x}_0 through the dynamics. The dynamical model MtM_t propagates the state forward in time,

xt=Mt(x0),\boldsymbol{x}_t = M_t(\boldsymbol{x}_0),

and the dynamics enter the cost through the composition HtMtH_t \circ M_t — the observation operator applied to the propagated state.

As in the observation-model note, the code is library-agnostic: shapes via jaxtyping, array ops via einx, ODE integration and adjoints via diffrax, and gradients via JAX autodiff.

The Forward-Model Interface

A forward model only needs to expose how to advance the state. Everything else — multi-step rollouts, tangent-linear, adjoint — is built on top.

from typing import Protocol
from jaxtyping import Array, Float

class ForwardModel(Protocol):
    dt: float

    def step(self, state: Float[Array, "N"], dt: float) -> Float[Array, "N"]:
        """Advance the state by one time step dt (discrete dynamics)."""
        ...

    def vector_field(self, t: float, state: Float[Array, "N"], args) -> Float[Array, "N"]:
        """Continuous-time RHS  dy/dt = f(t, y)  for ODE integration."""
        ...

Any model satisfying this is interchangeable from the assimilation code’s point of view. For multi-step rollouts, compose step calls into a trajectory via diffrax.diffeqsolve (when the model is an ODE) or direct iteration (when it is discrete). The integrator choice (Tsit5, Dopri5, Heun, …) belongs to the dynamics, not the assimilation layer.

Why the Adjoint Matters

The 4DVar cost sums a background term and an observation term over time,

J(x0)=12x0xbB12+12tytHt(Mt(x0))Rt12.J(\boldsymbol{x}_0) = \tfrac{1}{2} \, \| \boldsymbol{x}_0 - \boldsymbol{x}_b \|^2_{\mathbf{B}^{-1}} + \tfrac{1}{2} \sum_{t} \| \boldsymbol{y}_t - H_t(M_t(\boldsymbol{x}_0)) \|^2_{\mathbf{R}_t^{-1}}.

Minimising over x0\boldsymbol{x}_0 requires x0J\nabla_{\boldsymbol{x}_0} J, which by the chain rule pulls back through every MtM_t:

x0J(x0)=B1(x0xb)+t(Mt)(Ht)Rt1(Ht(Mt(x0))yt).\nabla_{\boldsymbol{x}_0} J(\boldsymbol{x}_0) = \mathbf{B}^{-1}(\boldsymbol{x}_0 - \boldsymbol{x}_b) + \sum_{t} (M'_t)^\top \, (H'_t)^\top \, \mathbf{R}_t^{-1} \, \big( H_t(M_t(\boldsymbol{x}_0)) - \boldsymbol{y}_t \big).

In practice you never assemble (3) by hand. Write the cost (2), then let autodiff compose the pullback:

import jax
import einx
from jaxtyping import Array, Float

def fourdvar_cost(
    x0:     Float[Array, "N"],
    xb:     Float[Array, "N"],
    B_inv:  Float[Array, "N N"],
    ys:     Float[Array, "T M"],
    rollout,                       # x0 -> trajectory (T, N); diffrax handles the adjoint
    H,                             # observation operator applied per step
    R_inv:  Float[Array, "M M"],
) -> Float[Array, ""]:
    # background term:  ½ ‖x0 - xb‖²_{B⁻¹}
    db = x0 - xb
    Jb = 0.5 * einx.dot("N, N ->", db, einx.dot("N K, K -> N", B_inv, db))

    # observation term:  ½ Σ_t ‖y_t - H(M_t(x0))‖²_{R⁻¹}
    traj = rollout(x0)                                  # (T, N)
    pred = jax.vmap(H)(traj)                            # (T, M)
    res  = ys - pred                                    # (T, M)
    Rr   = jax.vmap(lambda r: einx.dot("M K, K -> M", R_inv, r))(res)
    Jo   = 0.5 * einx.dot("T M, T M ->", res, Rr)

    return Jb + Jo

grad_J = jax.grad(fourdvar_cost)        # adjoint pullback through every M_t, for free

Choosing an Adjoint Method

Expose the adjoint as a configurable slot on your 4DVar solver and thread it into diffeqsolve; the choice flows straight through to diffrax. The trade-offs:

Table 1:diffrax adjoint methods. KK = number of forward steps, PP = number of parameters.

MethodMemoryExtra costUse when
RecursiveCheckpointAdjoint(checkpoints=N)O(N)O(N)O(logK)O(\log K) extra forward solvesdefault — balances memory and recompute
BacksolveAdjoint()O(1)O(1)backward ODE solvelong windows, memory-constrained
ForwardMode()O(P)O(P) in paramsone extra forward per parameterfew parameters (parameter estimation)
DirectAdjoint()O(K)O(K) stepsstandard reverseshort windows, debugging

Dynamics Regimes

Nonlinear (default)
Linear / incremental
Stochastic (SDE)

The full nonlinear MtM_t is integrated each outer step and diffrax’s adjoint handles the gradient. This is the unwrapped strong/weak-constraint 4DVar setting — nothing special is required beyond choosing an adjoint method.

Forward Rollout in Practice

A typical rollout integrates the vector field and saves the trajectory, threading the chosen adjoint so jax.grad flows back through it:

import jax.numpy as jnp
import diffrax as dfx
from jaxtyping import Array, Float

def rollout(
    x0: Float[Array, "N"],
    forward,                       # exposes forward.vector_field(t, y, args)
    n_steps: int,
    dt: float,
    *,
    adjoint: dfx.AbstractAdjoint,
) -> Float[Array, "T N"]:
    """Integrate forward and return the trajectory (n_steps + 1, N)."""
    sol = dfx.diffeqsolve(
        dfx.ODETerm(lambda t, y, args: forward.vector_field(t, y, args)),
        solver=dfx.Tsit5(),
        t0=0.0, t1=n_steps * dt, dt0=dt,
        y0=x0,
        saveat=dfx.SaveAt(ts=jnp.arange(n_steps + 1) * dt),
        adjoint=adjoint,
    )
    return sol.ys

Compose this inside fourdvar_cost; when jax.grad(J)(x0) is called, the gradient flows back through rollout along the adjoint path you selected.