Skip to content

Protocols & Types

vardax components are coupled structurally: anything with the right methods conforms, with no base class to inherit. Three of the Protocols — AnalysisStep, ForwardModel, and ObservationOperator — are re-exported from pipekit-cycle and define the seam along which vardax plugs into assimilation cycles; the rest are vardax-specific and define the seams inside a variational method (prior, cost, gradient modulator, posterior adapter, minimiser). All are runtime-checkable, so isinstance checks work at the boundaries.

pipekit-cycle protocols

The orchestration contract. ForwardModel propagates state between cycle times, ObservationOperator maps state to observation space (with linearize for the tangent-linear), and AnalysisStep is what every vardax model's .as_analysis_step() returns.

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

AnalysisStep

Bases: Protocol

Combine forecast state with observations to produce analysis state.

Implementations: ensemble Kalman analyses (EnKF, ETKF, LETKF), variational solvers (3D/4D-Var), particle filters, smoothers. Algorithm libraries supply concrete classes.

Members

call(forecast, obs, *, obs_op, obs_err_cov): Return the analysis state given the forecast, the observations, the observation operator, and the observation-error covariance.

Source code in .venv/lib/python3.12/site-packages/pipekit_cycle/protocols.py
@runtime_checkable
class AnalysisStep(Protocol):
    """Combine forecast state with observations to produce analysis state.

    Implementations: ensemble Kalman analyses (EnKF, ETKF, LETKF),
    variational solvers (3D/4D-Var), particle filters, smoothers.
    Algorithm libraries supply concrete classes.

    Members:
        __call__(forecast, obs, *, obs_op, obs_err_cov): Return the
            analysis state given the forecast, the observations, the
            observation operator, and the observation-error covariance.
    """

    def __call__(
        self,
        forecast: Any,
        obs: Any,
        *,
        obs_op: ObservationOperator,
        obs_err_cov: Any,
    ) -> Any: ...

ForwardModel

Bases: Protocol

Advance a model state forward in time by dt.

Implementations: domain forward models (chemistry transport, ocean state, plume dispersion), neural emulators, hybrid physics + ML hybrids. Algorithm libraries provide adapters that satisfy this protocol structurally.

Members

step(state, dt): Return the state advanced by dt. dt: Default integration step. state_signature: Optional pipekit.Signature describing the shape / dtype of the state carrier. None if the model doesn't track named dimensions.

Source code in .venv/lib/python3.12/site-packages/pipekit_cycle/protocols.py
@runtime_checkable
class ForwardModel(Protocol):
    """Advance a model state forward in time by ``dt``.

    Implementations: domain forward models (chemistry transport,
    ocean state, plume dispersion), neural emulators, hybrid
    physics + ML hybrids. Algorithm libraries provide adapters
    that satisfy this protocol structurally.

    Members:
        step(state, dt): Return the state advanced by ``dt``.
        dt: Default integration step.
        state_signature: Optional `pipekit.Signature` describing the
            shape / dtype of the state carrier. ``None`` if the model
            doesn't track named dimensions.
    """

    def step(self, state: Any, dt: float) -> Any: ...

    @property
    def dt(self) -> float: ...

    @property
    def state_signature(self) -> Any: ...

ObservationOperator

Bases: Protocol

Map model state → predicted observations.

The H operator in classical data-assimilation notation: H(x) produces "what would the observations look like if the state were x?". The innovation in DA is then obs - H(forecast).

Members

call(state): Return predicted observations for state. linearize(state): Optional tangent-linear operator at state (returns a callable / matrix). Implementations that don't expose a linearisation may raise NotImplementedError.

Source code in .venv/lib/python3.12/site-packages/pipekit_cycle/protocols.py
@runtime_checkable
class ObservationOperator(Protocol):
    """Map model state → predicted observations.

    The H operator in classical data-assimilation notation:
    ``H(x)`` produces "what would the observations look like if the
    state were ``x``?". The innovation in DA is then ``obs - H(forecast)``.

    Members:
        __call__(state): Return predicted observations for ``state``.
        linearize(state): Optional tangent-linear operator at
            ``state`` (returns a callable / matrix). Implementations
            that don't expose a linearisation may raise
            ``NotImplementedError``.
    """

    def __call__(self, state: Any) -> Any: ...

    def linearize(self, state: Any) -> Any: ...

vardax protocols

The internal seams: implement these to swap in custom priors, cost functions, gradient modulators, posterior adapters, or inner-loop minimisers without touching the model classes.

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

Prior

Bases: Protocol

Prior model: maps state to its regularised reconstruction.

For an autoencoder prior \(\varphi_\theta\), the variational cost includes ||x - φ(x)||^2. For a dynamical prior wrapping a ForwardModel, φ(x) is the forward integration. For the identity prior, φ(x) = x.

Members

__call__(x) -> x_prior — apply the prior model.

Source code in src/vardax/_src/protocols.py
@runtime_checkable
class Prior(Protocol):
    r"""Prior model: maps state to its regularised reconstruction.

    For an autoencoder prior $\varphi_\theta$, the variational
    cost includes ``||x - φ(x)||^2``. For a dynamical prior wrapping a
    `ForwardModel`, ``φ(x)`` is the forward integration. For the
    identity prior, ``φ(x) = x``.

    Members:
        ``__call__(x) -> x_prior`` — apply the prior model.
    """

    def __call__(self, x: Float[Array, ...]) -> Float[Array, ...]: ...

CostFunction

Bases: Protocol

Variational cost function J(x, batch, **kwargs) -> scalar.

Implementations include vardax.costs.ThreeDVarCost, StrongConstraintCost, WeakConstraintCost, IncrementalCost, and FourDVarNetCost.

Members

__call__(x, batch, **kwargs) -> scalar

Source code in src/vardax/_src/protocols.py
@runtime_checkable
class CostFunction(Protocol):
    """Variational cost function ``J(x, batch, **kwargs) -> scalar``.

    Implementations include `vardax.costs.ThreeDVarCost`,
    `StrongConstraintCost`, `WeakConstraintCost`, `IncrementalCost`,
    and `FourDVarNetCost`.

    Members:
        ``__call__(x, batch, **kwargs) -> scalar``
    """

    def __call__(
        self,
        x: Float[Array, ...],
        batch: Any,
        **kwargs: Any,
    ) -> Float[Array, ""]: ...

GradModulator

Bases: Protocol

Learned gradient modulator for the FourDVarNet inner solver.

Takes the current variational-cost gradient and the modulator's own carry state, returns a state update and the new carry. Used only by FourDVarNet; the classical analysis methods use optimistix.AbstractMinimiser instead.

Members

__call__(grad, state, carry) -> (update, new_carry)

Source code in src/vardax/_src/protocols.py
@runtime_checkable
class GradModulator(Protocol):
    """Learned gradient modulator for the FourDVarNet inner solver.

    Takes the current variational-cost gradient and the modulator's
    own carry state, returns a state update and the new carry. Used
    only by `FourDVarNet`; the classical analysis methods use
    ``optimistix.AbstractMinimiser`` instead.

    Members:
        ``__call__(grad, state, carry) -> (update, new_carry)``
    """

    def __call__(
        self,
        grad: Float[Array, ...],
        state: Float[Array, ...],
        carry: Any,
    ) -> tuple[Float[Array, ...], Any]: ...

PosteriorAdapter

Bases: Protocol

Turns an analysis output into a Posterior container.

Implementations: LaplaceCovariance, GaussNewtonHessian, EnsembleCovariance. Each computes the posterior covariance via a different approximation; the contract returned is the same.

Members

__call__(analysis, model, batch) -> Posterior

Source code in src/vardax/_src/protocols.py
@runtime_checkable
class PosteriorAdapter(Protocol):
    """Turns an analysis output into a `Posterior` container.

    Implementations: `LaplaceCovariance`, `GaussNewtonHessian`,
    `EnsembleCovariance`. Each computes the posterior covariance via a
    different approximation; the contract returned is the same.

    Members:
        ``__call__(analysis, model, batch) -> Posterior``
    """

    def __call__(
        self,
        analysis: Float[Array, ...],
        model: AnalysisStep,
        batch: Any,
    ) -> Any: ...

Minimiser

Bases: Protocol

Wrapper protocol around optimistix.AbstractMinimiser.

A Minimiser knows how to minimise a CostFunction from an initial guess x0 against a batch. Implementations adapt optimistix solvers (GaussNewton, BFGS, NonlinearCG, …) to vardax's cost-function calling convention.

Members

__call__(cost_fn, x0, batch) -> x_star

Source code in src/vardax/_src/protocols.py
@runtime_checkable
class Minimiser(Protocol):
    """Wrapper protocol around ``optimistix.AbstractMinimiser``.

    A `Minimiser` knows how to minimise a `CostFunction` from an
    initial guess ``x0`` against a batch. Implementations adapt
    optimistix solvers (GaussNewton, BFGS, NonlinearCG, …) to vardax's
    cost-function calling convention.

    Members:
        ``__call__(cost_fn, x0, batch) -> x_star``
    """

    def __call__(
        self,
        cost_fn: CostFunction,
        x0: Float[Array, ...],
        batch: Any,
    ) -> Float[Array, ...]: ...

Batch & state types

The typed carriers that flow through the solvers and training loops: 1D and 2D (single- and multi-variable) observation batches, and the recurrent-state containers of the ConvLSTM gradient modulators.

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

Batch1D

Bases: Module

Batch of 1-D spatiotemporal data.

Attributes:

Name Type Description
input Float[Array, 'B T N']

Observed (masked) input field of shape (B, T, N).

mask Float[Array, 'B T N']

Binary observation mask of shape (B, T, N).

target Float[Array, 'B T N'] | None

Ground-truth field of shape (B, T, N). Optional; absent in operational use where only observations are available.

Source code in src/vardax/_src/_types.py
class Batch1D(eqx.Module):
    """Batch of 1-D spatiotemporal data.

    Attributes:
        input: Observed (masked) input field of shape ``(B, T, N)``.
        mask: Binary observation mask of shape ``(B, T, N)``.
        target: Ground-truth field of shape ``(B, T, N)``. Optional; absent
            in operational use where only observations are available.
    """

    input: Float[Array, "B T N"]
    mask: Float[Array, "B T N"]
    target: Float[Array, "B T N"] | None = None

Batch2D

Bases: Module

Batch of 2-D spatiotemporal data.

Attributes:

Name Type Description
input Float[Array, 'B T H W']

Observed (masked) input field of shape (B, T, H, W).

mask Float[Array, 'B T H W']

Binary observation mask of shape (B, T, H, W).

target Float[Array, 'B T H W'] | None

Ground-truth field of shape (B, T, H, W). Optional.

Source code in src/vardax/_src/_types.py
class Batch2D(eqx.Module):
    """Batch of 2-D spatiotemporal data.

    Attributes:
        input: Observed (masked) input field of shape ``(B, T, H, W)``.
        mask: Binary observation mask of shape ``(B, T, H, W)``.
        target: Ground-truth field of shape ``(B, T, H, W)``. Optional.
    """

    input: Float[Array, "B T H W"]
    mask: Float[Array, "B T H W"]
    target: Float[Array, "B T H W"] | None = None

Batch2DMultivar

Bases: Module

Batch of 2-D multivariate spatiotemporal data.

Attributes:

Name Type Description
input Float[Array, 'B T C H W']

Observed (masked) input field of shape (B, T, C, H, W).

mask Float[Array, 'B T C H W']

Binary observation mask of shape (B, T, C, H, W).

target Float[Array, 'B T C H W'] | None

Ground-truth field of shape (B, T, C, H, W). Optional.

Source code in src/vardax/_src/_types.py
class Batch2DMultivar(eqx.Module):
    """Batch of 2-D multivariate spatiotemporal data.

    Attributes:
        input: Observed (masked) input field of shape ``(B, T, C, H, W)``.
        mask: Binary observation mask of shape ``(B, T, C, H, W)``.
        target: Ground-truth field of shape ``(B, T, C, H, W)``. Optional.
    """

    input: Float[Array, "B T C H W"]
    mask: Float[Array, "B T C H W"]
    target: Float[Array, "B T C H W"] | None = None

LSTMState1D

Bases: Module

Hidden state for a 1-D ConvLSTM gradient modulator.

Attributes:

Name Type Description
h Float[Array, 'B H_dim N']

Hidden state tensor of shape (B, H_dim, N).

c Float[Array, 'B H_dim N']

Cell state tensor of shape (B, H_dim, N).

Source code in src/vardax/_src/_types.py
class LSTMState1D(eqx.Module):
    """Hidden state for a 1-D ConvLSTM gradient modulator.

    Attributes:
        h: Hidden state tensor of shape ``(B, H_dim, N)``.
        c: Cell state tensor of shape ``(B, H_dim, N)``.
    """

    h: Float[Array, "B H_dim N"]
    c: Float[Array, "B H_dim N"]

    @classmethod
    def zeros(
        cls,
        batch_size: int,
        hidden_dim: int,
        seq_len: int,
    ) -> LSTMState1D:
        """Create a zero-initialised LSTM state."""
        return cls(
            h=jnp.zeros((batch_size, hidden_dim, seq_len)),
            c=jnp.zeros((batch_size, hidden_dim, seq_len)),
        )

zeros classmethod

zeros(
    batch_size: int, hidden_dim: int, seq_len: int
) -> LSTMState1D

Create a zero-initialised LSTM state.

Source code in src/vardax/_src/_types.py
@classmethod
def zeros(
    cls,
    batch_size: int,
    hidden_dim: int,
    seq_len: int,
) -> LSTMState1D:
    """Create a zero-initialised LSTM state."""
    return cls(
        h=jnp.zeros((batch_size, hidden_dim, seq_len)),
        c=jnp.zeros((batch_size, hidden_dim, seq_len)),
    )

LSTMState2D

Bases: Module

Hidden state for a 2-D ConvLSTM gradient modulator.

Attributes:

Name Type Description
h Float[Array, 'B H_dim H W']

Hidden state tensor of shape (B, H_dim, H, W).

c Float[Array, 'B H_dim H W']

Cell state tensor of shape (B, H_dim, H, W).

Source code in src/vardax/_src/_types.py
class LSTMState2D(eqx.Module):
    """Hidden state for a 2-D ConvLSTM gradient modulator.

    Attributes:
        h: Hidden state tensor of shape ``(B, H_dim, H, W)``.
        c: Cell state tensor of shape ``(B, H_dim, H, W)``.
    """

    h: Float[Array, "B H_dim H W"]
    c: Float[Array, "B H_dim H W"]

    @classmethod
    def zeros(
        cls,
        batch_size: int,
        hidden_dim: int,
        height: int,
        width: int,
    ) -> LSTMState2D:
        """Create a zero-initialised LSTM state."""
        return cls(
            h=jnp.zeros((batch_size, hidden_dim, height, width)),
            c=jnp.zeros((batch_size, hidden_dim, height, width)),
        )

zeros classmethod

zeros(
    batch_size: int,
    hidden_dim: int,
    height: int,
    width: int,
) -> LSTMState2D

Create a zero-initialised LSTM state.

Source code in src/vardax/_src/_types.py
@classmethod
def zeros(
    cls,
    batch_size: int,
    hidden_dim: int,
    height: int,
    width: int,
) -> LSTMState2D:
    """Create a zero-initialised LSTM state."""
    return cls(
        h=jnp.zeros((batch_size, hidden_dim, height, width)),
        c=jnp.zeros((batch_size, hidden_dim, height, width)),
    )