Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

4DVarNet — Learned 4DVar

CNRS
MEOM

4DVarNet Fablet et al., 2021 keeps the shape of classical 4DVar — minimise a variational cost that balances a data term against a prior — but makes two of its components learnable:

  1. the Gaussian background 12xxbB12\tfrac{1}{2}\|\mathbf{x}-\mathbf{x}_b\|^2_{\mathbf{B}^{-1}} becomes a learned prior ϕθ\boldsymbol{\phi}_{\boldsymbol{\theta}} (an autoencoder), able to capture nonlinear, multimodal manifold structure a Gaussian cannot;

  2. the inner gradient-descent solver becomes a learned iterative update Φψ\boldsymbol{\Phi}_{\boldsymbol{\psi}} (a ConvLSTM / MLP / attention network).

The whole pipeline is differentiable end-to-end and is trained on (observation,target)(\text{observation}, \text{target}) pairs. The “Net” is a method name (Fablet et al.), not a separate paradigm — with the right (trivial) choices it reduces to classical 4DVar and to OI.

The notation follows the DA notes: state x\mathbf{x}, observations y\mathbf{y} (here on a subdomain / mask Ω\boldsymbol{\Omega}), observation operator HH. The learnable parameters are collected as Θ=(θ,ψ)\boldsymbol{\Theta} = (\boldsymbol{\theta}, \boldsymbol{\psi}) — prior parameters θ\boldsymbol{\theta} and solver parameters ψ\boldsymbol{\psi}.

Setup

We observe a noisy, gappy field — the truth corrupted by noise and seen only on a subdomain Ω\boldsymbol{\Omega} (the observed pixels):

y=H(xgt)+ε,seen only on Ω.\mathbf{y} = H(\mathbf{x}_{\text{gt}}) + \boldsymbol{\varepsilon}, \qquad \text{seen only on } \boldsymbol{\Omega}.

The goal is a state x\mathbf{x} that matches the observations on Ω\boldsymbol{\Omega} and fills the gaps on the complement Ωˉ\bar{\boldsymbol{\Omega}}. Classical 4DVar fills gaps using a Gaussian background; 4DVarNet fills them using a learned prior.

The Variational Cost

The inner cost balances an observation term (restricted to Ω\boldsymbol{\Omega}) against a learned-prior term:

Jθ(x)=αobs2H(x)yΩ2+αpriorxϕθ(x)22,J_{\boldsymbol{\theta}}(\mathbf{x}) = \tfrac{\alpha_{\text{obs}}}{2}\,\big\| H(\mathbf{x}) - \mathbf{y} \big\|^2_{\boldsymbol{\Omega}} + \alpha_{\text{prior}}\,\big\| \mathbf{x} - \boldsymbol{\phi}_{\boldsymbol{\theta}}(\mathbf{x}) \big\|^2_2,

where Ω2\|\cdot\|_{\boldsymbol{\Omega}}^2 is the quadratic norm evaluated only on the observed subdomain. The prior term measures how far x\mathbf{x} is from its own learned reconstruction ϕθ(x)\boldsymbol{\phi}_{\boldsymbol{\theta}}(\mathbf{x}): it is small on the manifold of plausible states and large off it. This replaces — not complements — the Gaussian background.

State vs. Parameters: a Bi-Level Problem

This is the conceptual heart of 4DVarNet, and where it differs from classical 4DVar. There are two distinct kinds of unknown, optimised in two nested loops:

State x\mathbf{x}Parameters Θ=(θ,ψ)\boldsymbol{\Theta}=(\boldsymbol{\theta},\boldsymbol{\psi})
scopeone per example (this observation y\mathbf{y})shared across the whole dataset
rolethe thing we want to estimatehow to estimate it (prior + solver)
found bythe inner solver (estimation)the outer training loop (learning)
optimiserlearned gradient steps Φψ\boldsymbol{\Phi}_{\boldsymbol{\psi}}optax (Adam/SGD) over the dataset
data neededa single (y,Ω)(\mathbf{y}, \boldsymbol{\Omega})a dataset D\mathcal{D} of many examples

Inner problem — state estimation. For fixed parameters Θ\boldsymbol{\Theta} and a given observation y\mathbf{y}, run the solver to produce the analysis

x(Θ;y)    arg minx  Jθ(x).\mathbf{x}^\star(\boldsymbol{\Theta}; \mathbf{y}) \;\approx\; \operatorname*{arg\,min}_{\mathbf{x}} \; J_{\boldsymbol{\theta}}(\mathbf{x}).

Outer problem — parameter learning. Choose the parameters that make the inner solver’s output match the truth, averaged over the dataset:

Θ=arg minΘ  L(Θ),L(Θ)=E(y,xgt)Dx(Θ;y)xgt22.\boldsymbol{\Theta}^\star = \operatorname*{arg\,min}_{\boldsymbol{\Theta}} \; \mathcal{L}(\boldsymbol{\Theta}), \qquad \mathcal{L}(\boldsymbol{\Theta}) = \mathbb{E}_{(\mathbf{y}, \mathbf{x}_{\text{gt}}) \sim \mathcal{D}} \big\| \mathbf{x}^\star(\boldsymbol{\Theta}; \mathbf{y}) - \mathbf{x}_{\text{gt}} \big\|_2^2.

The Learned Inner Solver

Classical 4DVar takes plain gradient steps on the cost, x(k+1)=x(k)λxJθ(x(k))\mathbf{x}^{(k+1)} = \mathbf{x}^{(k)} - \lambda \nabla_{\mathbf{x}} J_{\boldsymbol{\theta}}(\mathbf{x}^{(k)}). 4DVarNet replaces the fixed step with a learned update that reads the cost gradient and a recurrent hidden state h(k)\mathbf{h}^{(k)}:

x(k+1)=x(k)Φψ ⁣(xJθ(x(k));h(k)),k=0,,K1.\mathbf{x}^{(k+1)} = \mathbf{x}^{(k)} - \boldsymbol{\Phi}_{\boldsymbol{\psi}}\!\big( \nabla_{\mathbf{x}} J_{\boldsymbol{\theta}}(\mathbf{x}^{(k)}); \, \mathbf{h}^{(k)} \big), \qquad k = 0, \ldots, K-1.

Φψ\boldsymbol{\Phi}_{\boldsymbol{\psi}} (a ConvLSTM, MLP, or attention network) learns adaptive step sizes, momentum, and problem-specific preconditioning. After each step the observed pixels are reset to the data and only the gaps are updated — the Ω\boldsymbol{\Omega}-projection:

x(k+1)(Ω)=y(Ω),x(k+1)(Ωˉ)=x~(k+1)(Ωˉ).\mathbf{x}^{(k+1)}(\boldsymbol{\Omega}) = \mathbf{y}(\boldsymbol{\Omega}), \qquad \mathbf{x}^{(k+1)}(\bar{\boldsymbol{\Omega}}) = \tilde{\mathbf{x}}^{(k+1)}(\bar{\boldsymbol{\Omega}}).

Three flavours of inner update, from hand-built to fully learned:

Projection
Gradient
Learned (LSTM/CNN)

A learned reconstruction ϕ\boldsymbol{\phi} is applied, then the observed pixels are overwritten (the DINEOF / DINCAE pattern):

x~(k+1)=ϕθ(x(k)),x(k+1)=Ωy+(1Ω)x~(k+1).\tilde{\mathbf{x}}^{(k+1)} = \boldsymbol{\phi}_{\boldsymbol{\theta}}(\mathbf{x}^{(k)}), \qquad \mathbf{x}^{(k+1)} = \boldsymbol{\Omega} \odot \mathbf{y} + (1 - \boldsymbol{\Omega}) \odot \tilde{\mathbf{x}}^{(k+1)}.
import einx
from jaxtyping import Array, Bool, Float

def project_step(
    x:    Float[Array, "... D"],
    y:    Float[Array, "... D"],
    mask: Float[Array, "... D"],     # Ω indicator (1 = observed)
    phi,                              # learned reconstruction φ_θ
) -> Float[Array, "... D"]:
    x_hat = phi(x)
    return einx.multiply("... D, ... D -> ... D", mask, y) \
         + einx.multiply("... D, ... D -> ... D", 1.0 - mask, x_hat)

Iterating (5) for KK steps and applying (6) at each step is the inner solver x(Θ;y)\mathbf{x}^\star(\boldsymbol{\Theta}; \mathbf{y}) of (3).

Why We Need optax

The inner loop takes “gradient steps,” but those steps optimise the state x\mathbf{x} for one example — they are part of the forward pass. The thing we actually train is the parameters Θ\boldsymbol{\Theta}, in the outer loop (4), and that is an ordinary supervised deep-learning problem: minimise a loss over a dataset with respect to network weights. That is precisely what a gradient-based optimiser library like optax is for — Adam/SGD with momentum, schedules, clipping, and optimiser state — so we do not hand-roll it.

The outer gradient ΘL\nabla_{\boldsymbol{\Theta}}\mathcal{L} flows through the KK inner solver iterations (see adjoint choices below); optax then consumes it to update the parameters:

import jax
import jax.numpy as jnp
import optax

optimizer = optax.adam(1e-3)
opt_state = optimizer.init(params)            # params = Θ = (θ, ψ)

@jax.jit
def train_step(params, opt_state, y, mask, x_gt):
    def loss_fn(params):
        x_star = inner_solve(params, y, mask)               # K learned steps -> x*(Θ; y)
        return jnp.mean((x_star - x_gt) ** 2)               # outer reconstruction loss
    loss, grads = jax.value_and_grad(loss_fn)(params)       # ∇_Θ L, through the inner solve
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

Training

The reconstruction loss (4) comes in two flavours, mirroring the data you have:

Supervised
Self-supervised

A dataset of (y,xgt)(\mathbf{y}, \mathbf{x}_{\text{gt}}) pairs — typically from simulated / emulated truth used for pre-training. The loss is the reconstruction error against the ground truth, Dsup={(y(i),xgt(i))}i=1N\mathcal{D}_{\text{sup}} = \{(\mathbf{y}^{(i)}, \mathbf{x}_{\text{gt}}^{(i)})\}_{i=1}^N.

Differentiating Through the Inner Solver

Because x\mathbf{x}^\star is produced by KK iterations, the outer gradient must backpropagate through them. Three strategies trade memory against exactness:

Table 1:Adjoint strategies for the learned inner solver. KK = number of inner steps.

StrategyMemoryExact?Notes
Recursive checkpointingO(K)O(K)yes (unrolled)standard backprop through all KK steps
One-stepO(1)O(1)at convergencestop_gradient on K1K{-}1 steps, differentiate the last Bolte et al., 2023
Implicit (IFT)O(1)O(1)at convergencedifferentiate the fixed point directly

The implicit strategy differentiates the fixed point x=F(x,Θ)\mathbf{x}^\star = \boldsymbol{F}(\mathbf{x}^\star, \boldsymbol{\Theta}) via the implicit function theorem, avoiding storage of the iterates entirely:

xΘ=(IxF)1ΘF.\frac{\partial \mathbf{x}^\star}{\partial \boldsymbol{\Theta}} = \Big( \mathbf{I} - \partial_{\mathbf{x}} \boldsymbol{F} \Big)^{-1} \, \partial_{\boldsymbol{\Theta}} \boldsymbol{F}.

The Prior and Solver Families

The two learnable pieces are swappable. The prior ϕθ\boldsymbol{\phi}_{\boldsymbol{\theta}}:

PriorFormWhen
Bilinear AEϕ(x)=decode(ReLU(Ax)tanh(Bx))\boldsymbol{\phi}(\mathbf{x}) = \mathrm{decode}\big(\mathrm{ReLU}(\mathbf{A}\mathbf{x}) \odot \tanh(\mathbf{B}\mathbf{x})\big)standard 4DVarNet starting point
Conv AEperiodic 1D/2D convolutional autoencodergridded periodic fields (e.g. Lorenz-96)
MLP AEdense autoencoderlow-dimensional states (e.g. Lorenz-63)
Identityϕ(x)=x\boldsymbol{\phi}(\mathbf{x}) = \mathbf{x}pure obs-driven baseline → classical 4DVar
Dynamicalϕ(x)=Mn(x)\boldsymbol{\phi}(\mathbf{x}) = M^n(\mathbf{x}) via a forward modelphysics-informed: regulariser = “consistency with the dynamics”

The gradient modulator Φψ\boldsymbol{\Phi}_{\boldsymbol{\psi}}:

ModulatorForm
ConvLSTMConvLSTM over space, recurrent over the KK iterations
MLPdense, dimension-agnostic via flattening
Attentionself-attention over the spatial axis
Identityupdate=αxJ\text{update} = -\alpha\, \nabla_{\mathbf{x}} J — classical fixed-step descent

When to Use 4DVarNet

Use it when…
Reach for something else when…
  • A training set of (observation,target)(\text{observation}, \text{target}) pairs is available.

  • A classical B\mathbf{B} is inadequate (nonlinear / multimodal state manifold).

  • Data-rich regime with smooth manifolds and a stable train↔test distribution.

Posterior

A Laplace approximation works as in the classical methods, but the Hessian at the MAP uses the learned-prior curvature, which may be poorly conditioned outside the training manifold. Treat 4DVarNet posteriors with caution and validate them; when full uncertainty matters, an ensemble method is safer.

Perspectives: Explicit vs. Implicit

The learned map can be framed three ways, which clarifies what gets penalised:

A probabilistic head replaces the squared error with a Gaussian negative log-likelihood, predicting a mean and variance, p(x(k+1)x(k))=N(x(k)μθ(x(k)),σθ2(x(k)))p(\mathbf{x}^{(k+1)} \mid \mathbf{x}^{(k)}) = \mathcal{N}\big(\mathbf{x}^{(k)} \mid \boldsymbol{\mu}_{\boldsymbol{\theta}}(\mathbf{x}^{(k)}), \boldsymbol{\sigma}^2_{\boldsymbol{\theta}}(\mathbf{x}^{(k)})\big), giving calibrated per-pixel uncertainty when trained with the NLL loss.

History

4DVarNet originated with Fablet et al. Fablet et al., 2021Fablet et al., 2023 for satellite sea-surface-height reconstruction; the original implementation was ocean-specific (PyTorch). The formulation here is framework-agnostic (jaxtyping + JAX autodiff + optax) and positions 4DVarNet as one method among many in the DA hierarchy — a learnable generalisation of, not a replacement for, classical variational assimilation.

References
  1. Fablet, R., Chapron, B., Drumetz, L., Mémin, E., Pannekoucke, O., & Rousseau, F. (2021). Learning Variational Data Assimilation Models and Solvers. Journal of Advances in Modeling Earth Systems, 13(10), e2021MS002572. 10.1029/2021MS002572
  2. Bolte, J., Pauwels, E., & Vaiter, S. (2023). One-Step Differentiation of Iterative Algorithms. Advances in Neural Information Processing Systems (NeurIPS).
  3. Fablet, R., Febvre, Q., & Chapron, B. (2023). Multimodal 4DVarNets for the Reconstruction of Sea Surface Dynamics from NADIR and Wide-Swath Altimetry. IEEE Transactions on Geoscience and Remote Sensing, 61, 1–14. 10.1109/TGRS.2023.3268370