Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Amortized Inference

CSIC
UCM
IGEO

Every method up to now — OI, 3DVar, strong / weak / incremental 4DVar — solves a fresh minimisation per event. A new observation triggers a new analysis: fine for retrospective studies, too slow for real-time alerting. Amortized inference instead learns a function qϕ(xy)q_{\boldsymbol{\phi}}(\mathbf{x} \mid \mathbf{y}) that maps observations directly to posteriors in a single forward pass. The training cost amortises over many events; the per-event cost collapses to one network evaluation.

This is not a departure from variational data assimilation — it is its endpoint. The variational cost we have minimised all along is the posterior’s energy; amortized inference learns to produce that posterior for all observations at once.

From Per-Event Optimisation to a Learned Posterior Map

Recall the variational cost shared by 3DVar and 4DVar, J(x)=12xxbB12+12tytHt(Mt(x))Rt12J(\mathbf{x}) = \tfrac{1}{2}\|\mathbf{x}-\mathbf{x}_b\|^2_{\mathbf{B}^{-1}} + \tfrac{1}{2}\sum_t \|\mathbf{y}_t - H_t(M_t(\mathbf{x}))\|^2_{\mathbf{R}_t^{-1}}. It is exactly the negative log of the joint, so the posterior is a Gibbs density in JJ:

p(xy)    p(yx)p(x)  =  exp ⁣(J(x)).p(\mathbf{x} \mid \mathbf{y}) \;\propto\; p(\mathbf{y} \mid \mathbf{x})\, p(\mathbf{x}) \;=\; \exp\!\big(-J(\mathbf{x})\big).

The methods so far all chase the mode of (1):

x=arg minxJ(x)=arg maxxp(xy),\mathbf{x}^\star = \operatorname*{arg\,min}_{\mathbf{x}} \, J(\mathbf{x}) = \operatorname*{arg\,max}_{\mathbf{x}} \, p(\mathbf{x} \mid \mathbf{y}),

one optimisation per observation. Variational inference instead approximates the whole posterior with a tractable qq, and amortized variational inference trains a single qϕ(xy)q_{\boldsymbol{\phi}}(\mathbf{x}\mid\mathbf{y}) that does this for every y\mathbf{y} — the amortized regime of the inference schema, with the conditioning network playing the role of the variational parameters.

The amortization spectrum

The DA methods form a ladder of how much is moved from per-event solve-time into trained-once weights:

Table 1:How much each method amortises. KK = inner iterations per event.

MethodPer-event workAmortised (trained once)
OI / 3DVar / 4DVarfull optimisationnothing
Incremental 4DVarGauss–Newton + CGnothing (just faster)
4DVarNetKK learned inner stepsthe prior ϕθ\boldsymbol{\phi}_{\boldsymbol{\theta}} and solver Φψ\boldsymbol{\Phi}_{\boldsymbol{\psi}}
Amortized inferencea single forward passthe entire posterior map qϕ(xy)q_{\boldsymbol{\phi}}(\mathbf{x}\mid\mathbf{y})

Read top-to-bottom, amortized inference is the K0K \to 0 limit of 4DVarNet: there is no inner optimisation left at all — the network emits the posterior directly.

The Variational Objective

Train the conditional density to match the exact posterior in expectation over observations:

ϕ=arg minϕ  Eyp(y)KL ⁣(qϕ(y)p(y)).\boldsymbol{\phi}^\star = \operatorname*{arg\,min}_{\boldsymbol{\phi}} \; \mathbb{E}_{\mathbf{y} \sim p(\mathbf{y})} \, \mathrm{KL}\!\big( q_{\boldsymbol{\phi}}(\cdot \mid \mathbf{y}) \,\|\, p(\cdot \mid \mathbf{y}) \big).

For each y\mathbf{y}, that KL is one ELBO away from the evidence,

KL(qϕ(y)p(y))=logp(y)Eqϕ ⁣[logp(x,y)logqϕ(xy)]ELBO(ϕ;y),\mathrm{KL}\big( q_{\boldsymbol{\phi}}(\cdot \mid \mathbf{y}) \,\|\, p(\cdot \mid \mathbf{y}) \big) = \log p(\mathbf{y}) - \underbrace{\mathbb{E}_{q_{\boldsymbol{\phi}}}\!\big[ \log p(\mathbf{x},\mathbf{y}) - \log q_{\boldsymbol{\phi}}(\mathbf{x}\mid\mathbf{y}) \big]}_{\text{ELBO}(\boldsymbol{\phi};\,\mathbf{y})},

so minimising the KL maximises the ELBO. Substituting logp(x,y)=J(x)+const\log p(\mathbf{x},\mathbf{y}) = -J(\mathbf{x}) + \text{const} from (1) gives the clean variational-DA reading:

ELBO(ϕ;y)=Eqϕ ⁣[J(x)]+H ⁣[qϕ(y)]+const.\mathrm{ELBO}(\boldsymbol{\phi};\,\mathbf{y}) = -\,\mathbb{E}_{q_{\boldsymbol{\phi}}}\!\big[ J(\mathbf{x}) \big] + \mathcal{H}\!\big[ q_{\boldsymbol{\phi}}(\cdot\mid\mathbf{y}) \big] + \text{const}.

Simulation-Based Training

When the prior p(x)p(\mathbf{x}) and the forward model are samplable — exactly the simulator track of the generative model — the training set comes from simulation: draw a truth, push it through the physics, add noise, and learn to invert it.

xp(x),y=H(M(x))+ε,train qϕ to recover x from y.\mathbf{x} \sim p(\mathbf{x}), \qquad \mathbf{y} = H(M(\mathbf{x})) + \boldsymbol{\varepsilon}, \qquad \text{train } q_{\boldsymbol{\phi}} \text{ to recover } \mathbf{x} \text{ from } \mathbf{y}.

This is simulation-based inference Cranmer et al., 2020: every pair has known ground truth and the training distribution is fully controlled — but the deployment distribution must match it, or qϕq_{\boldsymbol{\phi}} is miscalibrated (the central risk, below).

import jax

def simulate_pair(key, prior_sample, forward, obs_noise):
    """One synthetic (x, y) pair from the generative model."""
    kx, ke = jax.random.split(key)
    x = prior_sample(kx)                       # x ~ p(x)
    y = forward(x) + obs_noise(ke)             # y = H(M(x)) + ε
    return x, y

The Posterior Family

The conditioning network cψ(y)c_{\boldsymbol{\psi}}(\mathbf{y}) (an encoder over the gappy observations) feeds a head that defines qϕq_{\boldsymbol{\phi}}. Three choices trade expressiveness against cost:

Conditional flow
Score-based diffusion
Gaussian regression

An invertible map from a Gaussian base to the posterior, conditioned on the observation context — exact density and exact samples Papamakarios et al., 2021:

x=fϕ(z;cψ(y)),zN(0,I),logqϕ(xy)=logpz ⁣(fϕ1(x))logdetfϕz.\mathbf{x} = f_{\boldsymbol{\phi}}(\mathbf{z}; c_{\boldsymbol{\psi}}(\mathbf{y})), \quad \mathbf{z} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}), \qquad \log q_{\boldsymbol{\phi}}(\mathbf{x}\mid\mathbf{y}) = \log p_{\mathbf{z}}\!\big(f_{\boldsymbol{\phi}}^{-1}(\mathbf{x})\big) - \log\left| \det \frac{\partial f_{\boldsymbol{\phi}}}{\partial \mathbf{z}} \right|.

Good to moderate state dimension (104\sim 10^4); higher-dimensional flows are open research.

Training Objective

Flow and regression heads admit a tractable density, so train by maximum likelihood on simulated pairs; score heads use denoising score matching:

LMLE(ϕ)=E(x,y)psimlogqϕ(xy),LDSM(ϕ)=Et,x,y,εsϕ(xt,ty)xtlogpt(xtx)2.\mathcal{L}_{\text{MLE}}(\boldsymbol{\phi}) = -\,\mathbb{E}_{(\mathbf{x},\mathbf{y}) \sim p_{\text{sim}}} \log q_{\boldsymbol{\phi}}(\mathbf{x}\mid\mathbf{y}), \qquad \mathcal{L}_{\text{DSM}}(\boldsymbol{\phi}) = \mathbb{E}_{t,\mathbf{x},\mathbf{y},\boldsymbol{\varepsilon}} \big\| s_{\boldsymbol{\phi}}(\mathbf{x}_t, t \mid \mathbf{y}) - \nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}_t \mid \mathbf{x}) \big\|^2.

Either way it is a standard minibatch loop over simulated data — the same optax training machinery as 4DVarNet:

import jax
import jax.numpy as jnp
import optax

optimizer = optax.adam(1e-3)
opt_state = optimizer.init(params)            # params = (φ, ψ): head + encoder

@jax.jit
def train_step(params, opt_state, x, y):
    # maximum likelihood:  L = -E log q_φ(x | y)
    loss, grads = jax.value_and_grad(lambda p: -jnp.mean(log_q(p, x, y)))(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    return optax.apply_updates(params, updates), opt_state, loss

Calibration: the Hard Part

Amortized inference is dangerous precisely because it looks confident. A tight qϕq_{\boldsymbol{\phi}} centred in the wrong place gives a confident wrong answer with no warning — the opposite failure mode to a slow-but-honest solver. Promote a head to operational use only after three gates pass:

Table 2:Calibration gates before deploying an amortized head.

GateCheckWhy it matters
Posterior agreementxamortxoracle/σpost1\lvert \mathbf{x}^\star_{\text{amort}} - \mathbf{x}^\star_{\text{oracle}} \rvert / \sigma_{\text{post}} \le 1the mode must match a slower 4DVar oracle
Adjoint calibrationxamort/yxoracle/yop/xoracle/yop<0.05\lVert \partial \mathbf{x}^\star_{\text{amort}}/\partial \mathbf{y} - \partial \mathbf{x}^\star_{\text{oracle}}/\partial \mathbf{y} \rVert_{\text{op}} / \lVert \partial \mathbf{x}^\star_{\text{oracle}}/\partial \mathbf{y} \rVert_{\text{op}} < 0.05the sensitivity to data must match the physics
SBCrank of x(j)\mathbf{x}^{(j)} in qϕq_{\boldsymbol{\phi}} samples is uniform Talts et al., 2018detects over-/under-confidence

The adjoint-calibration gate is the one that keeps amortized inference inside DA rather than outside it: the learned map’s Jacobian x/y\partial \mathbf{x}^\star / \partial \mathbf{y} must match the physics-based tangent-linear / adjoint of the true inverse problem. Matching the oracle MAP on the training set is not enough — without matching gradients, the head extrapolates badly the moment the observation distribution shifts.

The Oracle Pattern

This is what ties amortized inference back to the rest of the DA hierarchy. The right design is both, in sequence:

  1. Run a solver-based method (incremental / strong 4DVar) to produce oracle posteriors on a representative sample.

  2. Train the amortized head against the oracle (and/or by simulation).

  3. Validate with the three gates.

  4. Deploy the head for real-time, high-throughput work.

The physics-based solvers are the teachers; the amortized head is the student that runs in milliseconds once it has earned trust. Amortized inference does not replace variational DA — it compiles it.

When It Helps, and the Trade-offs

When it helps
Solver vs. amortized
RegimeAmortized helps?
Single retrospective analysisno — a solver is fine
Real-time alertsyes — sub-second
Many independent events (catalogue reprocessing)yes — amortise training
Same forward, varying observationsyes — train once, infer NN times
Each event has a different forwardno — would need retraining
Multimodal posterioryes — flow/score heads represent it; Laplace can’t
References
  1. Cranmer, K., Brehmer, J., & Louppe, G. (2020). The Frontier of Simulation-Based Inference. Proceedings of the National Academy of Sciences, 117(48), 30055–30062. 10.1073/pnas.1912789117
  2. Papamakarios, G., Nalisnick, E., Rezende, D. J., Mohamed, S., & Lakshminarayanan, B. (2021). Normalizing Flows for Probabilistic Modeling and Inference. Journal of Machine Learning Research, 22(57), 1–64.
  3. Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2021). Score-Based Generative Modeling through Stochastic Differential Equations. International Conference on Learning Representations (ICLR).
  4. Talts, S., Betancourt, M., Simpson, D., Vehtari, A., & Gelman, A. (2018). Validating Bayesian Inference Algorithms with Simulation-Based Calibration. arXiv Preprint arXiv:1804.06788.