Skip to content

Layer 2 — Model Examples

End-to-end workflows for each of the seven peer analysis classes.


OptimalInterpolation — closed-form fast path

import gaussx as gx
import lineax as lx
from vardax.models import OptimalInterpolation
from vardax.obs_operators import LinearObs

model = OptimalInterpolation(
    obs_op=LinearObs(H_mat=along_track_op),     # must be linear
    prior_mean=climatology_ssh,
    prior_cov_op=gx.MaternLinearOperator(coords, length_scale=100.0, sigma=0.1),
    obs_cov_op=lx.DiagonalLinearOperator(altika_variances),
)

# Single forward pass — no iteration
x_star = model(batch)
posterior = model.posterior(batch)

Use when \(H\) is linear, \(B\) and \(R\) are Gaussian. The right default for SSH altimetry with along-track observations and a Matérn prior.


ThreeDVar — nonlinear, single time

import optimistix as optx
from vardax.models import ThreeDVar

model = ThreeDVar(
    obs_op=AveragingKernel(A=A, x_a=xa, h=h),   # nonlinear via h ⊙ x
    prior_mean=x_b,
    prior_cov_op=B_op, obs_cov_op=R_op,
    minimiser=optx.GaussNewton(rtol=1e-5, atol=1e-5),
    minimiser_adjoint=optx.ImplicitAdjoint(),
)
x_star = model(batch)

Use when \(H\) is nonlinear (averaging kernel, RTM, custom retrieval) but there are no dynamics — snapshot inversion.


StrongFourDVar — multi-time, exact dynamics

import diffrax as dfx
import optimistix as optx
from vardax.models import StrongFourDVar

model = StrongFourDVar(
    forward=somax_model,                      # pipekit_cycle.ForwardModel
    obs_op=MaskedIdentity(),
    prior_mean=x_b, prior_cov_op=B_op, obs_cov_op=R_op,
    minimiser=optx.NonlinearCG(rtol=1e-5, atol=1e-5),
    minimiser_adjoint=optx.ImplicitAdjoint(),
    forward_adjoint=dfx.RecursiveCheckpointAdjoint(),
)
x_star = model(batch)

For long assimilation windows, switch the dynamics adjoint:

forward_adjoint = dfx.BacksolveAdjoint()    # continuous adjoint, O(1) memory
# or
forward_adjoint = dfx.ForwardMode()         # forward sensitivity (small parameter dim)

WeakFourDVar — model-error-aware

from vardax.models import WeakFourDVar

model = WeakFourDVar(
    forward=somax_model,                      # M_t^free
    obs_op=obs_op,
    prior_mean=x_b, prior_cov_op=B_op, obs_cov_op=R_op,
    model_err_cov_op=Q_op,                    # Q for the η_t prior
    minimiser=optx.NonlinearCG(rtol=1e-5),
    forward_adjoint=dfx.RecursiveCheckpointAdjoint(),
)

# Returns both the analysis IC and the model-error trajectory
x_0_star, eta_star = model(batch)

Use when the model has known biases (climatological drift) and you need to estimate the time-varying \(\eta_t\) alongside the initial state.


IncrementalFourDVar — operational fast path

from vardax.models import IncrementalFourDVar
from vardax import IncrementalConfig

model = IncrementalFourDVar(
    forward=somax_model,
    obs_op=AveragingKernel(A=A, x_a=xa, h=h),
    prior_mean=x_b,
    prior_cov_op=gx.MaternLinearOperator(coords, length_scale=10.0, sigma=0.1),
    obs_cov_op=lx.DiagonalLinearOperator(obs_uncertainty),
    config=IncrementalConfig(n_outer=3, n_inner=20, cvt=True),
)

# Operational analysis
x_star = model(batch)

# Posterior reuses GN Hessian from last outer iteration
posterior = model.posterior(batch)

Use for production / operational 4DVar with structured Matérn priors. Functionally equivalent to StrongFourDVar but with the GN+CG+CVT operational pattern hard-wired.


FourDVarNet — learned 4DVar

import equinox as eqx
import optax
from vardax.models import FourDVarNet
from vardax import SolverConfig
from vardax.adjoints import OneStepAdjoint
from vardax.training import train_step

model = FourDVarNet(
    prior=ConvAEPrior(encoder=enc_2d, decoder=dec_2d),
    obs_op=MaskedIdentity(),
    grad_mod=ConvLSTMGradMod2D(hidden_dim=64),
    config=SolverConfig(n_steps=15, alpha=0.2),
    solver_adjoint=OneStepAdjoint(),
)

optimizer = optax.adam(1e-3)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

@eqx.filter_jit
def step(model, opt_state, batch):
    model, opt_state, loss = train_step(model, batch, optimizer, opt_state)
    return model, opt_state, loss

for epoch in range(100):
    for batch in dataloader:
        model, opt_state, loss = step(model, opt_state, batch)

After training, model.as_analysis_step() gives the operational interface for use in pipekit_cycle.DACycle.


AmortizedPosterior — direct head

from vardax.models import AmortizedPosterior
from vardax import AmortizedConfig

model = AmortizedPosterior(
    encoder=ConvObsEncoder(...),       # eqx.Module: (y, mask) → context
    head=ConditionalFlowHead(...),     # gauss_flows-based head
    config=AmortizedConfig(head_type="flow", n_samples=64),
)

# Training data from simulation
def sample_train_pair(key):
    x = prior_distribution.sample(key)
    y_clean = forward_model(x)
    y = y_clean + obs_noise.sample(key)
    return Batch2D(input=y, mask=quality_mask, target=x)

for batch in simulation_loader:
    model, opt_state, loss = train_step(model, batch, optimizer, opt_state)

# Inference: sub-second
x_map = model(batch)
samples = model.sample(batch, key, n=200)

Validation gates per Decision D12:

from vardax.utils.validation import (
    assert_posterior_agreement, simulation_based_calibration,
)

for val_batch in val_loader:
    p_amort = LaplaceCovariance()(model(val_batch),
                                    model.as_analysis_step(), val_batch)
    p_phys = LaplaceCovariance()(strong_4dvar(val_batch),
                                   strong_4dvar.as_analysis_step(), val_batch)
    assert_posterior_agreement(p_amort, p_phys, tolerance_sigma=1.0)

simulation_based_calibration(model, prior_distribution, forward_model, n_runs=200)

Cycling any model through pipekit_cycle.DACycle

All seven satisfy AnalysisStep via .as_analysis_step() — the cycling code is identical:

import pipekit_cycle as pc

da_cycle = pc.DACycle(
    forward_model=somax_model,
    obs_op=AveragingKernel(...),
    analysis_step=model.as_analysis_step(),   # any of the seven
    obs_source=satellite_loader,
    n_steps=n_assimilation_windows,
)

result, final_state = da_cycle(initial_state, pc.DAState(t=0.0, cycle_count=0))

The orchestration is decoupled from the inference algorithm — swap OptimalInterpolation for IncrementalFourDVar for FourDVarNet by changing the analysis_step slot. Nothing else in the pipeline changes.


Linear-Gaussian agreement test (Decision D14 invariant)

from vardax.utils.validation import assert_methods_agree

# All seven should agree in the linear-Gaussian limit
batch = make_linear_gaussian_batch()

oi = OptimalInterpolation(linear_obs_op, x_b, B_op, R_op)
threedvar = ThreeDVar(linear_obs_op, x_b, B_op, R_op, optx.GaussNewton(rtol=1e-7))
strong = StrongFourDVar(static_forward, linear_obs_op, x_b, B_op, R_op,
                         optx.NonlinearCG(rtol=1e-7))
incremental = IncrementalFourDVar(static_forward, linear_obs_op, x_b, B_op, R_op,
                                    IncrementalConfig(n_outer=5, n_inner=50))
fourdvarnet = FourDVarNet(IdentityPrior(), linear_obs_op,
                           IdentityGradMod(alpha=0.05),
                           SolverConfig(n_steps=200))

assert_methods_agree(
    {"oi": oi, "3dvar": threedvar, "strong": strong,
     "incremental": incremental, "fourdvarnet": fourdvarnet},
    batch, atol=1e-3,
)

This is the canonical correctness baseline — every method, classical and learned, must produce the same posterior mean in the linear-Gaussian limit.