vardax — Design Decisions¶
Each decision is referenced by ID throughout the design docs and math reference. New in v0.5.0: D17 (latent DA peer family). New in v0.4.0: D14, D15, D16. Revised: D8 (seven peers), D11 (rename), D6 (clarified upstream adjoint contribution).
Index¶
| ID | Title | Layer |
|---|---|---|
| D1 | Equinox over Flax NNX | Foundation |
| D2 | Protocol-driven extensibility | Layer 1 |
| D3 | Dimensional inheritance over duplication | Layer 2 |
| D4 | Nested module configuration | Layer 1 |
| D5 | Training step as library code, training loop as example | Layer 2 |
| D6 | optimistix for novel solvers — contribute upstream | Layer 0 |
| D7 | Demo-quality dynamical priors (forwards live elsewhere) | Boundary |
| D8 | Direct pipekit-cycle protocol satisfaction (seven peer AnalysisStep classes) | Layer 1/2 |
| D9 | Averaging kernel + multi-instrument as first-class | Layer 1 |
| D10 | Posterior export adapter pattern | Layer 1 |
| D11 | IncrementalFourDVar as the operational fast path of StrongFourDVar |
Layer 2 |
| D12 | Six-step inference cycle as testing scaffold | Methodology |
| D13 | pipekit-jax JaxModelOp + ModelRegistry for persistence |
Layer 1 |
| D14 | DA hierarchy as horizontal peer classes | Layer 2 |
| D15 | Lean on optimistix / diffrax adjoints, not in-house grad modes |
Layer 0 |
| D16 | BLUE / OI as a first-class method | Layer 2 |
| D17 | Latent DA as a first-class peer family | Layer 2 |
D1: Equinox over Flax NNX¶
All components are eqx.Module, not nnx.Module. The migration enables
direct use of lineax, optimistix, diffrax, and pipekit-jax
without adapter layers, and gives a simpler immutable pytree model than
Flax NNX's mutable state.
Apply: nnx.Module → eqx.Module, nnx.Linear → eqx.nn.Linear,
nnx.Optimizer → optax + eqx.filter_value_and_grad.
D2: Protocol-driven extensibility¶
Vardax defines vardax-specific runtime-checkable protocols for Prior,
GradModulator, CostFunction, PosteriorAdapter, and Minimiser.
Where pipekit-cycle already names the contract (ForwardModel,
ObservationOperator, AnalysisStep), vardax re-exports and satisfies
those directly — no parallel Abstract* hierarchy (see D8).
D3: Dimensional inheritance over duplication¶
Each Layer 2 class holds a dimension-agnostic algorithm; *1D, *2D,
*3D subclasses set defaults (conv kernel shape, ConvLSTM layout). The
algorithm is identical regardless of spatial dimension; only tensor
shapes differ.
D4: Nested module configuration¶
SolverConfig, IncrementalConfig, AmortizedConfig are themselves
eqx.Module. Serialisable, JIT-friendly, round-trip through
pipekit-experiment.ModelRegistry.
D5: Training step as library code, training loop as example¶
train_step and eval_step ship as library functions — they encode the
correctness-critical differentiation through the inner solver. fit()
is example-only. Production training composes vardax train_step with
pipekit-train.Loss / Callback / MetricWriter protocols.
D6: optimistix for novel solvers — contribute upstream¶
Updated in v0.4.0. Use existing optimistix.AbstractMinimiser
subclasses for standard minimisation. The unique contribution from the
4DVarNet line of work — the learned ConvLSTM gradient step, and the
one-step differentiation of Bolte et al. (2023) — should be packaged as
optimistix.AbstractMinimiser and optimistix.AbstractAdjoint
subclasses respectively, and contributed upstream rather than
maintained as parallel implementations in vardax.
Apply: vardax ships vardax._src.adjoints.one_step.OneStepAdjoint(
optimistix.AbstractAdjoint) initially, with the goal of upstreaming
once stable. Same for any learned-step minimiser.
D7: Demo-quality dynamical priors (forwards live elsewhere)¶
L63 / L96 / SWM simulators are demo utilities in
vardax._src.utils.dynamical_systems. Production forward models come
from somax (geophysics) or plumax (atmospheric transport / methane).
Vardax owns inference, not physics.
D8: Direct pipekit-cycle protocol satisfaction (seven peer AnalysisStep classes)¶
Updated in v0.4.0 with the new class hierarchy.
Vardax classes directly satisfy pipekit_cycle.{ForwardModel,
ObservationOperator, AnalysisStep} protocols. Seven Layer 2 classes
implement AnalysisStep as peers, not a parent–child family:
OptimalInterpolation(closed-form BLUE / OI)ThreeDVarStrongFourDVarWeakFourDVarIncrementalFourDVarFourDVarNetAmortizedPosterior
Each exposes .as_analysis_step() returning a callable matching
(forecast, obs, *, obs_op, obs_err_cov) → analysis. The training
interface model(batch) → x_recon is preserved on the learned variants
(FourDVarNet, AmortizedPosterior).
Apply:
- vardax.protocols re-exports the three pipekit-cycle protocols.
- Vardax-specific protocols (Prior, GradModulator, CostFunction,
PosteriorAdapter, Minimiser) are added for concepts pipekit-cycle
doesn't name.
- tests/test_pipekit_protocols.py enforces isinstance(...) checks on
every public model and obs operator (added as part of Epic 1).
pipekit-cycle is a required dependency. It has zero third-party deps,
so the cost is minimal.
D9: Averaging kernel + multi-instrument as first-class¶
AveragingKernel(A, x_a, h) and MultiInstrumentFusion(registry) are
part of the day-one vardax.obs_operators package. Skipping the
averaging kernel is the most common cause of bias in operational
satellite inversions.
MultiInstrumentFusion returns dict[str, Array] natively and
satisfies pipekit_cycle.ObservationOperator via its
.to_observation_operator() adapter (block-diagonal flattening).
D10: Posterior export adapter pattern¶
Every Layer 2 model emits a Posterior(mean, cov, samples, provenance)
via a PosteriorAdapter. GaussianMarkLikelihood serialises to
JSON-friendly form for downstream population models (Tier V TMTPP,
hierarchical Bayesian inversions).
Adapter selection by inference family:
| Family | Default adapter |
|---|---|
OptimalInterpolation |
direct (closed-form posterior cov) |
ThreeDVar, StrongFourDVar, WeakFourDVar |
LaplaceCovariance at MAP |
IncrementalFourDVar |
GaussNewtonHessian (reuses last outer Hessian) |
FourDVarNet |
LaplaceCovariance |
AmortizedPosterior |
direct sampling (Posterior.samples) |
Hybrid EnVar (filterax) |
EnsembleCovariance |
D11: IncrementalFourDVar as the operational fast path of StrongFourDVar¶
Renamed in v0.4.0 (was IncrementalVarDA).
IncrementalFourDVar is functionally a StrongFourDVar with a
specialised inner solver: Gauss-Newton outer iterations on the full
nonlinear cost, CG / Lanczos inner iterations on the linearised
quadratic subproblem, control-variable transform via
gaussx.MaternLinearOperator.half() for preconditioning.
For users:
- Use StrongFourDVar with minimiser=optimistix.NonlinearCG(...) for
general 4DVar problems.
- Use IncrementalFourDVar when you want the operational pattern — fewer
outer iterations, exact tangent-linear via jax.linearize, and CG
inner with gaussx-preconditioned Hessian.
The two share ForwardModel, ObservationOperator, and Posterior
contracts. The choice between them is performance-driven, not
semantic — both solve the same minimisation problem.
D12: Six-step inference cycle as testing scaffold¶
The cycle (physics → MAP / MCMC → emulator → faster inference →
amortized → improve) is the validation methodology. Step N validates
against Step N–1 as oracle. Hard gates (adjoint calibration, posterior
agreement, simulation-based calibration) are part of
tests/test_six_step_validation.py, not just documentation.
vardax._src.utils.validation exposes assert_posterior_agreement,
assert_adjoint_calibrated, simulation_based_calibration.
D13: pipekit-jax JaxModelOp + ModelRegistry for persistence¶
Trained models are persisted via pipekit-jax.JaxModelOp (weight
serialisation) + pipekit-experiment.ModelRegistry (content-addressed
storage). Vardax provides thin shortcuts (vardax.persist.save,
vardax.persist.load) but does not define its own persistence format.
pipekit-jax and pipekit-experiment are [persist] extras.
D14: DA hierarchy as horizontal peer classes¶
New in v0.4.0.
The previous v0.3 design organised methods into three families —
VarDANet* (learned), IncrementalVarDA* (operational), AmortizedVarDA*
(direct). This was 4DVarNet-centric: it treated the learned method as
the canonical case and the classical methods as historical variants.
The v0.4 design treats classical and learned DA methods as siblings.
Seven peer classes, all implementing pipekit_cycle.AnalysisStep, none
inheriting from any of the others:
| Method | Use when |
|---|---|
OptimalInterpolation |
Linear \(H\), Gaussian \(B\), \(R\). Static field. The right default. |
ThreeDVar |
Nonlinear \(H\), single time. Snapshot inversion. |
StrongFourDVar |
Multi-time, control = \(x_0\), model treated as exact. |
WeakFourDVar |
Multi-time, control = \((x_0, \eta_1, \ldots, \eta_T)\). Model-error-aware. |
IncrementalFourDVar |
Operational fast path: GN outer + CG inner + CVT (= StrongFourDVar). |
FourDVarNet |
Learned prior + learned grad modulator. Research, data-rich regimes. |
AmortizedPosterior |
Direct \(q_\phi(x \mid y)\). Real-time / many-event regimes. |
Why horizontal.
- BLUE / OI is the right tool when the regime allows it. Forcing it
through
ThreeDVarwith an "isLinear" branch hides the closed-form fast path and confuses the reader about complexity. - Strong- and weak-constraint 4DVar are different problems (different
control vectors, different cost terms), not configurations of a
single class. Burying them under a
mode: Literal["strong", "weak"]flag obscures the structural difference. FourDVarNetis one variant of 4DVar with learned components — but it is not the parent of classical 4DVar. Treating it as such was a category error inherited from the 4DVarNet-starter codebase.
Apply.
- All seven classes live as siblings in
vardax/_src/models/. - All seven implement
.as_analysis_step()returning the same protocol shape. - Naming convention: spelled-out names (
StrongFourDVar, notVarDA4DStrongorFourDVarStrong). NoVarDAprefix on the class names —vardaxis the package, the classes are named for the method. - Math reference chapters 4–10 cover one method each, with chapter 1 (problem setting), chapter 2 (observation model), and chapter 3 (dynamical model) supplying the shared foundation.
D15: Lean on optimistix / diffrax adjoints, not in-house grad modes¶
New in v0.4.0.
The v0.3 design defined grad_mode: Literal["unrolled", "one_step",
"implicit"] on FourDVarNet* and hand-rolled the three corresponding
differentiation paths in vardax._src.solver. This was useful as a
proof of concept but is the wrong long-term design:
optimistix.AbstractAdjointalready providesRecursiveCheckpointAdjoint,ImplicitAdjoint,DirectAdjointfor gradients through minimisers.diffrax.AbstractAdjointalready providesRecursiveCheckpointAdjoint,BacksolveAdjoint(continuous adjoint),ForwardMode,DirectAdjointfor gradients through ODE integration.
Vardax exposes both as constructor slots and delegates the actual gradient computation upstream:
class StrongFourDVar(eqx.Module):
...
minimiser: optimistix.AbstractMinimiser
minimiser_adjoint: optimistix.AbstractAdjoint = ImplicitAdjoint()
forward_adjoint: diffrax.AbstractAdjoint = RecursiveCheckpointAdjoint()
A user who wants the operational memory profile:
A user training FourDVarNet:
minimiser_adjoint = optimistix.RecursiveCheckpointAdjoint()
# or — once contributed upstream:
minimiser_adjoint = vardax.adjoints.OneStepAdjoint() # Bolte et al. 2023
Apply.
- Drop
GradModeenum from the public API. - Drop
unrolled_solve,one_step_solve,implicit_solveas separate Layer 0 primitives. Their logic moves into the appropriateoptimistix.AbstractAdjointsubclasses. - Bolte 2023 "one-step" differentiation is implemented as
vardax._src.adjoints.one_step.OneStepAdjoint(optimistix.AbstractAdjoint)with the goal of upstreaming (continues D6). - The continuous adjoint via
diffrax.BacksolveAdjointbecomes the recommended default for memory-constrained 4DVar with long assimilation windows.
This decision composes with D6: vardax's novel solver / adjoint
contributions become upstream optimistix / diffrax subclasses, not
parallel in-tree code.
D16: BLUE / OI as a first-class method¶
New in v0.4.0.
Optimal interpolation (BLUE) — the closed-form linear-Gaussian analysis
— is the first method vardax users should reach for when the regime
permits. It is not a degenerate case of 3DVar to be invoked via an
auto-detected branch; it is its own OptimalInterpolation class with
its own fast path.
The analysis is
with posterior covariance
When \(B\), \(R\), \(H\) are structured (Matérn, Kronecker, AveragingKernel),
gaussx solves these expressions efficiently without materialising
dense matrices. The whole analysis runs in a handful of structured
mat-vec operations.
Why first-class.
- The closed form is exact; no iteration, no convergence criterion, no adjoint. When applicable it is faster and more accurate than 3DVar's iterative solution.
- The posterior covariance comes for free in the same expression — no
separate
PosteriorAdaptercall needed. - It is the canonical baseline. Every more sophisticated method must agree with OI in the linear-Gaussian limit; if it doesn't, something is wrong. Making OI a first-class method makes this baseline easy to produce.
- The Kalman filter analysis step is OI. Code reuse between vardax and ensemble libraries (filterax) is direct.
Apply.
vardax/_src/models/optimal_interpolation.pyimplementsOptimalInterpolation(eqx.Module)with nominimiserslot.linearize()is required on the observation operator. If \(H\) is intrinsically nonlinear, the user picksThreeDVarinstead;OptimalInterpolation.__init__validates linearity and refuses nonlinearobs_op.- A new Layer 0 primitive
blue_analysis(x_b, y, B_op, R_op, H_op)implements the closed-form math viagaussx. - Math chapter 4 covers BLUE / OI in detail, including the Sherman– Morrison–Woodbury identity used to pick between the \(B\)-space and the \(R\)-space form.
D17: Latent DA as a first-class peer family¶
New in v0.5.0.
Latent data assimilation — performing the variational analysis in a learned low-dimensional latent space \(\mathcal{Z}\) — is added as a peer family of three new Layer-2 classes alongside the seven established in D14:
LatentThreeDVar— single-time inversion in \(\mathcal{Z}\).LatentStrongFourDVar— multi-time, latent dynamics \(M_z\).LatentHybridFourDVar— multi-time, physics forecast in \(\mathcal{X}\), control and background in \(\mathcal{Z}\).
Context. Vardax already uses learned subspaces in three places —
the AE priors inside FourDVarNet*, the observation encoder inside
AmortizedPosterior, and the heads of the amortised family. None of
these promote \(z\) itself to the control variable; the variational
problem still solves in \(\mathcal{X}\). The benchmark literature
(Peyron 2021, Cheng 2023, Fablet 2021) consistently shows that
promoting \(z\) to control gives an order-of-magnitude wall-clock win
and a structurally smaller posterior covariance.
Options.
(A) Add space: Literal["x", "z"] to existing methods and dispatch
internally.
(B) Add three new peer classes; reuse the AE priors as LatentMaps.
(C) Treat latent DA as a configuration of FourDVarNet.
Decision. Option B.
Rationale.
- The control vector (\(z\) vs \(x\)), the cost dimensionality, the Hessian object, and the posterior covariance all live in genuinely different spaces. The D14 pattern argues we should not hide a structural difference behind a flag.
- The AE priors of
FourDVarNetare regularisers, not parameterisations. Conflating regularisation and parameterisation (Option C) confuses the reader and forces users to thread thepriorslot to mean two different things. - Three peer classes mirror the three flavours formalised in the
pipekit_cycle.LatentDACycleorchestrator: strong (z forecast + z update), prior-only (x forecast + x update, AE inside cost), hybrid (x forecast + z update). The vardax peers map 1-to-1 to those flavours.
Apply.
- New module
vardax/_src/models/latent.pywith the three classes. - New module
vardax/_src/latent.pywithLatentPriorandNeuralLatentForwardModeladapter. - Existing
BilinAEPrior1D/2D,MLPAEPrior1D,ConvAEPrior1Dgainlatent_dimandstate_signatureproperties so they satisfypipekit_cycle.LatentMapstructurally. No behaviour change. - New Layer-0 primitives
variational_cost_latent,latent_incremental_cost, plus anidentity_latent_map(N)helper for the regression test in §18.6 of the math reference. - New posterior adapter
LatentLaplaceCovarianceinvardax/_src/posterior/latent.py. - Math chapter 18 covers the formalism; design doc
design/latent_da.mdcovers the API.
Consequences.
- Three more peers in the Layer-2 family (now ten). D14's "no
parent–child" rule applies: latent peers do not inherit from
StrongFourDVaror each other. LatentMapis a substrate-neutral pipekit-cycle protocol, not a vardax-owned one. Cross-library compatibility withfilterax(LatentETKF,LatentLETKF) is automatic.WeakLatentFourDVarand a VAE-awareLatentMapare explicitly deferred to v0.6 (see design/latent_da.md §13).