Skip to content

Costs, Priors & Solvers

The functional core under the model classes: pure cost functions that score a candidate state against observations and prior, prior modules that supply the regularisation term, and the inner-loop solver functions that drive the 4DVarNet iteration. The model classes are thin, stateful-looking wrappers over these pieces — drop down to this layer when building custom methods or instrumenting the optimisation.

Cost functions

The variational cost \(J(x) = J_\text{obs}(x) + J_\text{prior}(x)\) and its gradient, with the observation and prior terms also available separately (decomposed_loss returns them unsummed for logging). The _1d / _2d suffixes match the Batch1D / Batch2D carriers. See 3DVar and strong-constraint 4DVar for the math each term implements.

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

variational_cost

variational_cost(
    x: Float[Array, ...],
    batch: Batch1D,
    prior_fn: Callable[..., Any],
    alpha_obs: float = 0.5,
    alpha_prior: float = 0.5,
) -> Float[Array, ""]

Compute the variational cost \(U(x)\).

\[ U(x) = \alpha_{obs} \|m \odot (x - y)\|^2 + \alpha_{prior} \|x - \varphi(x)\|^2 \]

Parameters:

Name Type Description Default
x Float[Array, ...]

Current state estimate.

required
batch Batch1D

Observed data batch with input (y) and mask (m).

required
prior_fn Callable[..., Any]

Callable x -> x_prior.

required
alpha_obs float

Weight for the observation term (default 0.5).

0.5
alpha_prior float

Weight for the prior term (default 0.5).

0.5

Returns:

Type Description
Float[Array, '']

Scalar cost value.

Examples:

With the trivial IdentityPrior the prior term vanishes, leaving the weighted observation MSE.

>>> import jax.numpy as jnp, vardax
>>> batch = vardax.Batch1D(input=jnp.zeros((1, 2, 4)), mask=jnp.ones((1, 2, 4)))
>>> x = jnp.ones((1, 2, 4))
>>> float(vardax.variational_cost(x, batch, vardax.IdentityPrior()))
0.5
Source code in src/vardax/_src/costs.py
def variational_cost(
    x: Float[Array, ...],
    batch: Batch1D,
    prior_fn: Callable[..., Any],
    alpha_obs: float = 0.5,
    alpha_prior: float = 0.5,
) -> Float[Array, ""]:
    r"""Compute the variational cost $U(x)$.

    $$
    U(x) = \alpha_{obs} \|m \odot (x - y)\|^2
          + \alpha_{prior} \|x - \varphi(x)\|^2
    $$

    Args:
        x: Current state estimate.
        batch: Observed data batch with ``input`` (``y``) and ``mask``
            (``m``).
        prior_fn: Callable ``x -> x_prior``.
        alpha_obs: Weight for the observation term (default ``0.5``).
        alpha_prior: Weight for the prior term (default ``0.5``).

    Returns:
        Scalar cost value.

    Examples:
        With the trivial [`IdentityPrior`][vardax.IdentityPrior] the
        prior term vanishes, leaving the weighted observation MSE.

        >>> import jax.numpy as jnp, vardax
        >>> batch = vardax.Batch1D(input=jnp.zeros((1, 2, 4)), mask=jnp.ones((1, 2, 4)))
        >>> x = jnp.ones((1, 2, 4))
        >>> float(vardax.variational_cost(x, batch, vardax.IdentityPrior()))
        0.5
    """
    obs_diff = batch.mask * (x - batch.input)
    j_obs = jnp.mean(obs_diff**2)
    j_prior = jnp.mean((x - prior_fn(x)) ** 2)
    return alpha_obs * j_obs + alpha_prior * j_prior

variational_cost_grad

variational_cost_grad(
    x: Float[Array, ...],
    batch: Batch1D,
    prior_fn: Callable[..., Any],
    alpha_obs: float = 0.5,
    alpha_prior: float = 0.5,
) -> Float[Array, ...]

Gradient of variational_cost w.r.t. x.

Parameters:

Name Type Description Default
x Float[Array, ...]

Current state estimate.

required
batch Batch1D

Observed data batch.

required
prior_fn Callable[..., Any]

Callable x -> x_prior.

required
alpha_obs float

Weight for the observation term.

0.5
alpha_prior float

Weight for the prior term.

0.5

Returns:

Type Description
Float[Array, ...]

Gradient array with the same shape as x.

Source code in src/vardax/_src/costs.py
def variational_cost_grad(
    x: Float[Array, ...],
    batch: Batch1D,
    prior_fn: Callable[..., Any],
    alpha_obs: float = 0.5,
    alpha_prior: float = 0.5,
) -> Float[Array, ...]:
    """Gradient of [`variational_cost`][vardax.variational_cost] w.r.t. ``x``.

    Args:
        x: Current state estimate.
        batch: Observed data batch.
        prior_fn: Callable ``x -> x_prior``.
        alpha_obs: Weight for the observation term.
        alpha_prior: Weight for the prior term.

    Returns:
        Gradient array with the same shape as ``x``.
    """
    return jax.grad(variational_cost)(x, batch, prior_fn, alpha_obs, alpha_prior)

obs_cost_1d

obs_cost_1d(
    state: Float[Array, "B T N"],
    obs: Float[Array, "B T N"],
    mask: Float[Array, "B T N"],
) -> Float[Array, ""]

Observation cost for 1-D data.

Computes the masked mean-squared error between the state and observations:

\[ J_{obs} = \frac{1}{|\Omega|} \sum_{i \in \Omega} (x_i - y_i)^2 \]

where \(\Omega\) is the set of observed locations (mask == 1).

Parameters:

Name Type Description Default
state Float[Array, 'B T N']

Current state estimate of shape (B, T, N).

required
obs Float[Array, 'B T N']

Observations of shape (B, T, N).

required
mask Float[Array, 'B T N']

Binary observation mask of shape (B, T, N). A value of 1 indicates an observed location.

required

Returns:

Type Description
Float[Array, '']

Scalar observation cost.

Examples:

>>> import jax.numpy as jnp
>>> from vardax import obs_cost_1d
>>> state = jnp.ones((1, 1, 4))
>>> obs = jnp.zeros((1, 1, 4))
>>> mask = jnp.ones((1, 1, 4))
>>> float(obs_cost_1d(state, obs, mask))
1.0
Source code in src/vardax/_src/costs.py
def obs_cost_1d(
    state: Float[Array, "B T N"],
    obs: Float[Array, "B T N"],
    mask: Float[Array, "B T N"],
) -> Float[Array, ""]:
    r"""Observation cost for 1-D data.

    Computes the masked mean-squared error between the state and observations:

    $$
    J_{obs} = \frac{1}{|\Omega|} \sum_{i \in \Omega} (x_i - y_i)^2
    $$

    where $\Omega$ is the set of observed locations (``mask == 1``).

    Args:
        state: Current state estimate of shape ``(B, T, N)``.
        obs: Observations of shape ``(B, T, N)``.
        mask: Binary observation mask of shape ``(B, T, N)``.
            A value of ``1`` indicates an observed location.

    Returns:
        Scalar observation cost.

    Examples:
        >>> import jax.numpy as jnp
        >>> from vardax import obs_cost_1d
        >>> state = jnp.ones((1, 1, 4))
        >>> obs = jnp.zeros((1, 1, 4))
        >>> mask = jnp.ones((1, 1, 4))
        >>> float(obs_cost_1d(state, obs, mask))
        1.0
    """
    diff = mask * (state - obs)
    return jnp.mean(diff**2)

obs_cost_2d

obs_cost_2d(
    state: Float[Array, "B T H W"],
    obs: Float[Array, "B T H W"],
    mask: Float[Array, "B T H W"],
) -> Float[Array, ""]

Observation cost for 2-D data.

Computes the masked mean-squared error between the state and observations:

\[ J_{obs} = \frac{1}{|\Omega|} \sum_{i \in \Omega} (x_i - y_i)^2 \]

where \(\Omega\) is the set of observed locations (mask == 1).

Parameters:

Name Type Description Default
state Float[Array, 'B T H W']

Current state estimate of shape (B, T, H, W).

required
obs Float[Array, 'B T H W']

Observations of shape (B, T, H, W).

required
mask Float[Array, 'B T H W']

Binary observation mask of shape (B, T, H, W).

required

Returns:

Type Description
Float[Array, '']

Scalar observation cost.

Source code in src/vardax/_src/costs.py
def obs_cost_2d(
    state: Float[Array, "B T H W"],
    obs: Float[Array, "B T H W"],
    mask: Float[Array, "B T H W"],
) -> Float[Array, ""]:
    r"""Observation cost for 2-D data.

    Computes the masked mean-squared error between the state and observations:

    $$
    J_{obs} = \frac{1}{|\Omega|} \sum_{i \in \Omega} (x_i - y_i)^2
    $$

    where $\Omega$ is the set of observed locations (``mask == 1``).

    Args:
        state: Current state estimate of shape ``(B, T, H, W)``.
        obs: Observations of shape ``(B, T, H, W)``.
        mask: Binary observation mask of shape ``(B, T, H, W)``.

    Returns:
        Scalar observation cost.
    """
    diff = mask * (state - obs)
    return jnp.mean(diff**2)

prior_cost

prior_cost(
    state: Float[Array, ...],
    prior_reconstruction: Float[Array, ...],
) -> Float[Array, ""]

Prior cost based on learned autoencoder reconstruction.

Computes the mean-squared error between the state and its reconstruction through the learned prior (autoencoder):

\[ J_{prior} = \|x - \varphi(x)\|^2 \]

Parameters:

Name Type Description Default
state Float[Array, ...]

Current state estimate of arbitrary shape.

required
prior_reconstruction Float[Array, ...]

Autoencoder reconstruction of the state, same shape as state.

required

Returns:

Type Description
Float[Array, '']

Scalar prior cost.

Source code in src/vardax/_src/costs.py
def prior_cost(
    state: Float[Array, ...],
    prior_reconstruction: Float[Array, ...],
) -> Float[Array, ""]:
    r"""Prior cost based on learned autoencoder reconstruction.

    Computes the mean-squared error between the state and its reconstruction
    through the learned prior (autoencoder):

    $$
    J_{prior} = \|x - \varphi(x)\|^2
    $$

    Args:
        state: Current state estimate of arbitrary shape.
        prior_reconstruction: Autoencoder reconstruction of the state,
            same shape as ``state``.

    Returns:
        Scalar prior cost.
    """
    return jnp.mean((state - prior_reconstruction) ** 2)

decomposed_loss

decomposed_loss(
    x: Float[Array, ...],
    batch: Batch1D,
    prior_fn: Callable[..., Any],
    alpha_obs: float = 0.5,
    alpha_prior: float = 0.5,
) -> dict[str, Float[Array, ""]]

Compute the decomposed variational loss.

Returns individual observation and prior components alongside the total, matching the ModelLoss pattern from the legacy codebase.

Parameters:

Name Type Description Default
x Float[Array, ...]

Current state estimate.

required
batch Batch1D

Observed data batch.

required
prior_fn Callable[..., Any]

Callable x -> x_prior.

required
alpha_obs float

Weight for the observation term.

0.5
alpha_prior float

Weight for the prior term.

0.5

Returns:

Type Description
dict[str, Float[Array, '']]

Dictionary with keys "obs", "prior", and "total".

Source code in src/vardax/_src/costs.py
def decomposed_loss(
    x: Float[Array, ...],
    batch: Batch1D,
    prior_fn: Callable[..., Any],
    alpha_obs: float = 0.5,
    alpha_prior: float = 0.5,
) -> dict[str, Float[Array, ""]]:
    """Compute the decomposed variational loss.

    Returns individual observation and prior components alongside the
    total, matching the ``ModelLoss`` pattern from the legacy codebase.

    Args:
        x: Current state estimate.
        batch: Observed data batch.
        prior_fn: Callable ``x -> x_prior``.
        alpha_obs: Weight for the observation term.
        alpha_prior: Weight for the prior term.

    Returns:
        Dictionary with keys ``"obs"``, ``"prior"``, and ``"total"``.
    """
    obs_diff = batch.mask * (x - batch.input)
    obs = alpha_obs * jnp.mean(obs_diff**2)
    prior = alpha_prior * jnp.mean((x - prior_fn(x)) ** 2)
    return {"obs": obs, "prior": prior, "total": obs + prior}

Priors

Implementations of the Prior Protocol. IdentityPrior gives plain Tikhonov regularisation; L63Prior / L96Prior encode Lorenz dynamics as a model-consistency penalty; the autoencoder priors (MLP, convolutional, and bilinear variants in 1D, 2D, and 2D-multivariate) are learned priors that penalise distance from a trained reconstruction manifold.

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

IdentityPrior

Bases: Module

Trivial identity prior: \(\varphi(x) = x\).

Zero parameters. Useful as a pure obs-driven baseline (the prior cost vanishes everywhere) and as a sanity-check building block in the linear-Gaussian agreement tests.

Examples:

>>> import jax.numpy as jnp
>>> from vardax import IdentityPrior
>>> prior = IdentityPrior()
>>> x = jnp.arange(6.0).reshape(1, 2, 3)
>>> bool(jnp.all(prior(x) == x))
True
Source code in src/vardax/_src/priors.py
class IdentityPrior(eqx.Module):
    r"""Trivial identity prior: $\varphi(x) = x$.

    Zero parameters. Useful as a pure obs-driven baseline (the prior
    cost vanishes everywhere) and as a sanity-check building block in
    the linear-Gaussian agreement tests.

    Examples:
        >>> import jax.numpy as jnp
        >>> from vardax import IdentityPrior
        >>> prior = IdentityPrior()
        >>> x = jnp.arange(6.0).reshape(1, 2, 3)
        >>> bool(jnp.all(prior(x) == x))
        True
    """

    def __call__(self, x: Float[Array, ...]) -> Float[Array, ...]:
        """Return the input unchanged."""
        return x

L63Prior

Bases: Module

Learned prior for the Lorenz-63 system.

A simple MLP autoencoder designed for the 3-dimensional Lorenz-63 attractor. The state is treated as a flat vector of length 3.

Attributes:

Name Type Description
latent_dim

Dimensionality of the latent code (default 3).

hidden_dim

Hidden layer width.

state_dim

Dimensionality of the state vector (default 3).

Source code in src/vardax/_src/priors.py
class L63Prior(eqx.Module):
    """Learned prior for the Lorenz-63 system.

    A simple MLP autoencoder designed for the 3-dimensional Lorenz-63
    attractor. The state is treated as a flat vector of length ``3``.

    Attributes:
        latent_dim: Dimensionality of the latent code (default ``3``).
        hidden_dim: Hidden layer width.
        state_dim: Dimensionality of the state vector (default ``3``).
    """

    enc1: eqx.nn.Linear
    enc2: eqx.nn.Linear
    dec1: eqx.nn.Linear
    dec2: eqx.nn.Linear

    def __init__(
        self,
        latent_dim: int = 3,
        hidden_dim: int = 32,
        state_dim: int = 3,
        *,
        key: PRNGKeyArray,
    ) -> None:
        k1, k2, k3, k4 = jax.random.split(key, 4)
        self.enc1 = eqx.nn.Linear(state_dim, hidden_dim, key=k1)
        self.enc2 = eqx.nn.Linear(hidden_dim, latent_dim, key=k2)
        self.dec1 = eqx.nn.Linear(latent_dim, hidden_dim, key=k3)
        self.dec2 = eqx.nn.Linear(hidden_dim, state_dim, key=k4)

    def __call__(self, x: Float[Array, "B N"]) -> Float[Array, "B N"]:
        z = jnp.tanh(jax.vmap(self.enc1)(x))
        z = jax.vmap(self.enc2)(z)
        h = jnp.tanh(jax.vmap(self.dec1)(z))
        return jax.vmap(self.dec2)(h)

L96Prior

Bases: Module

Learned prior for the Lorenz-96 system.

A simple MLP autoencoder designed for the N-dimensional Lorenz-96 attractor. The state is treated as a flat vector of length N.

Attributes:

Name Type Description
latent_dim

Dimensionality of the latent code.

hidden_dim

Hidden layer width.

state_dim

Dimensionality of the state vector.

Source code in src/vardax/_src/priors.py
class L96Prior(eqx.Module):
    """Learned prior for the Lorenz-96 system.

    A simple MLP autoencoder designed for the N-dimensional Lorenz-96
    attractor. The state is treated as a flat vector of length ``N``.

    Attributes:
        latent_dim: Dimensionality of the latent code.
        hidden_dim: Hidden layer width.
        state_dim: Dimensionality of the state vector.
    """

    enc1: eqx.nn.Linear
    enc2: eqx.nn.Linear
    dec1: eqx.nn.Linear
    dec2: eqx.nn.Linear

    def __init__(
        self,
        latent_dim: int = 16,
        hidden_dim: int = 64,
        state_dim: int = 40,
        *,
        key: PRNGKeyArray,
    ) -> None:
        k1, k2, k3, k4 = jax.random.split(key, 4)
        self.enc1 = eqx.nn.Linear(state_dim, hidden_dim, key=k1)
        self.enc2 = eqx.nn.Linear(hidden_dim, latent_dim, key=k2)
        self.dec1 = eqx.nn.Linear(latent_dim, hidden_dim, key=k3)
        self.dec2 = eqx.nn.Linear(hidden_dim, state_dim, key=k4)

    def __call__(self, x: Float[Array, "B N"]) -> Float[Array, "B N"]:
        z = jnp.tanh(jax.vmap(self.enc1)(x))
        z = jax.vmap(self.enc2)(z)
        h = jnp.tanh(jax.vmap(self.dec1)(z))
        return jax.vmap(self.dec2)(h)

MLPAEPrior1D

Bases: Module

MLP autoencoder prior for 1-D data.

Attributes:

Name Type Description
state_dim

Spatial size of the input (N).

latent_dim

Dimensionality of the latent code.

hidden_dim

Hidden layer width.

n_time int

Number of time steps (T).

Source code in src/vardax/_src/priors.py
class MLPAEPrior1D(eqx.Module):
    """MLP autoencoder prior for 1-D data.

    Attributes:
        state_dim: Spatial size of the input (``N``).
        latent_dim: Dimensionality of the latent code.
        hidden_dim: Hidden layer width.
        n_time: Number of time steps (``T``).
    """

    n_time: int = eqx.field(static=True)
    enc1: eqx.nn.Linear
    enc2: eqx.nn.Linear
    dec1: eqx.nn.Linear
    dec2: eqx.nn.Linear

    def __init__(
        self,
        state_dim: int,
        latent_dim: int,
        hidden_dim: int = 64,
        n_time: int = 1,
        *,
        key: PRNGKeyArray,
    ) -> None:
        self.n_time = n_time
        in_features = n_time * state_dim
        k1, k2, k3, k4 = jax.random.split(key, 4)
        self.enc1 = eqx.nn.Linear(in_features, hidden_dim, key=k1)
        self.enc2 = eqx.nn.Linear(hidden_dim, latent_dim, key=k2)
        self.dec1 = eqx.nn.Linear(latent_dim, hidden_dim, key=k3)
        self.dec2 = eqx.nn.Linear(hidden_dim, in_features, key=k4)

    def __call__(self, x: Float[Array, "B T N"]) -> Float[Array, "B T N"]:
        b, t, n = x.shape
        x_flat = x.reshape(b, t * n)
        z = jax.nn.relu(jax.vmap(self.enc1)(x_flat))
        z = jax.vmap(self.enc2)(z)
        h = jax.nn.relu(jax.vmap(self.dec1)(z))
        out = jax.vmap(self.dec2)(h)
        return out.reshape(b, t, n)

ConvAEPrior1D

Bases: Module

Convolutional autoencoder prior for 1-D spatially-structured data.

Uses circular (periodic) padding suitable for systems with periodic boundary conditions such as Lorenz-96. Operates on inputs of shape (B, T, N) where N is the spatial dimension.

Attributes:

Name Type Description
latent_channels

Number of channels in the latent representation.

kernel_size int

Convolution kernel size (must be a positive odd integer).

n_time int

Number of time steps T; used as the input/output channels and validated against the runtime input shape.

Source code in src/vardax/_src/priors.py
class ConvAEPrior1D(eqx.Module):
    """Convolutional autoencoder prior for 1-D spatially-structured data.

    Uses circular (periodic) padding suitable for systems with periodic
    boundary conditions such as Lorenz-96. Operates on inputs of shape
    ``(B, T, N)`` where ``N`` is the spatial dimension.

    Attributes:
        latent_channels: Number of channels in the latent representation.
        kernel_size: Convolution kernel size (must be a positive odd integer).
        n_time: Number of time steps ``T``; used as the input/output channels
            and validated against the runtime input shape.
    """

    kernel_size: int = eqx.field(static=True)
    n_time: int = eqx.field(static=True)
    _enc_conv: eqx.nn.Conv1d
    _dec_conv: eqx.nn.Conv1d

    def __init__(
        self,
        latent_channels: int = 16,
        kernel_size: int = 3,
        n_time: int = 1,
        *,
        key: PRNGKeyArray,
    ) -> None:
        if kernel_size <= 0 or kernel_size % 2 == 0:
            raise ValueError(
                f"kernel_size must be a positive odd integer, got {kernel_size}."
            )
        self.kernel_size = kernel_size
        self.n_time = n_time
        k_enc, k_dec = jax.random.split(key)
        # eqx.nn.Conv1d uses channels-first: (in_channels, length)
        self._enc_conv = eqx.nn.Conv1d(
            in_channels=n_time,
            out_channels=latent_channels,
            kernel_size=kernel_size,
            padding=0,  # we apply circular padding manually
            key=k_enc,
        )
        self._dec_conv = eqx.nn.Conv1d(
            in_channels=latent_channels,
            out_channels=n_time,
            kernel_size=kernel_size,
            padding=0,
            key=k_dec,
        )

    def __call__(self, x: Float[Array, "B T N"]) -> Float[Array, "B T N"]:
        t = x.shape[1]
        if t != self.n_time:
            raise ValueError(
                f"Input time dimension {t} does not match n_time={self.n_time}."
            )
        # eqx Conv1d expects channels-first: (in_channels=T, length=N)
        # x is already (B, T, N) so vmap over the batch dim works directly.
        pad = self.kernel_size // 2

        def _forward(xi: Float[Array, "T N"]) -> Float[Array, "T N"]:
            # Circular padding along spatial axis (axis=1)
            if pad > 0:
                h = jnp.concatenate([xi[:, -pad:], xi, xi[:, :pad]], axis=1)
            else:
                h = xi
            h = self._enc_conv(h)
            h = jax.nn.relu(h)
            if pad > 0:
                h = jnp.concatenate([h[:, -pad:], h, h[:, :pad]], axis=1)
            return self._dec_conv(h)

        return jax.vmap(_forward)(x)

BilinAEPrior1D

Bases: Module

Bilinear autoencoder prior for 1-D data.

The encoder maps the input to a low-dimensional latent code; the decoder reconstructs the original space. The prior cost is ||x - decode(encode(x))||^2.

Attributes:

Name Type Description
state_dim int

Spatial size of the input (N).

latent_dim int

Dimensionality of the latent code.

n_time int

Number of time steps (T).

Examples:

>>> import jax, jax.numpy as jnp
>>> from vardax import BilinAEPrior1D
>>> prior = BilinAEPrior1D(
...     state_dim=4, latent_dim=2, n_time=3, key=jax.random.PRNGKey(0)
... )
>>> prior(jnp.ones((2, 3, 4))).shape
(2, 3, 4)
Source code in src/vardax/_src/priors.py
class BilinAEPrior1D(eqx.Module):
    """Bilinear autoencoder prior for 1-D data.

    The encoder maps the input to a low-dimensional latent code; the decoder
    reconstructs the original space. The prior cost is
    ``||x - decode(encode(x))||^2``.

    Attributes:
        state_dim: Spatial size of the input (``N``).
        latent_dim: Dimensionality of the latent code.
        n_time: Number of time steps (``T``).

    Examples:
        >>> import jax, jax.numpy as jnp
        >>> from vardax import BilinAEPrior1D
        >>> prior = BilinAEPrior1D(
        ...     state_dim=4, latent_dim=2, n_time=3, key=jax.random.PRNGKey(0)
        ... )
        >>> prior(jnp.ones((2, 3, 4))).shape
        (2, 3, 4)
    """

    state_dim: int = eqx.field(static=True)
    latent_dim: int = eqx.field(static=True)
    n_time: int = eqx.field(static=True)
    _bilin: _BilinearBlock
    _decode_dense: eqx.nn.Linear

    def __init__(
        self,
        state_dim: int,
        latent_dim: int,
        n_time: int = 1,
        *,
        key: PRNGKeyArray,
    ) -> None:
        self.state_dim = state_dim
        self.latent_dim = latent_dim
        self.n_time = n_time
        in_features = n_time * state_dim
        key_bilin, key_dec = jax.random.split(key)
        self._bilin = _BilinearBlock(in_features, latent_dim, key=key_bilin)
        self._decode_dense = eqx.nn.Linear(latent_dim, n_time * state_dim, key=key_dec)

    def __call__(self, x: Float[Array, "B T N"]) -> Float[Array, "B T N"]:
        b, t, n = x.shape
        x_flat = x.reshape(b, t * n)
        z = self._bilin(x_flat)
        out = jax.vmap(self._decode_dense)(z)
        return out.reshape(b, t, n)

    def encode(self, x: Float[Array, "B T N"]) -> Float[Array, "B Z"]:
        """Encode input to latent space."""
        b, t, n = x.shape
        x_flat = x.reshape(b, t * n)
        return self._bilin(x_flat)

    def decode(self, z: Float[Array, "B Z"]) -> Float[Array, "B T N"]:
        """Decode latent code to state space."""
        out = jax.vmap(self._decode_dense)(z)
        return out.reshape(-1, self.n_time, self.state_dim)

encode

encode(x: Float[Array, 'B T N']) -> Float[Array, 'B Z']

Encode input to latent space.

Source code in src/vardax/_src/priors.py
def encode(self, x: Float[Array, "B T N"]) -> Float[Array, "B Z"]:
    """Encode input to latent space."""
    b, t, n = x.shape
    x_flat = x.reshape(b, t * n)
    return self._bilin(x_flat)

decode

decode(z: Float[Array, 'B Z']) -> Float[Array, 'B T N']

Decode latent code to state space.

Source code in src/vardax/_src/priors.py
def decode(self, z: Float[Array, "B Z"]) -> Float[Array, "B T N"]:
    """Decode latent code to state space."""
    out = jax.vmap(self._decode_dense)(z)
    return out.reshape(-1, self.n_time, self.state_dim)

BilinAEPrior2D

Bases: Module

Bilinear autoencoder prior for 2-D data.

Attributes:

Name Type Description
latent_dim

Dimensionality of the latent code.

n_time int

Number of time steps (T).

height int

Spatial height H.

width int

Spatial width W.

Source code in src/vardax/_src/priors.py
class BilinAEPrior2D(eqx.Module):
    """Bilinear autoencoder prior for 2-D data.

    Attributes:
        latent_dim: Dimensionality of the latent code.
        n_time: Number of time steps (``T``).
        height: Spatial height ``H``.
        width: Spatial width ``W``.
    """

    n_time: int = eqx.field(static=True)
    height: int = eqx.field(static=True)
    width: int = eqx.field(static=True)
    _bilin: _BilinearBlock
    _decode_dense: eqx.nn.Linear

    def __init__(
        self,
        latent_dim: int,
        n_time: int,
        height: int,
        width: int,
        *,
        key: PRNGKeyArray,
    ) -> None:
        self.n_time = n_time
        self.height = height
        self.width = width
        in_features = n_time * height * width
        k_bilin, k_dec = jax.random.split(key)
        self._bilin = _BilinearBlock(in_features, latent_dim, key=k_bilin)
        self._decode_dense = eqx.nn.Linear(latent_dim, in_features, key=k_dec)

    def __call__(self, x: Float[Array, "B T H W"]) -> Float[Array, "B T H W"]:
        b, t, h, w = x.shape
        x_flat = x.reshape(b, t * h * w)
        z = self._bilin(x_flat)
        out = jax.vmap(self._decode_dense)(z)
        return out.reshape(b, t, h, w)

BilinAEPrior2DMultivar

Bases: Module

Bilinear autoencoder prior for 2-D multivariate data.

Attributes:

Name Type Description
latent_dim

Dimensionality of the latent code.

n_time int

Number of time steps (T).

n_channels int

Number of channels C.

height int

Spatial height H.

width int

Spatial width W.

Source code in src/vardax/_src/priors.py
class BilinAEPrior2DMultivar(eqx.Module):
    """Bilinear autoencoder prior for 2-D multivariate data.

    Attributes:
        latent_dim: Dimensionality of the latent code.
        n_time: Number of time steps (``T``).
        n_channels: Number of channels ``C``.
        height: Spatial height ``H``.
        width: Spatial width ``W``.
    """

    n_time: int = eqx.field(static=True)
    n_channels: int = eqx.field(static=True)
    height: int = eqx.field(static=True)
    width: int = eqx.field(static=True)
    _bilin: _BilinearBlock
    _decode_dense: eqx.nn.Linear

    def __init__(
        self,
        latent_dim: int,
        n_time: int,
        n_channels: int,
        height: int,
        width: int,
        *,
        key: PRNGKeyArray,
    ) -> None:
        self.n_time = n_time
        self.n_channels = n_channels
        self.height = height
        self.width = width
        in_features = n_time * n_channels * height * width
        k_bilin, k_dec = jax.random.split(key)
        self._bilin = _BilinearBlock(in_features, latent_dim, key=k_bilin)
        self._decode_dense = eqx.nn.Linear(latent_dim, in_features, key=k_dec)

    def __call__(self, x: Float[Array, "B T C H W"]) -> Float[Array, "B T C H W"]:
        b, t, c, h, w = x.shape
        x_flat = x.reshape(b, t * c * h * w)
        z = self._bilin(x_flat)
        out = jax.vmap(self._decode_dense)(z)
        return out.reshape(b, t, c, h, w)

4DVarNet inner-loop solvers

The unrolled (and fixed-point) inner loop of 4DVarNet, exposed as pure functions over an explicit SolverState: initialise with init_solver_state_*, advance one modulated-gradient step with solver_step_* (or fp_solver_step_1d for the fixed-point formulation), or run the whole loop with solve_4dvarnet_*. The one_step_* variants pair with OneStepAdjoint for memory-frugal training.

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

SolverState1D

Bases: Module

Mutable solver state for 1-D problems.

Attributes:

Name Type Description
x Float[Array, 'B T N']

Current state estimate of shape (B, T, N).

lstm LSTMState1D

Current LSTM hidden/cell state for the gradient modulator.

step int

Current iteration index.

Source code in src/vardax/_src/solver.py
class SolverState1D(eqx.Module):
    """Mutable solver state for 1-D problems.

    Attributes:
        x: Current state estimate of shape ``(B, T, N)``.
        lstm: Current LSTM hidden/cell state for the gradient modulator.
        step: Current iteration index.
    """

    x: Float[Array, "B T N"]
    lstm: LSTMState1D
    step: int

SolverState2D

Bases: Module

Mutable solver state for 2-D problems.

Attributes:

Name Type Description
x Float[Array, 'B T H W']

Current state estimate of shape (B, T, H, W).

lstm LSTMState2D

Current LSTM hidden/cell state for the gradient modulator.

step int

Current iteration index.

Source code in src/vardax/_src/solver.py
class SolverState2D(eqx.Module):
    """Mutable solver state for 2-D problems.

    Attributes:
        x: Current state estimate of shape ``(B, T, H, W)``.
        lstm: Current LSTM hidden/cell state for the gradient modulator.
        step: Current iteration index.
    """

    x: Float[Array, "B T H W"]
    lstm: LSTMState2D
    step: int

init_solver_state_1d

init_solver_state_1d(
    batch: Batch1D, hidden_dim: int
) -> SolverState1D

Initialise a 1-D solver state from a batch.

Parameters:

Name Type Description Default
batch Batch1D

Input batch. The initial state is set to the masked input (zeros where unobserved).

required
hidden_dim int

Hidden dimension of the ConvLSTM gradient modulator.

required

Returns:

Type Description
SolverState1D

Zero-initialised SolverState1D.

Source code in src/vardax/_src/solver.py
def init_solver_state_1d(
    batch: Batch1D,
    hidden_dim: int,
) -> SolverState1D:
    """Initialise a 1-D solver state from a batch.

    Args:
        batch: Input batch.  The initial state is set to the masked input
            (zeros where unobserved).
        hidden_dim: Hidden dimension of the ConvLSTM gradient modulator.

    Returns:
        Zero-initialised [`SolverState1D`][vardax.SolverState1D].
    """
    b, _, n = batch.input.shape
    x0 = batch.input * batch.mask
    lstm = LSTMState1D.zeros(b, hidden_dim, n)
    return SolverState1D(x=x0, lstm=lstm, step=0)

init_solver_state_2d

init_solver_state_2d(
    batch: Batch2D, hidden_dim: int
) -> SolverState2D

Initialise a 2-D solver state from a batch.

Parameters:

Name Type Description Default
batch Batch2D

Input batch. The initial state is set to the masked input.

required
hidden_dim int

Hidden dimension of the ConvLSTM gradient modulator.

required

Returns:

Type Description
SolverState2D

Zero-initialised SolverState2D.

Source code in src/vardax/_src/solver.py
def init_solver_state_2d(
    batch: Batch2D,
    hidden_dim: int,
) -> SolverState2D:
    """Initialise a 2-D solver state from a batch.

    Args:
        batch: Input batch.  The initial state is set to the masked input.
        hidden_dim: Hidden dimension of the ConvLSTM gradient modulator.

    Returns:
        Zero-initialised [`SolverState2D`][vardax.SolverState2D].
    """
    b, _, h, w = batch.input.shape
    x0 = batch.input * batch.mask
    lstm = LSTMState2D.zeros(b, hidden_dim, h, w)
    return SolverState2D(x=x0, lstm=lstm, step=0)

solver_step_1d

solver_step_1d(
    solver_state: SolverState1D,
    batch: Batch1D,
    prior_fn: Any,
    grad_mod_fn: Any,
    alpha: float = 1.0,
    prior_weight: float = 1.0,
) -> SolverState1D

Perform a single 1-D solver iteration.

Computes the gradient of the variational cost, then passes it through the learned gradient modulator to obtain a state update.

Parameters:

Name Type Description Default
solver_state SolverState1D

Current solver state.

required
batch Batch1D

Observed data batch.

required
prior_fn Any

Callable x -> x_prior (prior autoencoder forward pass).

required
grad_mod_fn Any

Callable (grad, x, lstm) -> (update, new_lstm).

required
alpha float

Step-size scaling factor.

1.0
prior_weight float

Weighting factor \(\lambda\) for the prior cost term.

1.0

Returns:

Type Description
SolverState1D

Updated SolverState1D.

Source code in src/vardax/_src/solver.py
def solver_step_1d(
    solver_state: SolverState1D,
    batch: Batch1D,
    prior_fn: Any,
    grad_mod_fn: Any,
    alpha: float = 1.0,
    prior_weight: float = 1.0,
) -> SolverState1D:
    r"""Perform a single 1-D solver iteration.

    Computes the gradient of the variational cost, then passes it through
    the learned gradient modulator to obtain a state update.

    Args:
        solver_state: Current solver state.
        batch: Observed data batch.
        prior_fn: Callable ``x -> x_prior`` (prior autoencoder forward pass).
        grad_mod_fn: Callable ``(grad, x, lstm) -> (update, new_lstm)``.
        alpha: Step-size scaling factor.
        prior_weight: Weighting factor $\lambda$ for the prior cost term.

    Returns:
        Updated [`SolverState1D`][vardax.SolverState1D].
    """
    x = solver_state.x

    def cost_fn(x_):
        x_prior = prior_fn(x_)
        obs_diff = batch.mask * (x_ - batch.input)
        j_obs = jnp.sum(obs_diff**2)
        j_prior = prior_weight * jnp.sum((x_ - x_prior) ** 2)
        return j_obs + j_prior

    grad = jax.grad(cost_fn)(x)
    update, new_lstm = grad_mod_fn(grad, x, solver_state.lstm)
    x_new = x - alpha * update

    return SolverState1D(x=x_new, lstm=new_lstm, step=solver_state.step + 1)

solver_step_2d

solver_step_2d(
    solver_state: SolverState2D,
    batch: Batch2D,
    prior_fn: Any,
    grad_mod_fn: Any,
    alpha: float = 1.0,
    prior_weight: float = 1.0,
) -> SolverState2D

Perform a single 2-D solver iteration.

Parameters:

Name Type Description Default
solver_state SolverState2D

Current solver state.

required
batch Batch2D

Observed data batch.

required
prior_fn Any

Callable x -> x_prior.

required
grad_mod_fn Any

Callable (grad, x, lstm) -> (update, new_lstm).

required
alpha float

Step-size scaling factor.

1.0
prior_weight float

Weighting factor \(\lambda\) for the prior cost term.

1.0

Returns:

Type Description
SolverState2D

Updated SolverState2D.

Source code in src/vardax/_src/solver.py
def solver_step_2d(
    solver_state: SolverState2D,
    batch: Batch2D,
    prior_fn: Any,
    grad_mod_fn: Any,
    alpha: float = 1.0,
    prior_weight: float = 1.0,
) -> SolverState2D:
    r"""Perform a single 2-D solver iteration.

    Args:
        solver_state: Current solver state.
        batch: Observed data batch.
        prior_fn: Callable ``x -> x_prior``.
        grad_mod_fn: Callable ``(grad, x, lstm) -> (update, new_lstm)``.
        alpha: Step-size scaling factor.
        prior_weight: Weighting factor $\lambda$ for the prior cost term.

    Returns:
        Updated [`SolverState2D`][vardax.SolverState2D].
    """
    x = solver_state.x

    def cost_fn(x_):
        x_prior = prior_fn(x_)
        obs_diff = batch.mask * (x_ - batch.input)
        j_obs = jnp.sum(obs_diff**2)
        j_prior = prior_weight * jnp.sum((x_ - x_prior) ** 2)
        return j_obs + j_prior

    grad = jax.grad(cost_fn)(x)
    update, new_lstm = grad_mod_fn(grad, x, solver_state.lstm)
    x_new = x - alpha * update

    return SolverState2D(x=x_new, lstm=new_lstm, step=solver_state.step + 1)

fp_solver_step_1d

fp_solver_step_1d(
    x: Float[Array, "B T N"], batch: Batch1D, prior_fn: Any
) -> Float[Array, "B T N"]

Perform a single 1-D fixed-point projection step.

Applies the prior projection then re-inserts observations at observed locations:

\[ x \leftarrow \varphi(x), \quad x \leftarrow m \odot y + (1 - m) \odot x \]

Parameters:

Name Type Description Default
x Float[Array, 'B T N']

Current state estimate of shape (B, T, N).

required
batch Batch1D

Observed data batch containing input (observations y) and mask (m).

required
prior_fn Any

Callable x -> x_prior (prior autoencoder forward pass).

required

Returns:

Type Description
Float[Array, 'B T N']

Updated state estimate of shape (B, T, N).

Source code in src/vardax/_src/solver.py
def fp_solver_step_1d(
    x: Float[Array, "B T N"],
    batch: Batch1D,
    prior_fn: Any,
) -> Float[Array, "B T N"]:
    r"""Perform a single 1-D fixed-point projection step.

    Applies the prior projection then re-inserts observations at observed
    locations:

    $$
    x \leftarrow \varphi(x), \quad
    x \leftarrow m \odot y + (1 - m) \odot x
    $$

    Args:
        x: Current state estimate of shape ``(B, T, N)``.
        batch: Observed data batch containing ``input`` (observations ``y``)
            and ``mask`` (``m``).
        prior_fn: Callable ``x -> x_prior`` (prior autoencoder forward pass).

    Returns:
        Updated state estimate of shape ``(B, T, N)``.
    """
    x_phi = prior_fn(x)
    return batch.mask * batch.input + (1 - batch.mask) * x_phi

solve_4dvarnet_1d

solve_4dvarnet_1d(
    batch: Batch1D,
    prior_fn: Any,
    grad_mod_fn: Any,
    n_steps: int,
    hidden_dim: int,
    alpha: float = 1.0,
) -> Float[Array, "B T N"]

Run the full 1-D 4DVarNet solver for n_steps iterations.

Parameters:

Name Type Description Default
batch Batch1D

Observed data batch.

required
prior_fn Any

Callable x -> x_prior.

required
grad_mod_fn Any

Callable (grad, x, lstm) -> (update, new_lstm).

required
n_steps int

Number of gradient-descent steps to unroll.

required
hidden_dim int

Hidden dimension of the ConvLSTM gradient modulator.

required
alpha float

Step-size scaling factor.

1.0

Returns:

Type Description
Float[Array, 'B T N']

Final state estimate of shape (B, T, N).

Source code in src/vardax/_src/solver.py
def solve_4dvarnet_1d(
    batch: Batch1D,
    prior_fn: Any,
    grad_mod_fn: Any,
    n_steps: int,
    hidden_dim: int,
    alpha: float = 1.0,
) -> Float[Array, "B T N"]:
    """Run the full 1-D 4DVarNet solver for ``n_steps`` iterations.

    Args:
        batch: Observed data batch.
        prior_fn: Callable ``x -> x_prior``.
        grad_mod_fn: Callable ``(grad, x, lstm) -> (update, new_lstm)``.
        n_steps: Number of gradient-descent steps to unroll.
        hidden_dim: Hidden dimension of the ConvLSTM gradient modulator.
        alpha: Step-size scaling factor.

    Returns:
        Final state estimate of shape ``(B, T, N)``.
    """
    state = init_solver_state_1d(batch, hidden_dim)
    for _ in range(n_steps):
        state = solver_step_1d(state, batch, prior_fn, grad_mod_fn, alpha)
    return state.x

solve_4dvarnet_2d

solve_4dvarnet_2d(
    batch: Batch2D,
    prior_fn: Any,
    grad_mod_fn: Any,
    n_steps: int,
    hidden_dim: int,
    alpha: float = 1.0,
) -> Float[Array, "B T H W"]

Run the full 2-D 4DVarNet solver for n_steps iterations.

Parameters:

Name Type Description Default
batch Batch2D

Observed data batch.

required
prior_fn Any

Callable x -> x_prior.

required
grad_mod_fn Any

Callable (grad, x, lstm) -> (update, new_lstm).

required
n_steps int

Number of gradient-descent steps to unroll.

required
hidden_dim int

Hidden dimension of the ConvLSTM gradient modulator.

required
alpha float

Step-size scaling factor.

1.0

Returns:

Type Description
Float[Array, 'B T H W']

Final state estimate of shape (B, T, H, W).

Source code in src/vardax/_src/solver.py
def solve_4dvarnet_2d(
    batch: Batch2D,
    prior_fn: Any,
    grad_mod_fn: Any,
    n_steps: int,
    hidden_dim: int,
    alpha: float = 1.0,
) -> Float[Array, "B T H W"]:
    """Run the full 2-D 4DVarNet solver for ``n_steps`` iterations.

    Args:
        batch: Observed data batch.
        prior_fn: Callable ``x -> x_prior``.
        grad_mod_fn: Callable ``(grad, x, lstm) -> (update, new_lstm)``.
        n_steps: Number of gradient-descent steps to unroll.
        hidden_dim: Hidden dimension of the ConvLSTM gradient modulator.
        alpha: Step-size scaling factor.

    Returns:
        Final state estimate of shape ``(B, T, H, W)``.
    """
    state = init_solver_state_2d(batch, hidden_dim)
    for _ in range(n_steps):
        state = solver_step_2d(state, batch, prior_fn, grad_mod_fn, alpha)
    return state.x

solve_4dvarnet_1d_fixedpoint

solve_4dvarnet_1d_fixedpoint(
    batch: Batch1D, prior_fn: Any, n_fp_steps: int
) -> Float[Array, "B T N"]

Run n_fp_steps fixed-point projection steps using jax.lax.scan.

Initialises the state from the masked observations, then iterates the fixed-point update fp_solver_step_1d for n_fp_steps steps.

Parameters:

Name Type Description Default
batch Batch1D

Observed data batch.

required
prior_fn Any

Callable x -> x_prior.

required
n_fp_steps int

Number of fixed-point iterations.

required

Returns:

Type Description
Float[Array, 'B T N']

Final state estimate of shape (B, T, N).

Source code in src/vardax/_src/solver.py
def solve_4dvarnet_1d_fixedpoint(
    batch: Batch1D,
    prior_fn: Any,
    n_fp_steps: int,
) -> Float[Array, "B T N"]:
    """Run ``n_fp_steps`` fixed-point projection steps using ``jax.lax.scan``.

    Initialises the state from the masked observations, then iterates the
    fixed-point update [`fp_solver_step_1d`][vardax.fp_solver_step_1d] for
    ``n_fp_steps`` steps.

    Args:
        batch: Observed data batch.
        prior_fn: Callable ``x -> x_prior``.
        n_fp_steps: Number of fixed-point iterations.

    Returns:
        Final state estimate of shape ``(B, T, N)``.
    """
    x0 = batch.input * batch.mask

    def scan_fn(carry: Float[Array, "B T N"], _: None) -> tuple:
        x_new = fp_solver_step_1d(carry, batch, prior_fn)
        return x_new, None

    x_final, _ = jax.lax.scan(scan_fn, x0, None, length=n_fp_steps)
    return x_final

one_step_solve_4dvarnet_1d

one_step_solve_4dvarnet_1d(
    batch: Batch1D,
    prior_fn: Any,
    grad_mod_fn: Any,
    n_steps: int,
    hidden_dim: int,
    alpha: float = 1.0,
    prior_weight: float = 1.0,
    k: int = 1,
) -> Float[Array, "B T N"]

Solve 4DVarNet-1D using k-step differentiation (Bolte et al., 2023).

Runs n_steps - k solver iterations with jax.lax.stop_gradient applied to the iterate, then performs k final steps through which gradients flow. This gives O(k) memory cost (k=1 matches implicit differentiation) while being as simple to implement as unrolled backprop.

Reference

Bolte, Pauwels & Vaiter (NeurIPS 2023). "One-step differentiation of iterative algorithms." https://arxiv.org/abs/2305.13768

Parameters:

Name Type Description Default
batch Batch1D

Observed data batch.

required
prior_fn Any

Callable x -> x_prior.

required
grad_mod_fn Any

Callable (grad, x, lstm) -> (update, new_lstm).

required
n_steps int

Total number of solver iterations (warmup = n_steps - k, then k differentiable steps).

required
hidden_dim int

Hidden dimension of the ConvLSTM gradient modulator.

required
alpha float

Step-size scaling factor.

1.0
prior_weight float

Weighting factor \(\lambda\) for the prior cost term.

1.0
k int

Number of trailing differentiable steps (clipped to n_steps).

1

Returns:

Type Description
Float[Array, 'B T N']

Final state estimate of shape (B, T, N).

Source code in src/vardax/_src/solver.py
def one_step_solve_4dvarnet_1d(
    batch: Batch1D,
    prior_fn: Any,
    grad_mod_fn: Any,
    n_steps: int,
    hidden_dim: int,
    alpha: float = 1.0,
    prior_weight: float = 1.0,
    k: int = 1,
) -> Float[Array, "B T N"]:
    r"""Solve 4DVarNet-1D using k-step differentiation (Bolte et al., 2023).

    Runs ``n_steps - k`` solver iterations with ``jax.lax.stop_gradient``
    applied to the iterate, then performs ``k`` final steps through which
    gradients flow.  This gives O(k) memory cost (k=1 matches implicit
    differentiation) while being as simple to implement as unrolled backprop.

    Reference:
        Bolte, Pauwels & Vaiter (NeurIPS 2023). "One-step differentiation of
        iterative algorithms." https://arxiv.org/abs/2305.13768

    Args:
        batch: Observed data batch.
        prior_fn: Callable ``x -> x_prior``.
        grad_mod_fn: Callable ``(grad, x, lstm) -> (update, new_lstm)``.
        n_steps: Total number of solver iterations (warmup = n_steps - k,
            then k differentiable steps).
        hidden_dim: Hidden dimension of the ConvLSTM gradient modulator.
        alpha: Step-size scaling factor.
        prior_weight: Weighting factor $\lambda$ for the prior cost term.
        k: Number of trailing differentiable steps (clipped to n_steps).

    Returns:
        Final state estimate of shape ``(B, T, N)``.
    """
    # --- warmup: run n_steps-k steps without tracking gradients ---
    state = init_solver_state_1d(batch, hidden_dim)
    live_steps = min(max(k, 1), n_steps) if n_steps >= 1 else 0
    warmup_steps = max(n_steps - live_steps, 0)
    for _ in range(warmup_steps):
        state = solver_step_1d(state, batch, prior_fn, grad_mod_fn, alpha, prior_weight)

    # detach the iterate so earlier steps don't contribute to the
    # gradient — but only when there was a warmup to detach: with
    # k >= n_steps the solve is fully differentiable, including
    # gradients with respect to the batch through the initial state.
    if warmup_steps > 0:
        state = SolverState1D(
            x=jax.lax.stop_gradient(state.x),
            lstm=jax.lax.stop_gradient(state.lstm),
            step=state.step,
        )

    # --- k differentiable steps ---
    for _ in range(live_steps):
        state = solver_step_1d(state, batch, prior_fn, grad_mod_fn, alpha, prior_weight)

    return state.x

one_step_solve_4dvarnet_2d

one_step_solve_4dvarnet_2d(
    batch: Batch2D,
    prior_fn: Any,
    grad_mod_fn: Any,
    n_steps: int,
    hidden_dim: int,
    alpha: float = 1.0,
    prior_weight: float = 1.0,
    k: int = 1,
) -> Float[Array, "B T H W"]

Solve 4DVarNet-2D using k-step differentiation (Bolte et al., 2023).

Runs n_steps - k solver iterations with jax.lax.stop_gradient applied to the iterate, then performs k final steps through which gradients flow. This gives O(k) memory cost (k=1 matches implicit differentiation) while being as simple to implement as unrolled backprop.

Reference

Bolte, Pauwels & Vaiter (NeurIPS 2023). "One-step differentiation of iterative algorithms." https://arxiv.org/abs/2305.13768

Parameters:

Name Type Description Default
batch Batch2D

Observed data batch.

required
prior_fn Any

Callable x -> x_prior.

required
grad_mod_fn Any

Callable (grad, x, lstm) -> (update, new_lstm).

required
n_steps int

Total number of solver iterations (warmup = n_steps - k, then k differentiable steps).

required
hidden_dim int

Hidden dimension of the ConvLSTM gradient modulator.

required
alpha float

Step-size scaling factor.

1.0
prior_weight float

Weighting factor \(\lambda\) for the prior cost term.

1.0
k int

Number of trailing differentiable steps (clipped to n_steps).

1

Returns:

Type Description
Float[Array, 'B T H W']

Final state estimate of shape (B, T, H, W).

Source code in src/vardax/_src/solver.py
def one_step_solve_4dvarnet_2d(
    batch: Batch2D,
    prior_fn: Any,
    grad_mod_fn: Any,
    n_steps: int,
    hidden_dim: int,
    alpha: float = 1.0,
    prior_weight: float = 1.0,
    k: int = 1,
) -> Float[Array, "B T H W"]:
    r"""Solve 4DVarNet-2D using k-step differentiation (Bolte et al., 2023).

    Runs ``n_steps - k`` solver iterations with ``jax.lax.stop_gradient``
    applied to the iterate, then performs ``k`` final steps through which
    gradients flow.  This gives O(k) memory cost (k=1 matches implicit
    differentiation) while being as simple to implement as unrolled backprop.

    Reference:
        Bolte, Pauwels & Vaiter (NeurIPS 2023). "One-step differentiation of
        iterative algorithms." https://arxiv.org/abs/2305.13768

    Args:
        batch: Observed data batch.
        prior_fn: Callable ``x -> x_prior``.
        grad_mod_fn: Callable ``(grad, x, lstm) -> (update, new_lstm)``.
        n_steps: Total number of solver iterations (warmup = n_steps - k,
            then k differentiable steps).
        hidden_dim: Hidden dimension of the ConvLSTM gradient modulator.
        alpha: Step-size scaling factor.
        prior_weight: Weighting factor $\lambda$ for the prior cost term.
        k: Number of trailing differentiable steps (clipped to n_steps).

    Returns:
        Final state estimate of shape ``(B, T, H, W)``.
    """
    # --- warmup: run n_steps-k steps without tracking gradients ---
    state = init_solver_state_2d(batch, hidden_dim)
    live_steps = min(max(k, 1), n_steps) if n_steps >= 1 else 0
    warmup_steps = max(n_steps - live_steps, 0)
    for _ in range(warmup_steps):
        state = solver_step_2d(state, batch, prior_fn, grad_mod_fn, alpha, prior_weight)

    # detach the iterate so earlier steps don't contribute to the
    # gradient — but only when there was a warmup to detach: with
    # k >= n_steps the solve is fully differentiable, including
    # gradients with respect to the batch through the initial state.
    if warmup_steps > 0:
        state = SolverState2D(
            x=jax.lax.stop_gradient(state.x),
            lstm=jax.lax.stop_gradient(state.lstm),
            step=state.step,
        )

    # --- k differentiable steps ---
    for _ in range(live_steps):
        state = solver_step_2d(state, batch, prior_fn, grad_mod_fn, alpha, prior_weight)

    return state.x