Protocols & Types¶
vardax components are coupled structurally: anything with the right
methods conforms, with no base class to inherit. Three of the Protocols —
AnalysisStep, ForwardModel, and ObservationOperator — are re-exported
from pipekit-cycle and define the
seam along which vardax plugs into assimilation
cycles; the rest are vardax-specific and define the seams
inside a variational method (prior, cost, gradient modulator, posterior
adapter, minimiser). All are runtime-checkable, so isinstance checks work
at the boundaries.
pipekit-cycle protocols¶
The orchestration contract. ForwardModel propagates state between cycle
times, ObservationOperator maps state to observation space (with
linearize for the tangent-linear), and AnalysisStep is what every
vardax model's .as_analysis_step() returns.
vardax — Modular variational data assimilation with learned components.
All public symbols are re-exported from the private _src subpackage so
that user code imports from the top-level namespace:
AnalysisStep
¶
Bases: Protocol
Combine forecast state with observations to produce analysis state.
Implementations: ensemble Kalman analyses (EnKF, ETKF, LETKF), variational solvers (3D/4D-Var), particle filters, smoothers. Algorithm libraries supply concrete classes.
Members
call(forecast, obs, *, obs_op, obs_err_cov): Return the analysis state given the forecast, the observations, the observation operator, and the observation-error covariance.
Source code in .venv/lib/python3.12/site-packages/pipekit_cycle/protocols.py
ForwardModel
¶
Bases: Protocol
Advance a model state forward in time by dt.
Implementations: domain forward models (chemistry transport, ocean state, plume dispersion), neural emulators, hybrid physics + ML hybrids. Algorithm libraries provide adapters that satisfy this protocol structurally.
Members
step(state, dt): Return the state advanced by dt.
dt: Default integration step.
state_signature: Optional pipekit.Signature describing the
shape / dtype of the state carrier. None if the model
doesn't track named dimensions.
Source code in .venv/lib/python3.12/site-packages/pipekit_cycle/protocols.py
ObservationOperator
¶
Bases: Protocol
Map model state → predicted observations.
The H operator in classical data-assimilation notation:
H(x) produces "what would the observations look like if the
state were x?". The innovation in DA is then obs - H(forecast).
Members
call(state): Return predicted observations for state.
linearize(state): Optional tangent-linear operator at
state (returns a callable / matrix). Implementations
that don't expose a linearisation may raise
NotImplementedError.
Source code in .venv/lib/python3.12/site-packages/pipekit_cycle/protocols.py
vardax protocols¶
The internal seams: implement these to swap in custom priors, cost functions, gradient modulators, posterior adapters, or inner-loop minimisers without touching the model classes.
vardax — Modular variational data assimilation with learned components.
All public symbols are re-exported from the private _src subpackage so
that user code imports from the top-level namespace:
Prior
¶
Bases: Protocol
Prior model: maps state to its regularised reconstruction.
For an autoencoder prior \(\varphi_\theta\), the variational
cost includes ||x - φ(x)||^2. For a dynamical prior wrapping a
ForwardModel, φ(x) is the forward integration. For the
identity prior, φ(x) = x.
Members
__call__(x) -> x_prior — apply the prior model.
Source code in src/vardax/_src/protocols.py
CostFunction
¶
Bases: Protocol
Variational cost function J(x, batch, **kwargs) -> scalar.
Implementations include vardax.costs.ThreeDVarCost,
StrongConstraintCost, WeakConstraintCost, IncrementalCost,
and FourDVarNetCost.
Members
__call__(x, batch, **kwargs) -> scalar
Source code in src/vardax/_src/protocols.py
GradModulator
¶
Bases: Protocol
Learned gradient modulator for the FourDVarNet inner solver.
Takes the current variational-cost gradient and the modulator's
own carry state, returns a state update and the new carry. Used
only by FourDVarNet; the classical analysis methods use
optimistix.AbstractMinimiser instead.
Members
__call__(grad, state, carry) -> (update, new_carry)
Source code in src/vardax/_src/protocols.py
PosteriorAdapter
¶
Bases: Protocol
Turns an analysis output into a Posterior container.
Implementations: LaplaceCovariance, GaussNewtonHessian,
EnsembleCovariance. Each computes the posterior covariance via a
different approximation; the contract returned is the same.
Members
__call__(analysis, model, batch) -> Posterior
Source code in src/vardax/_src/protocols.py
Minimiser
¶
Bases: Protocol
Wrapper protocol around optimistix.AbstractMinimiser.
A Minimiser knows how to minimise a CostFunction from an
initial guess x0 against a batch. Implementations adapt
optimistix solvers (GaussNewton, BFGS, NonlinearCG, …) to vardax's
cost-function calling convention.
Members
__call__(cost_fn, x0, batch) -> x_star
Source code in src/vardax/_src/protocols.py
Batch & state types¶
The typed carriers that flow through the solvers and training loops: 1D and 2D (single- and multi-variable) observation batches, and the recurrent-state containers of the ConvLSTM gradient modulators.
vardax — Modular variational data assimilation with learned components.
All public symbols are re-exported from the private _src subpackage so
that user code imports from the top-level namespace:
Batch1D
¶
Bases: Module
Batch of 1-D spatiotemporal data.
Attributes:
| Name | Type | Description |
|---|---|---|
input |
Float[Array, 'B T N']
|
Observed (masked) input field of shape |
mask |
Float[Array, 'B T N']
|
Binary observation mask of shape |
target |
Float[Array, 'B T N'] | None
|
Ground-truth field of shape |
Source code in src/vardax/_src/_types.py
Batch2D
¶
Bases: Module
Batch of 2-D spatiotemporal data.
Attributes:
| Name | Type | Description |
|---|---|---|
input |
Float[Array, 'B T H W']
|
Observed (masked) input field of shape |
mask |
Float[Array, 'B T H W']
|
Binary observation mask of shape |
target |
Float[Array, 'B T H W'] | None
|
Ground-truth field of shape |
Source code in src/vardax/_src/_types.py
Batch2DMultivar
¶
Bases: Module
Batch of 2-D multivariate spatiotemporal data.
Attributes:
| Name | Type | Description |
|---|---|---|
input |
Float[Array, 'B T C H W']
|
Observed (masked) input field of shape |
mask |
Float[Array, 'B T C H W']
|
Binary observation mask of shape |
target |
Float[Array, 'B T C H W'] | None
|
Ground-truth field of shape |
Source code in src/vardax/_src/_types.py
LSTMState1D
¶
Bases: Module
Hidden state for a 1-D ConvLSTM gradient modulator.
Attributes:
| Name | Type | Description |
|---|---|---|
h |
Float[Array, 'B H_dim N']
|
Hidden state tensor of shape |
c |
Float[Array, 'B H_dim N']
|
Cell state tensor of shape |
Source code in src/vardax/_src/_types.py
zeros
classmethod
¶
zeros(
batch_size: int, hidden_dim: int, seq_len: int
) -> LSTMState1D
Create a zero-initialised LSTM state.
Source code in src/vardax/_src/_types.py
LSTMState2D
¶
Bases: Module
Hidden state for a 2-D ConvLSTM gradient modulator.
Attributes:
| Name | Type | Description |
|---|---|---|
h |
Float[Array, 'B H_dim H W']
|
Hidden state tensor of shape |
c |
Float[Array, 'B H_dim H W']
|
Cell state tensor of shape |
Source code in src/vardax/_src/_types.py
zeros
classmethod
¶
zeros(
batch_size: int,
hidden_dim: int,
height: int,
width: int,
) -> LSTMState2D
Create a zero-initialised LSTM state.