Every method up to now — OI, 3DVar, strong / weak / incremental 4DVar — solves a fresh minimisation per event. A new observation triggers a new analysis: fine for retrospective studies, too slow for real-time alerting. Amortized inference instead learns a function that maps observations directly to posteriors in a single forward pass. The training cost amortises over many events; the per-event cost collapses to one network evaluation.
This is not a departure from variational data assimilation — it is its endpoint. The variational cost we have minimised all along is the posterior’s energy; amortized inference learns to produce that posterior for all observations at once.
From Per-Event Optimisation to a Learned Posterior Map¶
Recall the variational cost shared by 3DVar and 4DVar, . It is exactly the negative log of the joint, so the posterior is a Gibbs density in :
The methods so far all chase the mode of (1):
one optimisation per observation. Variational inference instead approximates the whole posterior with a tractable , and amortized variational inference trains a single that does this for every — the amortized regime of the inference schema, with the conditioning network playing the role of the variational parameters.
The amortization spectrum¶
The DA methods form a ladder of how much is moved from per-event solve-time into trained-once weights:
Table 1:How much each method amortises. = inner iterations per event.
| Method | Per-event work | Amortised (trained once) |
|---|---|---|
| OI / 3DVar / 4DVar | full optimisation | nothing |
| Incremental 4DVar | Gauss–Newton + CG | nothing (just faster) |
| 4DVarNet | learned inner steps | the prior and solver |
| Amortized inference | a single forward pass | the entire posterior map |
Read top-to-bottom, amortized inference is the limit of 4DVarNet: there is no inner optimisation left at all — the network emits the posterior directly.
The Variational Objective¶
Train the conditional density to match the exact posterior in expectation over observations:
For each , that KL is one ELBO away from the evidence,
so minimising the KL maximises the ELBO. Substituting from (1) gives the clean variational-DA reading:
Simulation-Based Training¶
When the prior and the forward model are samplable — exactly the simulator track of the generative model — the training set comes from simulation: draw a truth, push it through the physics, add noise, and learn to invert it.
This is simulation-based inference Cranmer et al., 2020: every pair has known ground truth and the training distribution is fully controlled — but the deployment distribution must match it, or is miscalibrated (the central risk, below).
import jax
def simulate_pair(key, prior_sample, forward, obs_noise):
"""One synthetic (x, y) pair from the generative model."""
kx, ke = jax.random.split(key)
x = prior_sample(kx) # x ~ p(x)
y = forward(x) + obs_noise(ke) # y = H(M(x)) + ε
return x, yThe Posterior Family¶
The conditioning network (an encoder over the gappy observations) feeds a head that defines . Three choices trade expressiveness against cost:
An invertible map from a Gaussian base to the posterior, conditioned on the observation context — exact density and exact samples Papamakarios et al., 2021:
Good to moderate state dimension (); higher-dimensional flows are open research.
Learn the score of a noise-perturbed posterior at scale and sample by a reverse SDE Song et al., 2021 — sampling-only (no exact density), but high capacity for multimodal posteriors that flows struggle with:
The reverse SDE integrates with diffrax; disconnected
posterior components are representable, which the Laplace approximation cannot do.
Predict the posterior moments directly — cheapest, restricted to unimodal Gaussians:
A learned drop-in for the Laplace covariance around an incremental-4DVar MAP, when the posterior is known to be Gaussian.
Training Objective¶
Flow and regression heads admit a tractable density, so train by maximum likelihood on simulated pairs; score heads use denoising score matching:
Either way it is a standard minibatch loop over simulated data — the same optax training machinery as 4DVarNet:
import jax
import jax.numpy as jnp
import optax
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(params) # params = (φ, ψ): head + encoder
@jax.jit
def train_step(params, opt_state, x, y):
# maximum likelihood: L = -E log q_φ(x | y)
loss, grads = jax.value_and_grad(lambda p: -jnp.mean(log_q(p, x, y)))(params)
updates, opt_state = optimizer.update(grads, opt_state, params)
return optax.apply_updates(params, updates), opt_state, lossCalibration: the Hard Part¶
Amortized inference is dangerous precisely because it looks confident. A tight centred in the wrong place gives a confident wrong answer with no warning — the opposite failure mode to a slow-but-honest solver. Promote a head to operational use only after three gates pass:
Table 2:Calibration gates before deploying an amortized head.
| Gate | Check | Why it matters |
|---|---|---|
| Posterior agreement | the mode must match a slower 4DVar oracle | |
| Adjoint calibration | the sensitivity to data must match the physics | |
| SBC | rank of in samples is uniform Talts et al., 2018 | detects over-/under-confidence |
The adjoint-calibration gate is the one that keeps amortized inference inside DA rather than outside it: the learned map’s Jacobian must match the physics-based tangent-linear / adjoint of the true inverse problem. Matching the oracle MAP on the training set is not enough — without matching gradients, the head extrapolates badly the moment the observation distribution shifts.
The Oracle Pattern¶
This is what ties amortized inference back to the rest of the DA hierarchy. The right design is both, in sequence:
Run a solver-based method (incremental / strong 4DVar) to produce oracle posteriors on a representative sample.
Train the amortized head against the oracle (and/or by simulation).
Validate with the three gates.
Deploy the head for real-time, high-throughput work.
The physics-based solvers are the teachers; the amortized head is the student that runs in milliseconds once it has earned trust. Amortized inference does not replace variational DA — it compiles it.
When It Helps, and the Trade-offs¶
| Regime | Amortized helps? |
|---|---|
| Single retrospective analysis | no — a solver is fine |
| Real-time alerts | yes — sub-second |
| Many independent events (catalogue reprocessing) | yes — amortise training |
| Same forward, varying observations | yes — train once, infer times |
| Each event has a different forward | no — would need retraining |
| Multimodal posterior | yes — flow/score heads represent it; Laplace can’t |
| Aspect | Solver-based | Amortized |
|---|---|---|
| Cost / event | high (iterative) | low (one pass) |
| Training cost | none–low | high |
| Generalisation | strong (physics) | limited (training distribution) |
| Multimodality | hard (Laplace) | easy (score) |
| Adjoint correctness | exact (autodiff) | approximate (gated) |
| Posterior shape | Gaussian | flexible |
| Failure mode | slow | confident-but-wrong |
- Cranmer, K., Brehmer, J., & Louppe, G. (2020). The Frontier of Simulation-Based Inference. Proceedings of the National Academy of Sciences, 117(48), 30055–30062. 10.1073/pnas.1912789117
- Papamakarios, G., Nalisnick, E., Rezende, D. J., Mohamed, S., & Lakshminarayanan, B. (2021). Normalizing Flows for Probabilistic Modeling and Inference. Journal of Machine Learning Research, 22(57), 1–64.
- Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2021). Score-Based Generative Modeling through Stochastic Differential Equations. International Conference on Learning Representations (ICLR).
- Talts, S., Betancourt, M., Simpson, D., Vehtari, A., & Gelman, A. (2018). Validating Bayesian Inference Algorithms with Simulation-Based Calibration. arXiv Preprint arXiv:1804.06788.