vardax¶
JAX-native data assimilation inference.
Formerly fourdvarjax — renamed to vardax.
Status — v0.4.0 design reference (forward-looking)
These docs describe the target API of vardax after the equinox
migration roadmap (Epics 0–13, see
boundaries). The shipped package
implements 4DVarNet only (single learned method, built on Flax
NNX); the seven-method DA hierarchy plus pipekit-cycle protocol
satisfaction is the design target. References to
vardax.models.*, vardax.obs_operators.*, vardax.adjoints.*,
the pipekit_cycle protocols, tests/test_pipekit_protocols.py,
and vardax._src.utils.validation describe the design target —
they are not yet runnable against the current package. Code
snippets are design pseudocode showing intended call sites.
(The package was previously published as fourdvarjax v0.1.x;
vardax is now the canonical name.)
vardax provides the seven classical and modern DA analysis methods as
peer pipekit_cycle.AnalysisStep
implementations:
| Class | Method | Use when |
|---|---|---|
OptimalInterpolation |
BLUE / OI — closed-form linear-Gaussian | Linear \(H\), Gaussian \(B\) / \(R\). The right default. |
ThreeDVar |
3D variational, nonlinear \(H\) | Snapshot inversion |
StrongFourDVar |
Strong-constraint 4DVar, control = \(x_0\) | Multi-time, exact dynamics |
WeakFourDVar |
Weak-constraint 4DVar, control = \((x_0, \boldsymbol{\eta})\) | Multi-time, model error active |
IncrementalFourDVar |
GN outer + CG inner + CVT | Operational fast path |
FourDVarNet |
Learned \(\varphi_\theta\) + learned \(\Phi_\phi\) | Learned variant of 4DVar |
AmortizedPosterior |
Direct \(q_\phi(x \mid y)\) head | Real-time / many-event regimes |
Gradients through dynamics and the inner minimiser are composed via
diffrax.AbstractAdjoint and
optimistix.AbstractAdjoint — no in-house
adjoint code. The Bolte 2023 one-step method appears as
vardax.adjoints.OneStepAdjoint, an optimistix.AbstractAdjoint
subclass targeting upstream contribution.
The single equation¶
Every analysis method in vardax is a special case of
Different methods specialise differently:
- \(T = 0\) + linear \(H\) + Gaussian \(B/R\) →
OptimalInterpolation(closed form) - \(T = 0\) + nonlinear \(H\) →
ThreeDVar - \(T > 0\), model-error term absent →
StrongFourDVar/IncrementalFourDVar - \(T > 0\), model-error term active →
WeakFourDVar - Learned \(\varphi_\theta\) replacing \(\|x - x_b\|^2_{B^{-1}}\) + learned inner solver →
FourDVarNet - Direct posterior head \(q_\phi(x \mid y)\) →
AmortizedPosterior
See the Problem Setting chapter for the full derivation.
Installation¶
vardax is not yet on PyPI; install from the checkout.
Quickstart — Optimal Interpolation¶
import gaussx as gx
import lineax as lx
from vardax.models import OptimalInterpolation
from vardax.obs_operators import LinearObs
model = OptimalInterpolation(
obs_op=LinearObs(H_mat=along_track_op),
prior_mean=climatology_ssh,
prior_cov_op=gx.MaternLinearOperator(coords, length_scale=100.0, sigma=0.1),
obs_cov_op=lx.DiagonalLinearOperator(altika_variances),
)
# Single forward pass — no iteration, no convergence criterion
x_star = model(batch)
posterior = model.posterior(batch)
Quickstart — Incremental 4DVar with control-variable transform¶
import diffrax as dfx
from vardax.models import IncrementalFourDVar
from vardax import IncrementalConfig
model = IncrementalFourDVar(
forward=somax_model,
obs_op=AveragingKernel(A=A, x_a=xa, h=h),
prior_mean=x_b,
prior_cov_op=gx.MaternLinearOperator(coords, length_scale=10.0, sigma=0.1),
obs_cov_op=lx.DiagonalLinearOperator(obs_uncertainty),
config=IncrementalConfig(n_outer=3, n_inner=20, cvt=True),
forward_adjoint=dfx.BacksolveAdjoint(), # constant memory through dynamics
)
x_star = model(batch)
posterior = model.posterior(batch)
Cycling any model through pipekit_cycle.DACycle¶
All seven methods satisfy pipekit_cycle.AnalysisStep via
.as_analysis_step() — the orchestration code is identical:
import pipekit_cycle as pc
da_cycle = pc.DACycle(
forward_model=somax_model,
obs_op=AveragingKernel(...),
analysis_step=model.as_analysis_step(), # any of the seven
obs_source=satellite_loader,
n_steps=n_assimilation_windows,
)
result, final_state = da_cycle(initial_state, pc.DAState(t=0.0, cycle_count=0))
Swap OptimalInterpolation for IncrementalFourDVar for FourDVarNet
by changing the analysis_step slot. Nothing else in the pipeline
changes.
Documentation¶
This site has two main sections:
- Mathematical Reference — 17 chapters covering the Bayesian foundation (1–3), each of the seven analysis methods (4–10), cross-cutting concerns (11–14), and end-to-end examples on Lorenz / SSH / methane (15–17).
- Design Docs — architecture, API contracts, ecosystem boundaries, and the decision log (D1–D16). The "why" behind the "what".
Ecosystem¶
vardax does not own forward models, optimisers, ODE solvers, ensemble methods, or structured linear algebra. It composes them:
| Concern | Owner |
|---|---|
| Geophysical forward models | somax |
| Atmospheric transport / methane | plumax |
| Optimisers + adjoints | optimistix |
| ODE integration + adjoints | diffrax |
| Linear solvers | lineax |
| Structured operators | gaussx |
| Ensemble methods | filterax |
| Operator composition + DA cycle protocols | pipekit + pipekit-cycle |
See Boundaries for the full ownership map.
License¶
MIT.