vardax — Vision¶
What vardax is¶
vardax is the data assimilation inference layer for JAX. It provides the seven classical and modern analysis methods that turn observations into state estimates:
OptimalInterpolation— BLUE / OI, the closed-form linear-Gaussian analysis. The right tool when \(H\) is linear and \(B\), \(R\) are Gaussian.ThreeDVar— 3D variational analysis. Nonlinear observation operator, snapshot in time.StrongFourDVar— strong-constraint 4DVar. Control variable is the initial state \(x_0\); dynamics \(M_t\) are treated as exact.WeakFourDVar— weak-constraint 4DVar. Augmented control vector \((x_0, \eta_1, \ldots, \eta_T)\) admits model error.IncrementalFourDVar— the operational fast path of strong-constraint 4DVar. Gauss-Newton outer iterations on the full nonlinear cost, CG inner iterations on the linearised quadratic subproblem, control-variable transform for preconditioning.FourDVarNet— the learned variant of 4DVar. Prior \(\varphi_\theta\) replaces the Gaussian \(B\); the inner solver becomes a learned gradient modulator \(\Phi_\phi\).AmortizedPosterior— the direct head \(q_\phi(x \mid y)\). Real-time inference via conditional flow, score-based diffusion, or regression.
All seven are siblings, not parent–child relationships. They satisfy the
same pipekit_cycle.AnalysisStep protocol and compose with the same
observation operators, forward models, priors, and posterior adapters.
The DA hierarchy is horizontal; you pick the method that matches the
regime, not a family.
The single equation¶
Every analysis method in vardax is a special case of
| Method | \(T\) | Prior | Model error \(\boldsymbol{\eta}\) | Inner solver |
|---|---|---|---|---|
OptimalInterpolation |
0 | Gaussian \(B\), linear \(H\) | — | closed-form |
ThreeDVar |
0 | Gaussian \(B\), nonlinear \(H\) | — | iterative (optimistix) |
StrongFourDVar |
\(>0\) | Gaussian \(B\) | absent | iterative (optimistix) |
WeakFourDVar |
\(>0\) | Gaussian \(B\), \(Q\) | active | iterative (optimistix) |
IncrementalFourDVar |
\(>0\) | Gaussian \(B\) | absent | GN outer + CG inner |
FourDVarNet |
\(>0\) | learned \(\varphi_\theta\) | absent | learned \(\Phi_\phi\) |
AmortizedPosterior |
any | implicit in \(q_\phi\) | implicit | none (single forward pass) |
Each method specialises the unified cost differently, but the surface
that user code touches — model.as_analysis_step() — is identical.
Composable gradients¶
vardax does not re-implement automatic differentiation. Gradients flow through two upstream libraries:
diffrax— gradients through ODE integration viaRecursiveCheckpointAdjoint,BacksolveAdjoint(continuous adjoint),ForwardMode(forward sensitivity), orDirectAdjoint. This is how 4DVar's adjoint of \(M_t\) is computed.optimistix— gradients through the inner minimisation viaRecursiveCheckpointAdjoint,ImplicitAdjoint(IFT-based), orDirectAdjoint. This is how training gradients propagate throughFourDVarNet's solver.
Both are exposed as constructor slots on the Layer 2 classes:
model = StrongFourDVar(
forward=somax_model,
obs_op=AveragingKernel(...),
prior_mean=x_b, prior_cov_op=B_op, obs_cov_op=R_op,
minimiser=optimistix.GaussNewton(rtol=1e-5, atol=1e-5),
minimiser_adjoint=optimistix.ImplicitAdjoint(),
forward_adjoint=diffrax.BacksolveAdjoint(),
)
A user who wants memory-efficient operational 4DVar reaches for
diffrax.BacksolveAdjoint(). A user training FourDVarNet reaches for
optimistix.RecursiveCheckpointAdjoint() or a custom one-step adjoint.
Vardax owns the DA algorithm, not the differentiation strategy.
Three-layer architecture¶
Layer 2 Models (each satisfies pipekit_cycle.AnalysisStep)
OptimalInterpolation, ThreeDVar, StrongFourDVar, WeakFourDVar,
IncrementalFourDVar, FourDVarNet, AmortizedPosterior
↑
Layer 1 Components (eqx.Module operators)
priors, observation operators, gradient modulators (4DVarNet only),
cost functions, posterior adapters, minimiser adapters
↑
Layer 0 Primitives (pure JAX)
cost terms, control-variable transform, Laplace covariance,
adjoint wiring (diffrax + optimistix passthrough)
Users enter at the level appropriate to their task:
- Layer 0 — integrating vardax algorithms into another framework
- Layer 1 — building custom DA pipelines (new prior + AK obs op)
- Layer 2 — operational analysis or training
What vardax is NOT¶
(See boundaries.md for the full ownership map.)
- It does not define forward models. Use
somax(geophysics) orplumax(atmospheric transport / methane). Lorenz-63 / Lorenz-96 demos invardax._src.utils.dynamical_systemsare toy examples. - It does not own ensemble methods. Use
filterax. vardax exposes hooks (EnsembleCovarianceposterior adapter, ensemble batch dimension) but EnKF / EnKS / EnKI propagation lives elsewhere. - It does not own structured linear algebra. Use
gaussxfor Matérn factorisations, Kronecker / LowRank / BlockDiag operators. - It does not own data I/O. Use
georeader(sensors),coordax(labelled arrays). vardax consumesBatch*containers; how they're populated is upstream. - It does not own optimisers or ODE solvers. Use
optimistixanddiffrax. vardax composes them via the adjoint slots described above. - It does not own experiment orchestration. Use
pipekit-cyclefor DA cycles,pipekit-experimentfor run tracking. vardax satisfies the protocols.
The six-step research-to-operations cycle¶
vardax is engineered around the cycle that turns a new forward model into operational analysis:
(1) Physics forward (somax / plumax)
→ (2) Classical inference: OI / 3DVar / 4DVar — slow, exact
→ (3) Neural emulator of the forward (trained from Step 1) — fast surrogate
→ (4) Emulator-based inference: same vardax loop — 100–1000× faster
→ (5) Amortized predictor: y → posterior directly — sub-second
→ (6) Improve: swap any block; prior step is the oracle
Crucially: Step 2 uses the same vardax code as Step 4. The forward
model is swapped via the pipekit_cycle.ForwardModel protocol; the
analysis class doesn't know whether \(M_t\) is physics or an emulator. This
is what makes the cycle a cycle and not a rewrite.
The validation gates between steps (adjoint calibration, posterior agreement, simulation-based calibration) are part of the test suite, not just documentation. See chapter 14 in the math reference.
Why a JAX-native DA library¶
The DA community has excellent Fortran/C++ implementations of incremental
4DVar (IFS, GSI, ICON, WRFDA), excellent Python ensemble libraries
(DAPPER, PyOSSE), and excellent research code for 4DVarNet
(4dvarnet-starter). What's missing is a single library where:
- All classical methods are first-class. BLUE / OI is not a footnote; it's the first thing you should reach for when the regime is linear-Gaussian. Today this means dropping into NumPy / SciPy or rolling your own.
- The same code transitions research → operations. A
ThreeDVartrained in a notebook becomes the analysis step in apipekit-cyclepipeline without changes to the inversion logic. - Gradients are uniform.
diffraxadjoints handle the dynamics,optimistixadjoints handle the inner solver, vardax just composes them. No per-method bespoke differentiation. - Learned methods coexist with classical ones.
FourDVarNetandAmortizedPosteriorare siblings ofOptimalInterpolation, not a replacement. They use the same priors, observation operators, and posterior adapters.
This is the gap vardax fills.
Framework stack¶
| Package | Role | Required? |
|---|---|---|
equinox |
Module system | yes |
optax |
Outer-loop training optimiser | yes |
optimistix |
Inner-loop minimisers + adjoints | yes |
diffrax |
ODE integration + adjoints | yes |
lineax |
Linear solvers, AbstractLinearOperator |
yes |
gaussx |
Structured operators (Matérn, Kronecker, LowRank) | yes |
pipekit + pipekit-cycle |
Operator base + protocol contracts | yes |
jaxtyping |
Shape annotations | yes |
pipekit-jax |
JaxModelOp for weight persistence |
optional [persist] |
pipekit-experiment |
ModelRegistry, run tracking |
optional [persist] |
pipekit-train |
Loss / Callback / MetricWriter protocols |
optional [train] |
filterax |
Ensemble methods (EnKF / EnKS / EnKI) | optional [ensemble] |
coordax |
Coordinate-aware fields | optional [coords] |
numpyro |
Full Bayesian fallback | optional [mcmc] |
flax and jaxopt are removed by the equinox migration.