Skip to content

Models

The seven analysis methods are peers — sibling equinox.Module classes with no inheritance between them, each owning exactly the assumptions its math requires. Pick by problem structure: a single analysis time with a linear observation operator wants OI or 3DVar; an assimilation window with perfect-model dynamics wants strong-constraint 4DVar; admitting model error turns that into weak-constraint 4DVar; the operational linearise-and-iterate formulation with control-variable transform is incremental 4DVar; and replacing the inner-loop optimiser with a trained ConvLSTM gives 4DVarNet.

Every model exposes .as_analysis_step(), returning a lightweight wrapper that satisfies the pipekit-cycle AnalysisStep Protocol — that is the seam through which all seven plug into VarDACycle / VarSmootherCycle interchangeably.

Classical methods

Closed-form and optimisation-based analyses. OptimalInterpolation is the linear-Gaussian BLUE solution and refuses non-linear observation operators at construction; ThreeDVar minimises the same cost iteratively and accepts non-linear operators; the three 4DVar variants extend the cost over a time window. IncrementalConfig collects the outer/inner-loop knobs of IncrementalFourDVar.

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

OptimalInterpolation

Bases: Module

BLUE / OI — closed-form linear-Gaussian analysis.

Attributes:

Name Type Description
obs_op Any

Observation operator. Must be linear (its linearize(x) must return an operator independent of x). Use ThreeDVar for non-linear H.

prior_mean Float[Array, 'T N']

Background \(x_b\) of shape (T, N).

prior_cov_op AbstractLinearOperator

Background-error covariance \(B\) as a lineax.AbstractLinearOperator.

obs_cov_op AbstractLinearOperator

Observation-error covariance \(R\) as a lineax.AbstractLinearOperator.

cg_atol float

CG absolute tolerance for the inner solve.

cg_rtol float

CG relative tolerance for the inner solve.

cg_max_steps int

CG iteration cap.

Examples:

With \(B = R = I\), identity \(H\) and everything observed, the analysis splits the innovation in half: \(x^* = y / 2\).

>>> import jax, jax.numpy as jnp, lineax as lx, vardax
>>> eye = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((1, 3), jnp.float32))
>>> oi = vardax.OptimalInterpolation(
...     obs_op=vardax.MaskedIdentity(),
...     prior_mean=jnp.zeros((1, 3)),
...     prior_cov_op=eye,
...     obs_cov_op=eye,
... )
>>> batch = vardax.Batch1D(input=jnp.ones((1, 1, 3)), mask=jnp.ones((1, 1, 3)))
>>> xa = oi(batch)
>>> xa.shape
(1, 1, 3)
>>> bool(jnp.allclose(xa, 0.5, atol=1e-4))
True
Source code in src/vardax/_src/models/optimal_interpolation.py
class OptimalInterpolation(eqx.Module):
    r"""BLUE / OI — closed-form linear-Gaussian analysis.

    Attributes:
        obs_op: Observation operator. Must be linear (its
            ``linearize(x)`` must return an operator independent of
            ``x``). Use [`ThreeDVar`][vardax.ThreeDVar] for non-linear
            ``H``.
        prior_mean: Background $x_b$ of shape ``(T, N)``.
        prior_cov_op: Background-error covariance $B$ as a
            ``lineax.AbstractLinearOperator``.
        obs_cov_op: Observation-error covariance $R$ as a
            ``lineax.AbstractLinearOperator``.
        cg_atol: CG absolute tolerance for the inner solve.
        cg_rtol: CG relative tolerance for the inner solve.
        cg_max_steps: CG iteration cap.

    Examples:
        With $B = R = I$, identity $H$ and everything observed, the
        analysis splits the innovation in half: $x^* = y / 2$.

        >>> import jax, jax.numpy as jnp, lineax as lx, vardax
        >>> eye = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((1, 3), jnp.float32))
        >>> oi = vardax.OptimalInterpolation(
        ...     obs_op=vardax.MaskedIdentity(),
        ...     prior_mean=jnp.zeros((1, 3)),
        ...     prior_cov_op=eye,
        ...     obs_cov_op=eye,
        ... )
        >>> batch = vardax.Batch1D(input=jnp.ones((1, 1, 3)), mask=jnp.ones((1, 1, 3)))
        >>> xa = oi(batch)
        >>> xa.shape
        (1, 1, 3)
        >>> bool(jnp.allclose(xa, 0.5, atol=1e-4))
        True
    """

    obs_op: Any  # ObservationOperator (linear)
    prior_mean: Float[Array, "T N"]
    prior_cov_op: lx.AbstractLinearOperator
    obs_cov_op: lx.AbstractLinearOperator
    cg_atol: float = eqx.field(static=True, default=1e-6)
    cg_rtol: float = eqx.field(static=True, default=1e-6)
    cg_max_steps: int = eqx.field(static=True, default=200)

    def __call__(self, batch: Batch1D) -> Float[Array, "B T N"]:
        """BLUE analysis for each sample in the batch."""

        def _one(
            input_i: Float[Array, "T N"], mask_i: Float[Array, "T N"]
        ) -> Float[Array, "T N"]:
            # The effective observation operator is mask ⊙ H(·) — apply the
            # mask both to the predicted obs (innovation) AND to the
            # tangent-linear used to assemble (H B H^T + R). Without the
            # mask in `linearize`, missing entries are treated as
            # zero-innovation observations and bias the analysis whenever
            # B couples observed and missing components.
            def masked_obs(x: Float[Array, "T N"]) -> Float[Array, "T N"]:
                raw = (
                    self.obs_op(x, mask=mask_i)
                    if _accepts_mask(self.obs_op)
                    else self.obs_op(x)
                )
                return mask_i * raw

            y_pred = masked_obs(self.prior_mean)
            innovation = (input_i * mask_i) - y_pred

            # Build (H B H^T + R) as a composed lineax operator. We
            # tag it positive-semidefinite so lineax.CG accepts it
            # (the composition itself doesn't carry the tag).
            H = lx.JacobianLinearOperator(
                lambda x, _args=None: masked_obs(x), self.prior_mean
            )
            inner_raw = H @ self.prior_cov_op @ H.transpose() + self.obs_cov_op
            inner_op = lx.TaggedLinearOperator(inner_raw, lx.positive_semidefinite_tag)

            solver = lx.CG(
                atol=self.cg_atol,
                rtol=self.cg_rtol,
                max_steps=self.cg_max_steps,
            )
            v = lx.linear_solve(inner_op, innovation, solver=solver).value

            # x_star = x_b + B H^T v
            return self.prior_mean + self.prior_cov_op.mv(H.transpose().mv(v))

        return jax.vmap(_one)(batch.input, batch.mask)

    def as_analysis_step(self) -> _OIAnalysisStep:
        """Adapt to ``pipekit_cycle.AnalysisStep`` (Decision D8)."""
        return _OIAnalysisStep(self)

as_analysis_step

as_analysis_step() -> _OIAnalysisStep

Adapt to pipekit_cycle.AnalysisStep (Decision D8).

Source code in src/vardax/_src/models/optimal_interpolation.py
def as_analysis_step(self) -> _OIAnalysisStep:
    """Adapt to ``pipekit_cycle.AnalysisStep`` (Decision D8)."""
    return _OIAnalysisStep(self)

ThreeDVar

Bases: Module

3D variational analysis.

Attributes:

Name Type Description
obs_op Any

Observation operator (linear or nonlinear).

prior_mean Float[Array, 'T N']

Background \(x_b\) of shape (T, N).

prior_cov_op AbstractLinearOperator

\(B\). lineax.AbstractLinearOperator.

obs_cov_op AbstractLinearOperator

\(R\). lineax.AbstractLinearOperator.

minimiser AbstractMinimiser

optimistix.AbstractMinimiser for the inner iteration. Default optimistix.BFGS(rtol=1e-6, atol=1e-6).

minimiser_adjoint AbstractAdjoint

optimistix.AbstractAdjoint for differentiating through the minimum (used only if the ThreeDVar analysis is itself nested inside a larger training loop). Default ImplicitAdjoint — exact at the optimum.

max_steps int

Iteration cap on the inner solver.

Examples:

With \(B = R = I\), identity \(H\) and everything observed, the analysis agrees with the closed-form BLUE, \(x^* = y / 2\).

>>> import jax, jax.numpy as jnp, lineax as lx, vardax
>>> eye = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((1, 3), jnp.float32))
>>> three = vardax.ThreeDVar(
...     obs_op=vardax.MaskedIdentity(),
...     prior_mean=jnp.zeros((1, 3)),
...     prior_cov_op=eye,
...     obs_cov_op=eye,
... )
>>> batch = vardax.Batch1D(input=jnp.ones((1, 1, 3)), mask=jnp.ones((1, 1, 3)))
>>> xa = three(batch)
>>> xa.shape
(1, 1, 3)
>>> bool(jnp.allclose(xa, 0.5, atol=1e-3))
True
Source code in src/vardax/_src/models/threedvar.py
class ThreeDVar(eqx.Module):
    r"""3D variational analysis.

    Attributes:
        obs_op: Observation operator (linear or nonlinear).
        prior_mean: Background $x_b$ of shape ``(T, N)``.
        prior_cov_op: $B$. ``lineax.AbstractLinearOperator``.
        obs_cov_op: $R$. ``lineax.AbstractLinearOperator``.
        minimiser: ``optimistix.AbstractMinimiser`` for the inner
            iteration. Default ``optimistix.BFGS(rtol=1e-6, atol=1e-6)``.
        minimiser_adjoint: ``optimistix.AbstractAdjoint`` for
            differentiating *through* the minimum (used only if the
            ``ThreeDVar`` analysis is itself nested inside a larger
            training loop). Default ``ImplicitAdjoint`` — exact at the
            optimum.
        max_steps: Iteration cap on the inner solver.

    Examples:
        With $B = R = I$, identity $H$ and everything observed, the
        analysis agrees with the closed-form BLUE, $x^* = y / 2$.

        >>> import jax, jax.numpy as jnp, lineax as lx, vardax
        >>> eye = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((1, 3), jnp.float32))
        >>> three = vardax.ThreeDVar(
        ...     obs_op=vardax.MaskedIdentity(),
        ...     prior_mean=jnp.zeros((1, 3)),
        ...     prior_cov_op=eye,
        ...     obs_cov_op=eye,
        ... )
        >>> batch = vardax.Batch1D(input=jnp.ones((1, 1, 3)), mask=jnp.ones((1, 1, 3)))
        >>> xa = three(batch)
        >>> xa.shape
        (1, 1, 3)
        >>> bool(jnp.allclose(xa, 0.5, atol=1e-3))
        True
    """

    obs_op: Any
    prior_mean: Float[Array, "T N"]
    prior_cov_op: lx.AbstractLinearOperator
    obs_cov_op: lx.AbstractLinearOperator
    minimiser: optx.AbstractMinimiser = eqx.field(static=True)
    minimiser_adjoint: optx.AbstractAdjoint = eqx.field(static=True)
    max_steps: int = eqx.field(static=True, default=200)

    def __init__(
        self,
        obs_op: Any,
        prior_mean: Float[Array, "T N"],
        prior_cov_op: lx.AbstractLinearOperator,
        obs_cov_op: lx.AbstractLinearOperator,
        minimiser: optx.AbstractMinimiser | None = None,
        minimiser_adjoint: optx.AbstractAdjoint | None = None,
        max_steps: int = 200,
    ) -> None:
        self.obs_op = obs_op
        self.prior_mean = prior_mean
        self.prior_cov_op = prior_cov_op
        self.obs_cov_op = obs_cov_op
        self.minimiser = minimiser or optx.BFGS(rtol=1e-6, atol=1e-6)
        self.minimiser_adjoint = minimiser_adjoint or optx.ImplicitAdjoint()
        self.max_steps = max_steps

    def __call__(self, batch: Batch1D) -> Float[Array, "B T N"]:
        """3DVar analysis for each sample in the batch."""

        def _one(
            input_i: Float[Array, "T N"], mask_i: Float[Array, "T N"]
        ) -> Float[Array, "T N"]:
            def cost(x: Float[Array, "T N"], _args: Any) -> Float[Array, ""]:
                # Background term — apply B^{-1} via lineax solve.
                dx = x - self.prior_mean
                B_inv_dx = lx.linear_solve(
                    self.prior_cov_op,
                    dx,
                    solver=lx.CG(atol=1e-6, rtol=1e-6),
                ).value
                j_bg = 0.5 * jnp.sum(dx * B_inv_dx)
                # Observation term — apply R^{-1} via lineax solve.
                # Mask predictions AND observations symmetrically: for
                # obs_ops that don't take a `mask` kwarg (e.g.
                # LinearObs), unmasked predictions at missing entries
                # would otherwise penalise the analysis there.
                y_pred = (
                    self.obs_op(x, mask=mask_i)
                    if _accepts_mask(self.obs_op)
                    else self.obs_op(x)
                )
                residual = mask_i * (input_i - y_pred)
                R_inv_r = lx.linear_solve(
                    self.obs_cov_op,
                    residual,
                    solver=lx.CG(atol=1e-6, rtol=1e-6),
                ).value
                j_obs = 0.5 * jnp.sum(residual * R_inv_r)
                return j_bg + j_obs

            result = optx.minimise(
                fn=cost,
                solver=self.minimiser,
                y0=self.prior_mean,
                args=None,
                max_steps=self.max_steps,
                adjoint=self.minimiser_adjoint,
                throw=False,
            )
            return result.value

        return jax.vmap(_one)(batch.input, batch.mask)

    def as_analysis_step(self) -> _ThreeDVarAnalysisStep:
        """Adapt to ``pipekit_cycle.AnalysisStep`` (Decision D8)."""
        return _ThreeDVarAnalysisStep(self)

as_analysis_step

as_analysis_step() -> _ThreeDVarAnalysisStep

Adapt to pipekit_cycle.AnalysisStep (Decision D8).

Source code in src/vardax/_src/models/threedvar.py
def as_analysis_step(self) -> _ThreeDVarAnalysisStep:
    """Adapt to ``pipekit_cycle.AnalysisStep`` (Decision D8)."""
    return _ThreeDVarAnalysisStep(self)

StrongFourDVar

Bases: Module

Strong-constraint 4DVar.

Attributes:

Name Type Description
forward Any

pipekit_cycle.ForwardModel supplying step(state, dt) -> state.

obs_op Any

Observation operator.

prior_mean Float[Array, N]

Background \(x_b\) — initial state of shape (N,) (state vector at \(t=0\)).

prior_cov_op AbstractLinearOperator

\(B\).

obs_cov_op AbstractLinearOperator

\(R\).

minimiser AbstractMinimiser

optimistix.AbstractMinimiser for the outer optimisation over \(x_0\).

minimiser_adjoint AbstractAdjoint

optimistix.AbstractAdjoint for differentiating through the minimum.

forward_adjoint Any

diffrax.AbstractAdjoint-like — currently stored as a tag, threaded through if the forward delegates to diffrax. Default None (use forward's own).

max_steps int

Iteration cap on the outer solver.

Examples:

Trivial dynamics (\(M_t(x) = x\)) and a single timestep reduce strong 4DVar to 3DVar, so with \(B = R = I\) the analysis is \(x_0^* = y / 2\).

>>> import jax, jax.numpy as jnp, lineax as lx, vardax
>>> class Identity:
...     dt = 1.0
...
...     def step(self, x, dt):
...         return x
>>> eye = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((3,), jnp.float32))
>>> strong = vardax.StrongFourDVar(
...     forward=Identity(),
...     obs_op=vardax.MaskedIdentity(),
...     prior_mean=jnp.zeros(3),
...     prior_cov_op=eye,
...     obs_cov_op=eye,
... )
>>> batch = vardax.Batch1D(input=jnp.ones((1, 1, 3)), mask=jnp.ones((1, 1, 3)))
>>> strong(batch).shape
(1, 3)
Source code in src/vardax/_src/models/strong_fourdvar.py
class StrongFourDVar(eqx.Module):
    r"""Strong-constraint 4DVar.

    Attributes:
        forward: ``pipekit_cycle.ForwardModel`` supplying
            ``step(state, dt) -> state``.
        obs_op: Observation operator.
        prior_mean: Background $x_b$ — initial state of shape
            ``(N,)`` (state vector at $t=0$).
        prior_cov_op: $B$.
        obs_cov_op: $R$.
        minimiser: ``optimistix.AbstractMinimiser`` for the outer
            optimisation over $x_0$.
        minimiser_adjoint: ``optimistix.AbstractAdjoint`` for
            differentiating through the minimum.
        forward_adjoint: ``diffrax.AbstractAdjoint``-like — currently
            stored as a tag, threaded through if the forward delegates
            to ``diffrax``. Default ``None`` (use forward's own).
        max_steps: Iteration cap on the outer solver.

    Examples:
        Trivial dynamics ($M_t(x) = x$) and a single timestep reduce
        strong 4DVar to 3DVar, so with $B = R = I$ the analysis is
        $x_0^* = y / 2$.

        >>> import jax, jax.numpy as jnp, lineax as lx, vardax
        >>> class Identity:
        ...     dt = 1.0
        ...
        ...     def step(self, x, dt):
        ...         return x
        >>> eye = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((3,), jnp.float32))
        >>> strong = vardax.StrongFourDVar(
        ...     forward=Identity(),
        ...     obs_op=vardax.MaskedIdentity(),
        ...     prior_mean=jnp.zeros(3),
        ...     prior_cov_op=eye,
        ...     obs_cov_op=eye,
        ... )
        >>> batch = vardax.Batch1D(input=jnp.ones((1, 1, 3)), mask=jnp.ones((1, 1, 3)))
        >>> strong(batch).shape
        (1, 3)
    """

    forward: Any
    obs_op: Any
    prior_mean: Float[Array, N]  # ty:ignore[unresolved-reference]
    prior_cov_op: lx.AbstractLinearOperator
    obs_cov_op: lx.AbstractLinearOperator
    minimiser: optx.AbstractMinimiser = eqx.field(static=True)
    minimiser_adjoint: optx.AbstractAdjoint = eqx.field(static=True)
    forward_adjoint: Any = eqx.field(static=True)
    max_steps: int = eqx.field(static=True, default=200)

    def __init__(
        self,
        forward: Any,
        obs_op: Any,
        prior_mean: Float[Array, N],  # ty:ignore[unresolved-reference]
        prior_cov_op: lx.AbstractLinearOperator,
        obs_cov_op: lx.AbstractLinearOperator,
        minimiser: optx.AbstractMinimiser | None = None,
        minimiser_adjoint: optx.AbstractAdjoint | None = None,
        forward_adjoint: Any = None,
        max_steps: int = 200,
    ) -> None:
        self.forward = forward
        self.obs_op = obs_op
        self.prior_mean = prior_mean
        self.prior_cov_op = prior_cov_op
        self.obs_cov_op = obs_cov_op
        self.minimiser = minimiser or optx.BFGS(rtol=1e-6, atol=1e-6)
        self.minimiser_adjoint = minimiser_adjoint or optx.ImplicitAdjoint()
        self.forward_adjoint = forward_adjoint
        self.max_steps = max_steps

    def _rollout(self, x_0: Float[Array, N], n_steps: int) -> Float[Array, "T N"]:  # ty:ignore[unresolved-reference]
        """Step the forward model ``n_steps`` times, returning the trajectory."""
        dt = self.forward.dt

        def step_fn(
            x: Float[Array, N],  # ty:ignore[unresolved-reference]
            _: None,
        ) -> tuple[Float[Array, N], Float[Array, N]]:  # ty:ignore[unresolved-reference]
            x_new = self.forward.step(x, dt)
            return x_new, x_new

        _, trajectory = jax.lax.scan(step_fn, x_0, None, length=n_steps)
        # Prepend x_0 so the trajectory has length T+1.
        return jnp.concatenate([x_0[None, :], trajectory], axis=0)

    def __call__(self, batch: Batch1D) -> Float[Array, "B N"]:
        """Strong-4DVar analysis: minimise over $x_0$."""

        T = batch.input.shape[1] - 1  # rollout steps (input has T+1 timesteps)

        def _one(
            input_i: Float[Array, "T_plus_1 N"], mask_i: Float[Array, "T_plus_1 N"]
        ) -> Float[Array, N]:  # ty:ignore[unresolved-reference]
            def cost(x_0: Float[Array, N], _args: Any) -> Float[Array, ""]:  # ty:ignore[unresolved-reference]
                # Background term
                dx = x_0 - self.prior_mean
                B_inv_dx = lx.linear_solve(
                    self.prior_cov_op,
                    dx,
                    solver=lx.CG(atol=1e-6, rtol=1e-6),
                ).value
                j_bg = 0.5 * jnp.sum(dx * B_inv_dx)

                # Roll out the trajectory.
                trajectory = self._rollout(x_0, n_steps=T)

                # Observation term, summed across time.
                def _per_step(x_t, y_t, m_t):
                    y_pred = (
                        self.obs_op(x_t, mask=m_t)
                        if _accepts_mask(self.obs_op)
                        else self.obs_op(x_t)
                    )
                    residual = m_t * (y_t - y_pred)
                    R_inv_r = lx.linear_solve(
                        self.obs_cov_op,
                        residual,
                        solver=lx.CG(atol=1e-6, rtol=1e-6),
                    ).value
                    return 0.5 * jnp.sum(residual * R_inv_r)

                per_step_costs = jax.vmap(_per_step)(trajectory, input_i, mask_i)
                j_obs = jnp.sum(per_step_costs)
                return j_bg + j_obs

            result = optx.minimise(
                fn=cost,
                solver=self.minimiser,
                y0=self.prior_mean,
                args=None,
                max_steps=self.max_steps,
                adjoint=self.minimiser_adjoint,
                throw=False,
            )
            return result.value

        return jax.vmap(_one)(batch.input, batch.mask)

    def as_analysis_step(self) -> _StrongFourDVarAnalysisStep:
        return _StrongFourDVarAnalysisStep(self)

WeakFourDVar

Bases: Module

Weak-constraint 4DVar with augmented control vector.

Attributes:

Name Type Description
forward Any

pipekit_cycle.ForwardModel (the free model \(M_t^\text{free}\)).

obs_op Any

Observation operator.

prior_mean Float[Array, N]

Background \(x_b\) — initial state (N,).

prior_cov_op AbstractLinearOperator

\(B\).

obs_cov_op AbstractLinearOperator

\(R\).

model_err_cov_op AbstractLinearOperator

\(Q\) — covariance of the per-step model error \(\eta_t\). Defaults to identity scaled by a small variance if not supplied (effectively strong-constraint with a tiny relaxation).

minimiser AbstractMinimiser
minimiser_adjoint AbstractAdjoint
max_steps int

Examples:

Trivial dynamics and a single timestep: no model-error steps remain (\(T = 0\)) and the analysis initial state reduces to the 3DVar / BLUE answer \(x_0^* = y / 2\) for \(B = R = I\).

>>> import jax, jax.numpy as jnp, lineax as lx, vardax
>>> class Identity:
...     dt = 1.0
...
...     def step(self, x, dt):
...         return x
>>> eye = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((3,), jnp.float32))
>>> weak = vardax.WeakFourDVar(
...     forward=Identity(),
...     obs_op=vardax.MaskedIdentity(),
...     prior_mean=jnp.zeros(3),
...     prior_cov_op=eye,
...     obs_cov_op=eye,
...     model_err_cov_op=eye,
... )
>>> batch = vardax.Batch1D(input=jnp.ones((1, 1, 3)), mask=jnp.ones((1, 1, 3)))
>>> x0_star, eta_star = weak(batch)
>>> x0_star.shape, eta_star.shape
((1, 3), (1, 0, 3))
Source code in src/vardax/_src/models/weak_fourdvar.py
class WeakFourDVar(eqx.Module):
    r"""Weak-constraint 4DVar with augmented control vector.

    Attributes:
        forward: ``pipekit_cycle.ForwardModel`` (the free model
            $M_t^\text{free}$).
        obs_op: Observation operator.
        prior_mean: Background $x_b$ — initial state ``(N,)``.
        prior_cov_op: $B$.
        obs_cov_op: $R$.
        model_err_cov_op: $Q$ — covariance of the per-step model
            error $\eta_t$. Defaults to identity scaled by a
            small variance if not supplied (effectively
            strong-constraint with a tiny relaxation).
        minimiser: As in [`StrongFourDVar`][vardax.StrongFourDVar].
        minimiser_adjoint: As in [`StrongFourDVar`][vardax.StrongFourDVar].
        max_steps: As in [`StrongFourDVar`][vardax.StrongFourDVar].

    Examples:
        Trivial dynamics and a single timestep: no model-error steps
        remain ($T = 0$) and the analysis initial state reduces to the
        3DVar / BLUE answer $x_0^* = y / 2$ for $B = R = I$.

        >>> import jax, jax.numpy as jnp, lineax as lx, vardax
        >>> class Identity:
        ...     dt = 1.0
        ...
        ...     def step(self, x, dt):
        ...         return x
        >>> eye = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((3,), jnp.float32))
        >>> weak = vardax.WeakFourDVar(
        ...     forward=Identity(),
        ...     obs_op=vardax.MaskedIdentity(),
        ...     prior_mean=jnp.zeros(3),
        ...     prior_cov_op=eye,
        ...     obs_cov_op=eye,
        ...     model_err_cov_op=eye,
        ... )
        >>> batch = vardax.Batch1D(input=jnp.ones((1, 1, 3)), mask=jnp.ones((1, 1, 3)))
        >>> x0_star, eta_star = weak(batch)
        >>> x0_star.shape, eta_star.shape
        ((1, 3), (1, 0, 3))
    """

    forward: Any
    obs_op: Any
    prior_mean: Float[Array, N]  # ty:ignore[unresolved-reference]
    prior_cov_op: lx.AbstractLinearOperator
    obs_cov_op: lx.AbstractLinearOperator
    model_err_cov_op: lx.AbstractLinearOperator
    minimiser: optx.AbstractMinimiser = eqx.field(static=True)
    minimiser_adjoint: optx.AbstractAdjoint = eqx.field(static=True)
    max_steps: int = eqx.field(static=True, default=200)

    def __init__(
        self,
        forward: Any,
        obs_op: Any,
        prior_mean: Float[Array, N],  # ty:ignore[unresolved-reference]
        prior_cov_op: lx.AbstractLinearOperator,
        obs_cov_op: lx.AbstractLinearOperator,
        model_err_cov_op: lx.AbstractLinearOperator,
        minimiser: optx.AbstractMinimiser | None = None,
        minimiser_adjoint: optx.AbstractAdjoint | None = None,
        max_steps: int = 200,
    ) -> None:
        self.forward = forward
        self.obs_op = obs_op
        self.prior_mean = prior_mean
        self.prior_cov_op = prior_cov_op
        self.obs_cov_op = obs_cov_op
        self.model_err_cov_op = model_err_cov_op
        self.minimiser = minimiser or optx.BFGS(rtol=1e-6, atol=1e-6)
        self.minimiser_adjoint = minimiser_adjoint or optx.ImplicitAdjoint()
        self.max_steps = max_steps

    def _rollout_with_error(
        self,
        x_0: Float[Array, N],  # ty:ignore[unresolved-reference]
        etas: Float[Array, "T N"],
    ) -> Float[Array, "T_plus_1 N"]:
        r"""Roll out with per-step model error.

        Each step applies $x_t = M_t^\text{free}(x_{t-1}) + \eta_t$.
        """
        dt = self.forward.dt

        def step_fn(
            x: Float[Array, N],  # ty:ignore[unresolved-reference]
            eta_t: Float[Array, N],  # ty:ignore[unresolved-reference]
        ) -> tuple[Float[Array, N], Float[Array, N]]:  # ty:ignore[unresolved-reference]
            x_new = self.forward.step(x, dt) + eta_t
            return x_new, x_new

        _, trajectory = jax.lax.scan(step_fn, x_0, etas)
        return jnp.concatenate([x_0[None, :], trajectory], axis=0)

    def __call__(
        self,
        batch: Batch1D,
    ) -> tuple[Float[Array, "B N"], Float[Array, "B T N"]]:
        """Weak-4DVar analysis.

        Returns ``(x_0_star, eta_star)`` — the analysis initial state
        and the per-step model-error trajectory (one row per
        timestep, total ``T``).
        """
        # T_plus_1 timesteps in the batch ⇒ T model-error steps.
        T_plus_1 = batch.input.shape[1]
        T = T_plus_1 - 1
        N = batch.input.shape[2]

        def _one(
            input_i: Float[Array, "T_plus_1 N"], mask_i: Float[Array, "T_plus_1 N"]
        ):
            def cost(
                control: Float[Array, "T_plus_1 N"], _args: Any
            ) -> Float[Array, ""]:
                x_0 = control[0]
                etas = control[1:]

                # Background term
                dx = x_0 - self.prior_mean
                B_inv_dx = lx.linear_solve(
                    self.prior_cov_op,
                    dx,
                    solver=lx.CG(atol=1e-6, rtol=1e-6),
                ).value
                j_bg = 0.5 * jnp.sum(dx * B_inv_dx)

                # Model-error term
                def _eta_cost(eta_t):
                    Q_inv_eta = lx.linear_solve(
                        self.model_err_cov_op,
                        eta_t,
                        solver=lx.CG(atol=1e-6, rtol=1e-6),
                    ).value
                    return 0.5 * jnp.sum(eta_t * Q_inv_eta)

                j_eta = jnp.sum(jax.vmap(_eta_cost)(etas))

                # Observation term
                trajectory = self._rollout_with_error(x_0, etas)

                def _per_step(x_t, y_t, m_t):
                    y_pred = (
                        self.obs_op(x_t, mask=m_t)
                        if _accepts_mask(self.obs_op)
                        else self.obs_op(x_t)
                    )
                    residual = m_t * (y_t - y_pred)
                    R_inv_r = lx.linear_solve(
                        self.obs_cov_op,
                        residual,
                        solver=lx.CG(atol=1e-6, rtol=1e-6),
                    ).value
                    return 0.5 * jnp.sum(residual * R_inv_r)

                per_step_costs = jax.vmap(_per_step)(trajectory, input_i, mask_i)
                j_obs = jnp.sum(per_step_costs)

                return j_bg + j_obs + j_eta

            # Initial guess: x_0 at the background, etas at zero.
            init_etas = jnp.zeros((T, N))
            y0 = jnp.concatenate([self.prior_mean[None, :], init_etas], axis=0)

            result = optx.minimise(
                fn=cost,
                solver=self.minimiser,
                y0=y0,
                args=None,
                max_steps=self.max_steps,
                adjoint=self.minimiser_adjoint,
                throw=False,
            )
            control_star = result.value
            return control_star[0], control_star[1:]

        x0_star, eta_star = jax.vmap(_one)(batch.input, batch.mask)
        return x0_star, eta_star

    def as_analysis_step(self) -> _WeakFourDVarAnalysisStep:
        return _WeakFourDVarAnalysisStep(self)

IncrementalFourDVar

Bases: Module

Operational incremental 4DVar (Decision D11).

Functionally equivalent to StrongFourDVar (same problem, same answer in the converged limit) but with a specialised inner solver: Gauss-Newton outer iterations and CG inner iterations on the linearised cost. Use this for production / long-window 4DVar.

Attributes:

Name Type Description
forward Any

pipekit_cycle.ForwardModel.

obs_op Any

Observation operator.

prior_mean Float[Array, N]

Background \(x_b\) — initial state (N,).

prior_cov_op AbstractLinearOperator

\(B\).

obs_cov_op AbstractLinearOperator

\(R\).

config IncrementalConfig

Examples:

Trivial dynamics (\(M_t(x) = x\)) and a single timestep reduce the problem to 3DVar, so with \(B = R = I\) the converged analysis is \(x_0^* = y / 2\).

>>> import jax, jax.numpy as jnp, lineax as lx, vardax
>>> class Identity:
...     dt = 1.0
...
...     def step(self, x, dt):
...         return x
>>> eye = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((3,), jnp.float32))
>>> inc = vardax.IncrementalFourDVar(
...     forward=Identity(),
...     obs_op=vardax.MaskedIdentity(),
...     prior_mean=jnp.zeros(3),
...     prior_cov_op=eye,
...     obs_cov_op=eye,
... )
>>> batch = vardax.Batch1D(input=jnp.ones((1, 1, 3)), mask=jnp.ones((1, 1, 3)))
>>> xa = inc(batch)
>>> xa.shape
(1, 3)
>>> bool(jnp.allclose(xa, 0.5, atol=1e-2))
True
Source code in src/vardax/_src/models/incremental_fourdvar.py
class IncrementalFourDVar(eqx.Module):
    r"""Operational incremental 4DVar (Decision D11).

    Functionally equivalent to [`StrongFourDVar`][vardax.StrongFourDVar]
    (same problem, same answer in the converged limit) but with a
    specialised inner solver: Gauss-Newton outer iterations and CG
    inner iterations on the linearised cost. Use this for production /
    long-window 4DVar.

    Attributes:
        forward: ``pipekit_cycle.ForwardModel``.
        obs_op: Observation operator.
        prior_mean: Background $x_b$ — initial state ``(N,)``.
        prior_cov_op: $B$.
        obs_cov_op: $R$.
        config: [`IncrementalConfig`][vardax.IncrementalConfig].

    Examples:
        Trivial dynamics ($M_t(x) = x$) and a single timestep reduce
        the problem to 3DVar, so with $B = R = I$ the converged
        analysis is $x_0^* = y / 2$.

        >>> import jax, jax.numpy as jnp, lineax as lx, vardax
        >>> class Identity:
        ...     dt = 1.0
        ...
        ...     def step(self, x, dt):
        ...         return x
        >>> eye = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((3,), jnp.float32))
        >>> inc = vardax.IncrementalFourDVar(
        ...     forward=Identity(),
        ...     obs_op=vardax.MaskedIdentity(),
        ...     prior_mean=jnp.zeros(3),
        ...     prior_cov_op=eye,
        ...     obs_cov_op=eye,
        ... )
        >>> batch = vardax.Batch1D(input=jnp.ones((1, 1, 3)), mask=jnp.ones((1, 1, 3)))
        >>> xa = inc(batch)
        >>> xa.shape
        (1, 3)
        >>> bool(jnp.allclose(xa, 0.5, atol=1e-2))
        True
    """

    forward: Any
    obs_op: Any
    prior_mean: Float[Array, N]  # ty:ignore[unresolved-reference]
    prior_cov_op: lx.AbstractLinearOperator
    obs_cov_op: lx.AbstractLinearOperator
    config: IncrementalConfig

    def __init__(
        self,
        forward: Any,
        obs_op: Any,
        prior_mean: Float[Array, N],  # ty:ignore[unresolved-reference]
        prior_cov_op: lx.AbstractLinearOperator,
        obs_cov_op: lx.AbstractLinearOperator,
        config: IncrementalConfig | None = None,
    ) -> None:
        self.forward = forward
        self.obs_op = obs_op
        self.prior_mean = prior_mean
        self.prior_cov_op = prior_cov_op
        self.obs_cov_op = obs_cov_op
        self.config = config or IncrementalConfig()

    def _rollout(
        self,
        x_0: Float[Array, N],  # ty:ignore[unresolved-reference]
        n_steps: int,
    ) -> Float[Array, "T_plus_1 N"]:
        dt = self.forward.dt

        def step_fn(x, _):
            x_new = self.forward.step(x, dt)
            return x_new, x_new

        _, trajectory = jax.lax.scan(step_fn, x_0, None, length=n_steps)
        return jnp.concatenate([x_0[None, :], trajectory], axis=0)

    def __call__(self, batch: Batch1D) -> Float[Array, "B N"]:
        """Incremental-4DVar analysis: GN outer + CG inner."""
        T = batch.input.shape[1] - 1

        def _one(input_i, mask_i):
            x_b = self.prior_mean

            def cost_grad_for_outer(x_iter):
                """Gradient of the full nonlinear cost at ``x_iter``."""

                def cost_full(x_0):
                    dx = x_0 - self.prior_mean
                    B_inv_dx = lx.linear_solve(
                        self.prior_cov_op,
                        dx,
                        solver=lx.CG(atol=1e-6, rtol=1e-6),
                    ).value
                    j_bg = 0.5 * jnp.sum(dx * B_inv_dx)
                    trajectory = self._rollout(x_0, n_steps=T)

                    def _per_step(x_t, y_t, m_t):
                        y_pred = (
                            self.obs_op(x_t, mask=m_t)
                            if _accepts_mask(self.obs_op)
                            else self.obs_op(x_t)
                        )
                        residual = m_t * (y_t - y_pred)
                        R_inv_r = lx.linear_solve(
                            self.obs_cov_op,
                            residual,
                            solver=lx.CG(atol=1e-6, rtol=1e-6),
                        ).value
                        return 0.5 * jnp.sum(residual * R_inv_r)

                    per_step = jax.vmap(_per_step)(trajectory, input_i, mask_i)
                    return j_bg + jnp.sum(per_step)

                return jax.grad(cost_full)(x_iter)

            # Gauss-Newton outer iterations on the nonlinear cost.
            for _ in range(self.config.n_outer):
                # Build the observation operator J: x_0 → masked
                # trajectory observations (shape (T_plus_1, N)).
                def J_full(x_0):
                    trajectory = self._rollout(x_0, n_steps=T)

                    def _per_step(x_t, m_t):
                        return (
                            self.obs_op(x_t, mask=m_t)
                            if _accepts_mask(self.obs_op)
                            else self.obs_op(x_t)
                        ) * m_t

                    return jax.vmap(_per_step)(trajectory, mask_i)

                # GN Hessian as a hand-rolled mat-vec on the state
                # space (N,). Uses jax.linearize + jax.vjp to apply
                # H^T R^{-1} H v plus B^{-1} v without forming dense
                # matrices.
                _, jvp_fn = jax.linearize(J_full, x_b)
                _, vjp_fn = jax.vjp(J_full, x_b)

                def gn_hessian_matvec(v, _jvp_fn=jvp_fn, _vjp_fn=vjp_fn):
                    Hv = _jvp_fn(v)

                    # Apply per-time-step R^{-1}: vmap CG over time
                    def _r_inv(y):
                        return lx.linear_solve(
                            self.obs_cov_op,
                            y,
                            solver=lx.CG(atol=1e-6, rtol=1e-6),
                        ).value

                    R_inv_Hv = jax.vmap(_r_inv)(Hv)
                    (HtR_inv_H_v,) = _vjp_fn(R_inv_Hv)
                    B_inv_v = lx.linear_solve(
                        self.prior_cov_op,
                        v,
                        solver=lx.CG(atol=1e-6, rtol=1e-6),
                    ).value
                    return HtR_inv_H_v + B_inv_v

                gn_op = lx.FunctionLinearOperator(
                    gn_hessian_matvec,
                    jax.ShapeDtypeStruct(x_b.shape, x_b.dtype),
                    lx.positive_semidefinite_tag,
                )

                # RHS = -grad J(x_b)
                rhs = -cost_grad_for_outer(x_b)

                # Inner CG solve: GN_Hessian @ dx = rhs
                dx_star = lx.linear_solve(
                    gn_op,
                    rhs,
                    solver=lx.CG(
                        atol=self.config.cg_atol,
                        rtol=self.config.cg_rtol,
                        max_steps=self.config.n_inner,
                    ),
                ).value
                x_b = x_b + dx_star

            return x_b

        return jax.vmap(_one)(batch.input, batch.mask)

    def as_analysis_step(self) -> _IncrementalFourDVarAnalysisStep:
        return _IncrementalFourDVarAnalysisStep(self)

IncrementalConfig

Bases: Module

Configuration for IncrementalFourDVar.

Attributes:

Name Type Description
n_outer int

Number of Gauss-Newton outer iterations (typical: 3).

n_inner int

Max CG iterations per outer (typical: 20-50).

cg_atol float

CG absolute tolerance.

cg_rtol float

CG relative tolerance.

Source code in src/vardax/_src/models/incremental_fourdvar.py
class IncrementalConfig(eqx.Module):
    """Configuration for ``IncrementalFourDVar``.

    Attributes:
        n_outer: Number of Gauss-Newton outer iterations (typical: 3).
        n_inner: Max CG iterations per outer (typical: 20-50).
        cg_atol: CG absolute tolerance.
        cg_rtol: CG relative tolerance.
    """

    n_outer: int = eqx.field(static=True, default=3)
    n_inner: int = eqx.field(static=True, default=30)
    cg_atol: float = eqx.field(static=True, default=1e-6)
    cg_rtol: float = eqx.field(static=True, default=1e-6)

Learned solvers — 4DVarNet

End-to-end-trainable 4DVar: the variational cost is kept explicit, but the inner-loop descent direction is produced by a ConvLSTM gradient modulator instead of a hand-tuned optimiser. The 1D variant operates on Batch1D (e.g. Lorenz-96 trajectories); the 2D variant on Batch2D / Batch2DMultivar fields (e.g. SSH reconstruction). The inner-loop iteration functions live on the Costs, Priors & Solvers page; training utilities on Training & Adjoints.

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

FourDVarNet1D

Bases: Module

End-to-end 4DVarNet model for 1-D spatiotemporal reconstruction.

Minimises the variational cost

\[ J(x) = \|\mathbf{m} \odot (x - y)\|^2 + \lambda \|x - \varphi(x)\|^2 \]

using n_solver_steps learned gradient steps modulated by a ConvLSTM, with the differentiation strategy selected by solver_adjoint.

Attributes:

Name Type Description
n_solver_steps int

Number of solver iterations to unroll.

alpha float

Gradient step-size.

prior_weight float

Weight \(\lambda\) for the prior cost term.

solver_adjoint AbstractAdjoint

optimistix.AbstractAdjoint selecting the differentiation strategy. Defaults to RecursiveCheckpointAdjoint (standard backprop). Use OneStepAdjoint() for O(1)-memory training (Bolte et al. 2023) or ImplicitAdjoint() for fixed-point projection.

prior BilinAEPrior1D

BilinAEPrior1D learned prior.

grad_mod ConvLSTMGradMod1D

ConvLSTMGradMod1D learned gradient modulator.

Source code in src/vardax/_src/model.py
class FourDVarNet1D(eqx.Module):
    r"""End-to-end 4DVarNet model for 1-D spatiotemporal reconstruction.

    Minimises the variational cost

    $$
    J(x) = \|\mathbf{m} \odot (x - y)\|^2 + \lambda \|x - \varphi(x)\|^2
    $$

    using ``n_solver_steps`` learned gradient steps modulated by a ConvLSTM,
    with the differentiation strategy selected by ``solver_adjoint``.

    Attributes:
        n_solver_steps: Number of solver iterations to unroll.
        alpha: Gradient step-size.
        prior_weight: Weight $\lambda$ for the prior cost term.
        solver_adjoint: ``optimistix.AbstractAdjoint`` selecting the
            differentiation strategy. Defaults to
            ``RecursiveCheckpointAdjoint`` (standard backprop). Use
            ``OneStepAdjoint()`` for O(1)-memory training (Bolte et al.
            2023) or ``ImplicitAdjoint()`` for fixed-point projection.
        prior: BilinAEPrior1D learned prior.
        grad_mod: ConvLSTMGradMod1D learned gradient modulator.
    """

    n_solver_steps: int = eqx.field(static=True)
    alpha: float = eqx.field(static=True)
    prior_weight: float = eqx.field(static=True)
    solver_adjoint: optx.AbstractAdjoint = eqx.field(static=True)
    prior: BilinAEPrior1D
    grad_mod: ConvLSTMGradMod1D

    def __init__(
        self,
        state_dim: int,
        n_time: int,
        latent_dim: int = 32,
        hidden_dim: int = 64,
        n_solver_steps: int = 15,
        alpha: float = 0.2,
        prior_weight: float = 1.0,
        solver_adjoint: optx.AbstractAdjoint | None = None,
        *,
        key: PRNGKeyArray,
    ) -> None:
        self.n_solver_steps = n_solver_steps
        self.alpha = alpha
        self.prior_weight = prior_weight
        self.solver_adjoint = solver_adjoint or _default_adjoint()
        k_prior, k_grad = jax.random.split(key)
        self.prior = BilinAEPrior1D(
            state_dim=state_dim,
            latent_dim=latent_dim,
            n_time=n_time,
            key=k_prior,
        )
        self.grad_mod = ConvLSTMGradMod1D(
            state_channels=n_time,
            hidden_dim=hidden_dim,
            key=k_grad,
        )

    def __call__(self, batch: Batch1D) -> Float[Array, "B T N"]:
        """Run the solver and return the final state estimate.

        Dispatches on ``solver_adjoint`` to pick the differentiation
        path.
        """
        if isinstance(self.solver_adjoint, KStepAdjoint):
            return self._call_one_step(batch)
        if isinstance(self.solver_adjoint, optx.ImplicitAdjoint):
            return self._call_implicit(batch)
        # Default: optimistix.RecursiveCheckpointAdjoint() and anything
        # else falls through to the unrolled backprop path.
        return self._call_unrolled(batch)

    def _call_unrolled(self, batch: Batch1D) -> Float[Array, "B T N"]:
        b, _, n = batch.input.shape
        x = batch.input * batch.mask
        lstm = LSTMState1D.zeros(b, self.grad_mod.hidden_dim, n)

        for _ in range(self.n_solver_steps):

            def cost_fn(x_):
                obs_diff = batch.mask * (x_ - batch.input)
                j_obs = jnp.sum(obs_diff**2)
                j_prior = self.prior_weight * jnp.sum((x_ - self.prior(x_)) ** 2)
                return j_obs + j_prior

            grad = jax.grad(cost_fn)(x)
            update, lstm = self.grad_mod(grad, x, lstm)
            x = x - self.alpha * update

        return x

    def _call_one_step(self, batch: Batch1D) -> Float[Array, "B T N"]:
        from .solver import one_step_solve_4dvarnet_1d

        return one_step_solve_4dvarnet_1d(
            batch,
            self.prior,
            self.grad_mod,
            n_steps=self.n_solver_steps,
            hidden_dim=self.grad_mod.hidden_dim,
            alpha=self.alpha,
            prior_weight=self.prior_weight,
            k=getattr(self.solver_adjoint, "k", 1),
        )

    def _call_implicit(self, batch: Batch1D) -> Float[Array, "B T N"]:
        # Uses the fixed-point projection solver (prior only; grad_mod
        # and alpha are not used in this mode).
        from .solver import solve_4dvarnet_1d_fixedpoint

        return solve_4dvarnet_1d_fixedpoint(
            batch, self.prior, n_fp_steps=self.n_solver_steps
        )

    def as_analysis_step(self) -> _FourDVarNet1DAnalysisStep:
        """Adapt to ``pipekit_cycle.AnalysisStep`` (Decision D8)."""
        return _FourDVarNet1DAnalysisStep(self)

as_analysis_step

as_analysis_step() -> _FourDVarNet1DAnalysisStep

Adapt to pipekit_cycle.AnalysisStep (Decision D8).

Source code in src/vardax/_src/model.py
def as_analysis_step(self) -> _FourDVarNet1DAnalysisStep:
    """Adapt to ``pipekit_cycle.AnalysisStep`` (Decision D8)."""
    return _FourDVarNet1DAnalysisStep(self)

FourDVarNet2D

Bases: Module

End-to-end 4DVarNet model for 2-D spatiotemporal reconstruction.

Attributes:

Name Type Description
n_solver_steps int

Number of solver iterations to unroll.

alpha float

Gradient step-size.

prior_weight float

Weight for the prior cost term.

solver_adjoint AbstractAdjoint

optimistix.AbstractAdjoint selecting the differentiation strategy. Defaults to RecursiveCheckpointAdjoint (standard backprop). ImplicitAdjoint is not yet implemented for 2-D models.

prior BilinAEPrior2D

BilinAEPrior2D learned prior.

grad_mod ConvLSTMGradMod2D

ConvLSTMGradMod2D learned gradient modulator.

Source code in src/vardax/_src/model.py
class FourDVarNet2D(eqx.Module):
    """End-to-end 4DVarNet model for 2-D spatiotemporal reconstruction.

    Attributes:
        n_solver_steps: Number of solver iterations to unroll.
        alpha: Gradient step-size.
        prior_weight: Weight for the prior cost term.
        solver_adjoint: ``optimistix.AbstractAdjoint`` selecting the
            differentiation strategy. Defaults to
            ``RecursiveCheckpointAdjoint`` (standard backprop).
            ``ImplicitAdjoint`` is not yet implemented for 2-D models.
        prior: BilinAEPrior2D learned prior.
        grad_mod: ConvLSTMGradMod2D learned gradient modulator.
    """

    n_solver_steps: int = eqx.field(static=True)
    alpha: float = eqx.field(static=True)
    prior_weight: float = eqx.field(static=True)
    solver_adjoint: optx.AbstractAdjoint = eqx.field(static=True)
    prior: BilinAEPrior2D
    grad_mod: ConvLSTMGradMod2D

    def __init__(
        self,
        n_time: int,
        height: int,
        width: int,
        latent_dim: int = 64,
        hidden_dim: int = 64,
        n_solver_steps: int = 15,
        alpha: float = 0.2,
        prior_weight: float = 1.0,
        solver_adjoint: optx.AbstractAdjoint | None = None,
        *,
        key: PRNGKeyArray,
    ) -> None:
        self.n_solver_steps = n_solver_steps
        self.alpha = alpha
        self.prior_weight = prior_weight
        self.solver_adjoint = solver_adjoint or _default_adjoint()
        k_prior, k_grad = jax.random.split(key)
        self.prior = BilinAEPrior2D(
            latent_dim=latent_dim,
            n_time=n_time,
            height=height,
            width=width,
            key=k_prior,
        )
        self.grad_mod = ConvLSTMGradMod2D(
            state_channels=n_time,
            hidden_dim=hidden_dim,
            key=k_grad,
        )

    def __call__(self, batch: Batch2D) -> Float[Array, "B T H W"]:
        """Run the solver and return the final state estimate.

        Dispatches on ``solver_adjoint`` to pick the differentiation
        path.
        """
        if isinstance(self.solver_adjoint, KStepAdjoint):
            return self._call_one_step(batch)
        if isinstance(self.solver_adjoint, optx.ImplicitAdjoint):
            return self._call_implicit(batch)
        return self._call_unrolled(batch)

    def _call_unrolled(self, batch: Batch2D) -> Float[Array, "B T H W"]:
        b, _, h, w = batch.input.shape
        x = batch.input * batch.mask
        lstm = LSTMState2D.zeros(b, self.grad_mod.hidden_dim, h, w)

        for _ in range(self.n_solver_steps):

            def cost_fn(x_):
                obs_diff = batch.mask * (x_ - batch.input)
                j_obs = jnp.sum(obs_diff**2)
                j_prior = self.prior_weight * jnp.sum((x_ - self.prior(x_)) ** 2)
                return j_obs + j_prior

            grad = jax.grad(cost_fn)(x)
            update, lstm = self.grad_mod(grad, x, lstm)
            x = x - self.alpha * update

        return x

    def _call_one_step(self, batch: Batch2D) -> Float[Array, "B T H W"]:
        from .solver import one_step_solve_4dvarnet_2d

        return one_step_solve_4dvarnet_2d(
            batch,
            self.prior,
            self.grad_mod,
            n_steps=self.n_solver_steps,
            hidden_dim=self.grad_mod.hidden_dim,
            alpha=self.alpha,
            prior_weight=self.prior_weight,
            k=getattr(self.solver_adjoint, "k", 1),
        )

    def _call_implicit(self, batch: Batch2D) -> Float[Array, "B T H W"]:
        raise NotImplementedError(
            "ImplicitAdjoint for FourDVarNet2D is not yet implemented; "
            "use RecursiveCheckpointAdjoint or OneStepAdjoint."
        )

    def as_analysis_step(self) -> _FourDVarNet2DAnalysisStep:
        """Adapt to ``pipekit_cycle.AnalysisStep`` (Decision D8)."""
        return _FourDVarNet2DAnalysisStep(self)

as_analysis_step

as_analysis_step() -> _FourDVarNet2DAnalysisStep

Adapt to pipekit_cycle.AnalysisStep (Decision D8).

Source code in src/vardax/_src/model.py
def as_analysis_step(self) -> _FourDVarNet2DAnalysisStep:
    """Adapt to ``pipekit_cycle.AnalysisStep`` (Decision D8)."""
    return _FourDVarNet2DAnalysisStep(self)