4DVarNet Fablet et al., 2021 keeps the shape of classical 4DVar — minimise a variational cost that balances a data term against a prior — but makes two of its components learnable:
the Gaussian background becomes a learned prior (an autoencoder), able to capture nonlinear, multimodal manifold structure a Gaussian cannot;
the inner gradient-descent solver becomes a learned iterative update (a ConvLSTM / MLP / attention network).
The whole pipeline is differentiable end-to-end and is trained on pairs. The “Net” is a method name (Fablet et al.), not a separate paradigm — with the right (trivial) choices it reduces to classical 4DVar and to OI.
The notation follows the DA notes: state , observations (here on a subdomain / mask ), observation operator . The learnable parameters are collected as — prior parameters and solver parameters .
Setup¶
We observe a noisy, gappy field — the truth corrupted by noise and seen only on a subdomain (the observed pixels):
The goal is a state that matches the observations on and fills the gaps on the complement . Classical 4DVar fills gaps using a Gaussian background; 4DVarNet fills them using a learned prior.
The Variational Cost¶
The inner cost balances an observation term (restricted to ) against a learned-prior term:
where is the quadratic norm evaluated only on the observed subdomain. The prior term measures how far is from its own learned reconstruction : it is small on the manifold of plausible states and large off it. This replaces — not complements — the Gaussian background.
State vs. Parameters: a Bi-Level Problem¶
This is the conceptual heart of 4DVarNet, and where it differs from classical 4DVar. There are two distinct kinds of unknown, optimised in two nested loops:
| State | Parameters | |
|---|---|---|
| scope | one per example (this observation ) | shared across the whole dataset |
| role | the thing we want to estimate | how to estimate it (prior + solver) |
| found by | the inner solver (estimation) | the outer training loop (learning) |
| optimiser | learned gradient steps | optax (Adam/SGD) over the dataset |
| data needed | a single | a dataset of many examples |
Inner problem — state estimation. For fixed parameters and a given observation , run the solver to produce the analysis
Outer problem — parameter learning. Choose the parameters that make the inner solver’s output match the truth, averaged over the dataset:
The Learned Inner Solver¶
Classical 4DVar takes plain gradient steps on the cost, . 4DVarNet replaces the fixed step with a learned update that reads the cost gradient and a recurrent hidden state :
(a ConvLSTM, MLP, or attention network) learns adaptive step sizes, momentum, and problem-specific preconditioning. After each step the observed pixels are reset to the data and only the gaps are updated — the -projection:
Three flavours of inner update, from hand-built to fully learned:
A learned reconstruction is applied, then the observed pixels are overwritten (the DINEOF / DINCAE pattern):
import einx
from jaxtyping import Array, Bool, Float
def project_step(
x: Float[Array, "... D"],
y: Float[Array, "... D"],
mask: Float[Array, "... D"], # Ω indicator (1 = observed)
phi, # learned reconstruction φ_θ
) -> Float[Array, "... D"]:
x_hat = phi(x)
return einx.multiply("... D, ... D -> ... D", mask, y) \
+ einx.multiply("... D, ... D -> ... D", 1.0 - mask, x_hat)A fixed-step descent on the variational cost (2) — this is the
IdentityGradMod baseline, i.e. classical 4DVar:
import jax
from jaxtyping import Array, Float
def energy(x, y, mask, theta, *, a_obs=0.99, a_prior=0.01) -> Float[Array, ""]:
loss_obs = jax.numpy.mean(mask * (x - y) ** 2) # ‖H(x)-y‖²_Ω (H = identity here)
loss_prior = jax.numpy.mean((phi(x, theta) - x) ** 2) # ‖x - φ_θ(x)‖²
return a_obs * loss_obs + a_prior * loss_prior
def gradient_step(x, y, mask, theta, alpha: float) -> Float[Array, "... D"]:
return x - alpha * jax.grad(energy)(x, y, mask, theta)The cost gradient is fed to a recurrent network that emits the actual update — (5). This is the 4DVarNet proper:
import jax
from jaxtyping import Array, Float, PyTree
def learned_step(
x: Float[Array, "... D"],
y: Float[Array, "... D"],
mask: Float[Array, "... D"],
theta: PyTree, # prior parameters
grad_mod, # solver network Φ_ψ (carries its own state)
h: PyTree, # recurrent hidden state h_k
alpha: float,
) -> tuple[Float[Array, "... D"], PyTree]:
g = alpha * jax.grad(energy)(x, y, mask, theta) # ∇ₓ J
update, h = grad_mod(g, h) # learned modulation
return x - update, hIterating (5) for steps and applying (6) at each step is the inner solver of (3).
Why We Need optax¶
The inner loop takes “gradient steps,” but those steps optimise the state for one example — they are part of the forward pass. The thing we actually train is the parameters , in the outer loop (4), and that is an ordinary supervised deep-learning problem: minimise a loss over a dataset with respect to network weights. That is precisely what a gradient-based optimiser library like optax is for — Adam/SGD with momentum, schedules, clipping, and optimiser state — so we do not hand-roll it.
The outer gradient flows through the inner solver iterations (see adjoint choices below); optax then consumes it to update the parameters:
import jax
import jax.numpy as jnp
import optax
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(params) # params = Θ = (θ, ψ)
@jax.jit
def train_step(params, opt_state, y, mask, x_gt):
def loss_fn(params):
x_star = inner_solve(params, y, mask) # K learned steps -> x*(Θ; y)
return jnp.mean((x_star - x_gt) ** 2) # outer reconstruction loss
loss, grads = jax.value_and_grad(loss_fn)(params) # ∇_Θ L, through the inner solve
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, lossTraining¶
The reconstruction loss (4) comes in two flavours, mirroring the data you have:
A dataset of pairs — typically from simulated / emulated truth used for pre-training. The loss is the reconstruction error against the ground truth, .
No ground truth — only observations. Hold out some observed pixels, reconstruct them, and score against the held-out values on :
Differentiating Through the Inner Solver¶
Because is produced by iterations, the outer gradient must backpropagate through them. Three strategies trade memory against exactness:
Table 1:Adjoint strategies for the learned inner solver. = number of inner steps.
| Strategy | Memory | Exact? | Notes |
|---|---|---|---|
| Recursive checkpointing | yes (unrolled) | standard backprop through all steps | |
| One-step | at convergence | stop_gradient on steps, differentiate the last Bolte et al., 2023 | |
| Implicit (IFT) | at convergence | differentiate the fixed point directly |
The implicit strategy differentiates the fixed point via the implicit function theorem, avoiding storage of the iterates entirely:
The Prior and Solver Families¶
The two learnable pieces are swappable. The prior :
| Prior | Form | When |
|---|---|---|
| Bilinear AE | standard 4DVarNet starting point | |
| Conv AE | periodic 1D/2D convolutional autoencoder | gridded periodic fields (e.g. Lorenz-96) |
| MLP AE | dense autoencoder | low-dimensional states (e.g. Lorenz-63) |
| Identity | pure obs-driven baseline → classical 4DVar | |
| Dynamical | via a forward model | physics-informed: regulariser = “consistency with the dynamics” |
The gradient modulator :
| Modulator | Form |
|---|---|
| ConvLSTM | ConvLSTM over space, recurrent over the iterations |
| MLP | dense, dimension-agnostic via flattening |
| Attention | self-attention over the spatial axis |
| Identity | — classical fixed-step descent |
When to Use 4DVarNet¶
A training set of pairs is available.
A classical is inadequate (nonlinear / multimodal state manifold).
Data-rich regime with smooth manifolds and a stable train↔test distribution.
Posterior¶
A Laplace approximation works as in the classical methods, but the Hessian at the MAP uses the learned-prior curvature, which may be poorly conditioned outside the training manifold. Treat 4DVarNet posteriors with caution and validate them; when full uncertainty matters, an ensemble method is safer.
Perspectives: Explicit vs. Implicit¶
The learned map can be framed three ways, which clarifies what gets penalised:
Explicit — a function predicts the truth, , with a model term added to the loss, .
Conditional explicit — the map reads observations too, ; no extra loss term is needed.
Implicit (fixed point) — the truth is the solution of , i.e. . This is the iterative solver view, and the regime where implicit differentiation shines.
A probabilistic head replaces the squared error with a Gaussian negative log-likelihood, predicting a mean and variance, , giving calibrated per-pixel uncertainty when trained with the NLL loss.
History¶
4DVarNet originated with Fablet et al. Fablet et al., 2021Fablet et al., 2023 for satellite sea-surface-height reconstruction; the original implementation was ocean-specific (PyTorch). The formulation here is framework-agnostic (jaxtyping + JAX autodiff + optax) and positions 4DVarNet as one method among many in the DA hierarchy — a learnable generalisation of, not a replacement for, classical variational assimilation.
- Fablet, R., Chapron, B., Drumetz, L., Mémin, E., Pannekoucke, O., & Rousseau, F. (2021). Learning Variational Data Assimilation Models and Solvers. Journal of Advances in Modeling Earth Systems, 13(10), e2021MS002572. 10.1029/2021MS002572
- Bolte, J., Pauwels, E., & Vaiter, S. (2023). One-Step Differentiation of Iterative Algorithms. Advances in Neural Information Processing Systems (NeurIPS).
- Fablet, R., Febvre, Q., & Chapron, B. (2023). Multimodal 4DVarNets for the Reconstruction of Sea Surface Dynamics from NADIR and Wide-Swath Altimetry. IEEE Transactions on Geoscience and Remote Sensing, 61, 1–14. 10.1109/TGRS.2023.3268370