Skip to content

Cycle Integration

Sequential assimilation — forecast, observe, analyse, repeat — is orchestrated by pipekit-cycle, not reimplemented here. The two factories on this page are thin wrappers that assemble a pipekit_cycle.DACycle (filtering) or pipekit_cycle.SmootherCycle (fixed-lag smoothing) from vardax parts: any model's .as_analysis_step(), any observation operator, and a forward model satisfying the ForwardModel Protocol.

Because the coupling is purely structural (runtime-checkable Protocols, no inheritance), the same cycle accepts every vardax method interchangeably — swapping 3DVar for strong-constraint 4DVar inside a cycle is a one-argument change. The full forecast/analysis loop is worked through in Six-Step Inference Cycle. Both factories are also accessible via the vardax.cycle submodule namespace.

Factories

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

VarDACycle

VarDACycle(
    forward: Any,
    obs_op: Any,
    model: Any,
    *,
    obs_source: Any | None = None,
    n_steps: int = 1,
    save_history: bool = True,
) -> DACycle

Build a pipekit_cycle.DACycle from a vardax model.

Pulls model.as_analysis_step() and hands it to pipekit-cycle. The same orchestration code works for any of the seven Layer 2 classes — OptimalInterpolation, ThreeDVar, StrongFourDVar, WeakFourDVar, IncrementalFourDVar, FourDVarNet, AmortizedPosterior.

Parameters:

Name Type Description Default
forward Any

Forward model satisfying pipekit_cycle.ForwardModel.

required
obs_op Any

Observation operator satisfying pipekit_cycle.ObservationOperator.

required
model Any

Vardax Layer 2 model exposing .as_analysis_step().

required
obs_source Any | None

Optional pipekit.Operator invoked as obs_source(step_index) to load observations per cycle. If None, the analysis step is skipped and the forecast propagates forward.

None
n_steps int

Number of forecast-analysis cycles per call.

1
save_history bool

Append (forecast, analysis) pairs to cycle.history each cycle.

True

Returns:

Type Description
DACycle

Configured pipekit_cycle.DACycle ready to call.

Raises:

Type Description
AttributeError

If model does not expose .as_analysis_step().

Examples:

>>> import equinox as eqx
>>> import jax, jax.numpy as jnp, lineax as lx
>>> import vardax as vdx
>>> class IdentityForward(eqx.Module):
...     dt: float = 1.0
...     state_signature: None = None
...
...     def step(self, state, dt):
...         return state
>>> eye = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((1, 4), jnp.float32))
>>> oi = vdx.OptimalInterpolation(
...     obs_op=vdx.LinearObs(H_mat=eye),
...     prior_mean=jnp.zeros((1, 4)),
...     prior_cov_op=eye,
...     obs_cov_op=eye,
... )
>>> cycle = vdx.VarDACycle(
...     forward=IdentityForward(),
...     obs_op=vdx.MaskedIdentity(),
...     model=oi,
...     n_steps=1,
... )
>>> type(cycle).__name__
'DACycle'
Source code in src/vardax/_src/cycle/da_cycle.py
def VarDACycle(
    forward: Any,
    obs_op: Any,
    model: Any,
    *,
    obs_source: Any | None = None,
    n_steps: int = 1,
    save_history: bool = True,
) -> DACycle:
    """Build a ``pipekit_cycle.DACycle`` from a vardax model.

    Pulls ``model.as_analysis_step()`` and hands it to pipekit-cycle.
    The same orchestration code works for any of the seven Layer 2
    classes — `OptimalInterpolation`, `ThreeDVar`, `StrongFourDVar`,
    `WeakFourDVar`, `IncrementalFourDVar`, `FourDVarNet`,
    `AmortizedPosterior`.

    Args:
        forward: Forward model satisfying ``pipekit_cycle.ForwardModel``.
        obs_op: Observation operator satisfying
            ``pipekit_cycle.ObservationOperator``.
        model: Vardax Layer 2 model exposing ``.as_analysis_step()``.
        obs_source: Optional ``pipekit.Operator`` invoked as
            ``obs_source(step_index)`` to load observations per cycle.
            If ``None``, the analysis step is skipped and the forecast
            propagates forward.
        n_steps: Number of forecast-analysis cycles per call.
        save_history: Append ``(forecast, analysis)`` pairs to
            ``cycle.history`` each cycle.

    Returns:
        Configured ``pipekit_cycle.DACycle`` ready to call.

    Raises:
        AttributeError: If ``model`` does not expose
            ``.as_analysis_step()``.

    Examples:
        >>> import equinox as eqx
        >>> import jax, jax.numpy as jnp, lineax as lx
        >>> import vardax as vdx
        >>> class IdentityForward(eqx.Module):
        ...     dt: float = 1.0
        ...     state_signature: None = None
        ...
        ...     def step(self, state, dt):
        ...         return state
        >>> eye = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((1, 4), jnp.float32))
        >>> oi = vdx.OptimalInterpolation(
        ...     obs_op=vdx.LinearObs(H_mat=eye),
        ...     prior_mean=jnp.zeros((1, 4)),
        ...     prior_cov_op=eye,
        ...     obs_cov_op=eye,
        ... )
        >>> cycle = vdx.VarDACycle(
        ...     forward=IdentityForward(),
        ...     obs_op=vdx.MaskedIdentity(),
        ...     model=oi,
        ...     n_steps=1,
        ... )
        >>> type(cycle).__name__
        'DACycle'
    """
    if not hasattr(model, "as_analysis_step"):
        raise AttributeError(
            "VarDACycle requires a model exposing `.as_analysis_step()`; "
            f"got {type(model).__name__}."
        )
    return _pc.DACycle(
        forward_model=forward,
        obs_op=obs_op,
        analysis_step=model.as_analysis_step(),
        obs_source=obs_source,
        n_steps=n_steps,
        save_history=save_history,
    )

VarSmootherCycle

VarSmootherCycle(
    forward: Any,
    obs_op: Any,
    model: Any,
    *,
    window: int,
    stride: int = 1,
    obs_source: Any | None = None,
) -> SmootherCycle

Build a pipekit_cycle.SmootherCycle from a vardax model.

Retrospective windowed smoothing: the model analyses each window-step window, sliding by stride between consecutive windows.

Parameters:

Name Type Description Default
forward Any

Forward model.

required
obs_op Any

Observation operator.

required
model Any

Vardax Layer 2 model exposing .as_analysis_step().

required
window int

Number of forecast steps per smoother window.

required
stride int

Step between consecutive window starts.

1
obs_source Any | None

Optional observation loader.

None

Returns:

Type Description
SmootherCycle

Configured pipekit_cycle.SmootherCycle.

Source code in src/vardax/_src/cycle/da_cycle.py
def VarSmootherCycle(
    forward: Any,
    obs_op: Any,
    model: Any,
    *,
    window: int,
    stride: int = 1,
    obs_source: Any | None = None,
) -> SmootherCycle:
    """Build a ``pipekit_cycle.SmootherCycle`` from a vardax model.

    Retrospective windowed smoothing: the model analyses each
    ``window``-step window, sliding by ``stride`` between consecutive
    windows.

    Args:
        forward: Forward model.
        obs_op: Observation operator.
        model: Vardax Layer 2 model exposing ``.as_analysis_step()``.
        window: Number of forecast steps per smoother window.
        stride: Step between consecutive window starts.
        obs_source: Optional observation loader.

    Returns:
        Configured ``pipekit_cycle.SmootherCycle``.
    """
    if not hasattr(model, "as_analysis_step"):
        raise AttributeError(
            "VarSmootherCycle requires a model exposing `.as_analysis_step()`; "
            f"got {type(model).__name__}."
        )
    return _pc.SmootherCycle(
        forward_model=forward,
        obs_op=obs_op,
        analysis_step=model.as_analysis_step(),
        window=window,
        stride=stride,
        obs_source=obs_source,
    )