Skip to content

Training & Adjoints

Training a learned solver means differentiating through an inner optimisation loop, and how you do that determines memory cost and gradient quality. The adjoint strategies here make that choice explicit and swappable — see Adjoint Methods in the Mathematical Reference for the trade-offs. Around them sit the loss functions and train/eval steps for 4DVarNet and the amortized posteriors, and the ConvLSTM gradient modulators that 4DVarNet learns in place of a hand-tuned inner optimiser.

Adjoint strategies

Implementations of vardax.adjoints (Decision D15): full backpropagation with checkpointed memory (RecursiveCheckpointAdjoint), truncated one-step gradients (OneStepAdjoint, pairing with the one_step_solve_* solver functions), and implicit differentiation at a fixed point (ImplicitAdjoint, pairing with solve_4dvarnet_1d_fixedpoint). All three are also accessible via the vardax.adjoints submodule namespace. Use assert_adjoint_calibrated to verify a cheap adjoint against the exact one before trusting it.

RecursiveCheckpointAdjoint and ImplicitAdjoint are re-exported from optimistix for one-stop import; see the optimistix documentation for their full signatures:

  • vardax.RecursiveCheckpointAdjointoptimistix.RecursiveCheckpointAdjoint, exact reverse-mode backpropagation through the unrolled inner loop with binomial checkpointing (the default).
  • vardax.ImplicitAdjointoptimistix.ImplicitAdjoint, implicit-function-theorem differentiation at a fixed point; pair with solve_4dvarnet_1d_fixedpoint.

KStepAdjoint(k) is vardax's own truncated adjoint — warmup under stop_gradient, then k differentiable steps; OneStepAdjoint is the k=1 alias (Bolte, Pauwels & Vaiter, NeurIPS 2023):

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

KStepAdjoint

Bases: AbstractAdjoint

K-step differentiation: warmup under stop_gradient, then k live steps.

Use as the solver_adjoint argument to FourDVarNet1D / FourDVarNet2D:

from vardax.adjoints import KStepAdjoint

model = FourDVarNet1D(
    state_dim=N, n_time=T, ...,
    solver_adjoint=KStepAdjoint(k=3),
    key=key,
)

Attributes:

Name Type Description
k int

Number of trailing solver iterations that propagate gradients. Must be at least 1; values larger than n_solver_steps are clipped to a fully differentiable solve.

References

Bolte, J., Pauwels, E. & Vaiter, S. (2023). One-step differentiation of iterative algorithms. NeurIPS 36. arXiv:2305.13768.

Source code in src/vardax/_src/adjoints/k_step.py
class KStepAdjoint(optx.AbstractAdjoint):
    """K-step differentiation: warmup under ``stop_gradient``, then ``k`` live steps.

    Use as the ``solver_adjoint`` argument to
    [`FourDVarNet1D`][vardax.FourDVarNet1D] /
    [`FourDVarNet2D`][vardax.FourDVarNet2D]:

    ```python
    from vardax.adjoints import KStepAdjoint

    model = FourDVarNet1D(
        state_dim=N, n_time=T, ...,
        solver_adjoint=KStepAdjoint(k=3),
        key=key,
    )
    ```

    Attributes:
        k: Number of trailing solver iterations that propagate
            gradients. Must be at least 1; values larger than
            ``n_solver_steps`` are clipped to a fully differentiable
            solve.

    References:
        Bolte, J., Pauwels, E. & Vaiter, S. (2023). One-step
        differentiation of iterative algorithms. NeurIPS 36.
        [arXiv:2305.13768](https://arxiv.org/abs/2305.13768).
    """

    k: int = 1

    def __check_init__(self) -> None:
        if self.k < 1:
            raise ValueError(f"KStepAdjoint needs k >= 1; got k={self.k}.")

    def apply(
        self,
        primal_fn: Callable,
        rewrite_fn: Callable,
        inputs: PyTree,
        tags: frozenset[object],
    ) -> PyTree[Array]:
        """Not used by vardax's custom learned solver — dispatch happens
        in ``vardax._src.solver`` via ``isinstance`` on the adjoint type.

        Implementing this method for upstream ``optimistix.minimise``
        compatibility is tracked under the planned upstream
        contribution (Decision D6).
        """
        raise NotImplementedError(
            "KStepAdjoint is currently a marker / strategy selector for the "
            "FourDVarNet inner solver. Generic apply() support for "
            "optimistix.minimise is planned as part of the upstream "
            "contribution. Use it via FourDVarNet*(solver_adjoint=KStepAdjoint(k=...))."
        )

    def __repr__(self) -> str:  # pragma: no cover - cosmetic
        return f"KStepAdjoint(k={self.k})"

apply

apply(
    primal_fn: Callable,
    rewrite_fn: Callable,
    inputs: PyTree,
    tags: frozenset[object],
) -> PyTree[Array]

Not used by vardax's custom learned solver — dispatch happens in vardax._src.solver via isinstance on the adjoint type.

Implementing this method for upstream optimistix.minimise compatibility is tracked under the planned upstream contribution (Decision D6).

Source code in src/vardax/_src/adjoints/k_step.py
def apply(
    self,
    primal_fn: Callable,
    rewrite_fn: Callable,
    inputs: PyTree,
    tags: frozenset[object],
) -> PyTree[Array]:
    """Not used by vardax's custom learned solver — dispatch happens
    in ``vardax._src.solver`` via ``isinstance`` on the adjoint type.

    Implementing this method for upstream ``optimistix.minimise``
    compatibility is tracked under the planned upstream
    contribution (Decision D6).
    """
    raise NotImplementedError(
        "KStepAdjoint is currently a marker / strategy selector for the "
        "FourDVarNet inner solver. Generic apply() support for "
        "optimistix.minimise is planned as part of the upstream "
        "contribution. Use it via FourDVarNet*(solver_adjoint=KStepAdjoint(k=...))."
    )

OneStepAdjoint

Bases: KStepAdjoint

One-step differentiation (Bolte et al., 2023).

Run K - 1 warmup iterations of the inner solver with jax.lax.stop_gradient, then one differentiable step. Gives O(1) memory and is exact at the fixed point of the inner iteration.

Use as the solver_adjoint argument to FourDVarNet1D / FourDVarNet2D:

from vardax.adjoints import OneStepAdjoint

model = FourDVarNet1D(
    state_dim=N, n_time=T, ...,
    solver_adjoint=OneStepAdjoint(),
    key=key,
)
References

Bolte, J., Pauwels, E. & Vaiter, S. (2023). One-step differentiation of iterative algorithms. NeurIPS 36. arXiv:2305.13768.

Source code in src/vardax/_src/adjoints/one_step.py
class OneStepAdjoint(KStepAdjoint):
    """One-step differentiation (Bolte et al., 2023).

    Run ``K - 1`` warmup iterations of the inner solver with
    ``jax.lax.stop_gradient``, then one differentiable step. Gives
    O(1) memory and is exact at the fixed point of the inner
    iteration.

    Use as the ``solver_adjoint`` argument to
    [`FourDVarNet1D`][vardax.FourDVarNet1D] /
    [`FourDVarNet2D`][vardax.FourDVarNet2D]:

    ```python
    from vardax.adjoints import OneStepAdjoint

    model = FourDVarNet1D(
        state_dim=N, n_time=T, ...,
        solver_adjoint=OneStepAdjoint(),
        key=key,
    )
    ```

    References:
        Bolte, J., Pauwels, E. & Vaiter, S. (2023). One-step
        differentiation of iterative algorithms. NeurIPS 36.
        [arXiv:2305.13768](https://arxiv.org/abs/2305.13768).
    """

    k: int = 1

    def __repr__(self) -> str:  # pragma: no cover - cosmetic
        return "OneStepAdjoint()"

to_optimistix_adjoint

to_optimistix_adjoint(spec: Any) -> AbstractAdjoint

Map an adjoint spec onto an optimistix.AbstractAdjoint.

Mapping:

  • ImplicitAdjoint()optx.ImplicitAdjoint() — exact at a converged fixed point, O(1) memory, one Hessian linear solve.
  • RecursiveCheckpointAdjoint(checkpoints)optx.RecursiveCheckpointAdjoint(checkpoints) — exact, recomputing.
  • TruncatedAdjoint(k)KStepAdjoint(k) — warmup under stop_gradient, then k differentiable steps (k=1 is OneStepAdjoint).
  • DirectAdjoint / BacksolveAdjointValueError: the first is the plain unrolled default (pass optx.RecursiveCheckpointAdjoint() or nothing), the second only exists at the dynamics layer.

Parameters:

Name Type Description Default
spec Any

A spec from pipekit_cycle.adjoints or any structurally identical object. A ready-made optx.AbstractAdjoint passes through unchanged.

required

Returns:

Type Description
AbstractAdjoint

The corresponding optimistix adjoint instance.

Raises:

Type Description
ValueError

for layer-inappropriate or unrecognised specs.

Examples:

>>> import optimistix as optx
>>> from pipekit_cycle.adjoints import TruncatedAdjoint
>>> from vardax.adjoints import to_optimistix_adjoint
>>> to_optimistix_adjoint(TruncatedAdjoint(k=3))
KStepAdjoint(k=3)
>>> isinstance(
...     to_optimistix_adjoint(optx.ImplicitAdjoint()), optx.ImplicitAdjoint
... )
True
Source code in src/vardax/_src/adjoints/mapping.py
def to_optimistix_adjoint(spec: Any) -> optx.AbstractAdjoint:
    """Map an adjoint spec onto an ``optimistix.AbstractAdjoint``.

    Mapping:

    - ``ImplicitAdjoint()`` → ``optx.ImplicitAdjoint()`` — exact at a
      converged fixed point, O(1) memory, one Hessian linear solve.
    - ``RecursiveCheckpointAdjoint(checkpoints)`` →
      ``optx.RecursiveCheckpointAdjoint(checkpoints)`` — exact,
      recomputing.
    - ``TruncatedAdjoint(k)`` →
      [`KStepAdjoint(k)`][vardax.adjoints.KStepAdjoint] — warmup under
      ``stop_gradient``, then ``k`` differentiable steps (``k=1`` is
      [`OneStepAdjoint`][vardax.adjoints.OneStepAdjoint]).
    - ``DirectAdjoint`` / ``BacksolveAdjoint`` → ``ValueError``: the
      first is the plain unrolled default (pass
      ``optx.RecursiveCheckpointAdjoint()`` or nothing), the second
      only exists at the dynamics layer.

    Args:
        spec: A spec from ``pipekit_cycle.adjoints`` or any
            structurally identical object. A ready-made
            ``optx.AbstractAdjoint`` passes through unchanged.

    Returns:
        The corresponding optimistix adjoint instance.

    Raises:
        ValueError: for layer-inappropriate or unrecognised specs.

    Examples:
        >>> import optimistix as optx
        >>> from pipekit_cycle.adjoints import TruncatedAdjoint
        >>> from vardax.adjoints import to_optimistix_adjoint
        >>> to_optimistix_adjoint(TruncatedAdjoint(k=3))
        KStepAdjoint(k=3)
        >>> isinstance(
        ...     to_optimistix_adjoint(optx.ImplicitAdjoint()), optx.ImplicitAdjoint
        ... )
        True
    """
    if isinstance(spec, optx.AbstractAdjoint):
        return spec
    name = type(spec).__name__
    if name == "ImplicitAdjoint":
        return optx.ImplicitAdjoint()
    if name == "RecursiveCheckpointAdjoint":
        return optx.RecursiveCheckpointAdjoint(
            checkpoints=getattr(spec, "checkpoints", None)
        )
    if name == "TruncatedAdjoint":
        return KStepAdjoint(k=getattr(spec, "k", 1))
    if name in ("DirectAdjoint", "BacksolveAdjoint"):
        raise ValueError(
            f"{name} does not apply at the inner-solve layer: DirectAdjoint is "
            "the plain unrolled default (use optx.RecursiveCheckpointAdjoint() "
            "or omit solver_adjoint), and BacksolveAdjoint only exists at the "
            "dynamics layer (see pipekit_jax.DiffraxForwardModel)."
        )
    raise ValueError(f"Unrecognised adjoint spec: {spec!r}")

Gradient modulators

The learned components of 4DVarNet: ConvLSTM cells that map the raw variational-cost gradient to a descent update, satisfying the GradModulator Protocol. Their recurrent state is carried in the LSTMState1D / LSTMState2D containers.

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

ConvLSTMGradMod1D

Bases: Module

1-D ConvLSTM-based gradient modulator.

Accepts the concatenation of the current state and its gradient as input and produces a modulated gradient update (and updated LSTM state).

Attributes:

Name Type Description
state_channels int

Number of channels in the state / gradient (T).

hidden_dim int

Number of hidden channels in the LSTM.

kernel_size int

1-D convolution kernel size.

Source code in src/vardax/_src/grad_mod.py
class ConvLSTMGradMod1D(eqx.Module):
    """1-D ConvLSTM-based gradient modulator.

    Accepts the concatenation of the current state and its gradient as input
    and produces a modulated gradient update (and updated LSTM state).

    Attributes:
        state_channels: Number of channels in the state / gradient (``T``).
        hidden_dim: Number of hidden channels in the LSTM.
        kernel_size: 1-D convolution kernel size.
    """

    state_channels: int = eqx.field(static=True)
    hidden_dim: int = eqx.field(static=True)
    kernel_size: int = eqx.field(static=True)
    _gates_input_conv: eqx.nn.Conv1d
    _gates_hidden_conv: eqx.nn.Conv1d
    _output_conv: eqx.nn.Conv1d

    def __init__(
        self,
        state_channels: int,
        hidden_dim: int,
        kernel_size: int = 3,
        *,
        key: PRNGKeyArray,
    ) -> None:
        self.state_channels = state_channels
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        k1, k2, k3 = jax.random.split(key, 3)
        pad = kernel_size // 2
        self._gates_input_conv = eqx.nn.Conv1d(
            in_channels=2 * state_channels,
            out_channels=4 * hidden_dim,
            kernel_size=kernel_size,
            padding=pad,
            key=k1,
        )
        self._gates_hidden_conv = eqx.nn.Conv1d(
            in_channels=hidden_dim,
            out_channels=4 * hidden_dim,
            kernel_size=kernel_size,
            padding=pad,
            key=k2,
        )
        self._output_conv = eqx.nn.Conv1d(
            in_channels=hidden_dim,
            out_channels=state_channels,
            kernel_size=kernel_size,
            padding=pad,
            key=k3,
        )

    def __call__(
        self,
        grad: Float[Array, "B T N"],
        state: Float[Array, "B T N"],
        lstm_state: LSTMState1D,
    ) -> tuple[Float[Array, "B T N"], LSTMState1D]:
        """Forward pass.

        Args:
            grad: Gradient of variational cost w.r.t. state, shape ``(B, T, N)``.
            state: Current state estimate, shape ``(B, T, N)``.
            lstm_state: Current LSTM hidden/cell state with hidden_dim channels.

        Returns:
            Tuple of (modulated gradient update, new LSTM state).
        """
        # eqx.nn.Conv1d operates on (channels, spatial). Our arrays are already
        # (B, channels, spatial) with T=channels for grad/state and H=channels
        # for the LSTM state — vmap over the leading batch dim.

        def _forward(
            grad_i: Float[Array, "T N"],
            state_i: Float[Array, "T N"],
            h_i: Float[Array, "H N"],
            c_i: Float[Array, "H N"],
        ) -> tuple[Float[Array, "T N"], Float[Array, "H N"], Float[Array, "H N"]]:
            x = jnp.concatenate([grad_i, state_i], axis=0)  # (2T, N)
            gates = self._gates_input_conv(x) + self._gates_hidden_conv(h_i)
            i, f, g, o = jnp.split(gates, 4, axis=0)
            i = jax.nn.sigmoid(i)
            f = jax.nn.sigmoid(f)
            g = jnp.tanh(g)
            o = jax.nn.sigmoid(o)
            c_new = f * c_i + i * g
            h_new = o * jnp.tanh(c_new)
            out = self._output_conv(h_new)
            return out, h_new, c_new

        out, h_new, c_new = jax.vmap(_forward)(grad, state, lstm_state.h, lstm_state.c)
        return out, LSTMState1D(h=h_new, c=c_new)

ConvLSTMGradMod2D

Bases: Module

2-D ConvLSTM-based gradient modulator.

Attributes:

Name Type Description
state_channels int

Number of time channels in the state / gradient.

hidden_dim int

Number of hidden channels in the LSTM.

kernel_size int

2-D convolution kernel size.

Source code in src/vardax/_src/grad_mod.py
class ConvLSTMGradMod2D(eqx.Module):
    """2-D ConvLSTM-based gradient modulator.

    Attributes:
        state_channels: Number of time channels in the state / gradient.
        hidden_dim: Number of hidden channels in the LSTM.
        kernel_size: 2-D convolution kernel size.
    """

    state_channels: int = eqx.field(static=True)
    hidden_dim: int = eqx.field(static=True)
    kernel_size: int = eqx.field(static=True)
    _gates_input_conv: eqx.nn.Conv2d
    _gates_hidden_conv: eqx.nn.Conv2d
    _output_conv: eqx.nn.Conv2d

    def __init__(
        self,
        state_channels: int,
        hidden_dim: int,
        kernel_size: int = 3,
        *,
        key: PRNGKeyArray,
    ) -> None:
        self.state_channels = state_channels
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        k1, k2, k3 = jax.random.split(key, 3)
        pad = kernel_size // 2
        self._gates_input_conv = eqx.nn.Conv2d(
            in_channels=2 * state_channels,
            out_channels=4 * hidden_dim,
            kernel_size=kernel_size,
            padding=pad,
            key=k1,
        )
        self._gates_hidden_conv = eqx.nn.Conv2d(
            in_channels=hidden_dim,
            out_channels=4 * hidden_dim,
            kernel_size=kernel_size,
            padding=pad,
            key=k2,
        )
        self._output_conv = eqx.nn.Conv2d(
            in_channels=hidden_dim,
            out_channels=state_channels,
            kernel_size=kernel_size,
            padding=pad,
            key=k3,
        )

    def __call__(
        self,
        grad: Float[Array, "B T H W"],
        state: Float[Array, "B T H W"],
        lstm_state: LSTMState2D,
    ) -> tuple[Float[Array, "B T H W"], LSTMState2D]:
        """Forward pass.

        Args:
            grad: Gradient of variational cost w.r.t. state, shape ``(B, T, H, W)``.
            state: Current state estimate, shape ``(B, T, H, W)``.
            lstm_state: Current LSTM hidden/cell state.

        Returns:
            Tuple of (modulated gradient update, new LSTM state).
        """

        def _forward(
            grad_i: Float[Array, "T H W"],
            state_i: Float[Array, "T H W"],
            h_i: Float[Array, "H_dim H W"],
            c_i: Float[Array, "H_dim H W"],
        ) -> tuple[
            Float[Array, "T H W"],
            Float[Array, "H_dim H W"],
            Float[Array, "H_dim H W"],
        ]:
            x = jnp.concatenate([grad_i, state_i], axis=0)  # (2T, H, W)
            gates = self._gates_input_conv(x) + self._gates_hidden_conv(h_i)
            i, f, g, o = jnp.split(gates, 4, axis=0)
            i = jax.nn.sigmoid(i)
            f = jax.nn.sigmoid(f)
            g = jnp.tanh(g)
            o = jax.nn.sigmoid(o)
            c_new = f * c_i + i * g
            h_new = o * jnp.tanh(c_new)
            out = self._output_conv(h_new)
            return out, h_new, c_new

        out, h_new, c_new = jax.vmap(_forward)(grad, state, lstm_state.h, lstm_state.c)
        return out, LSTMState2D(h=h_new, c=c_new)

Losses & steps

Outer-loop training: reconstruction-based losses for 4DVarNet, the negative-log-likelihood loss for amortized posteriors, and the optax-driven train_step / amortized_train_step / eval_step that consume them.

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

reconstruction_loss

reconstruction_loss(
    pred: Float[Array, ...], target: Float[Array, ...]
) -> Float[Array, ""]

Mean-squared reconstruction loss.

Parameters:

Name Type Description Default
pred Float[Array, ...]

Model predictions, arbitrary shape.

required
target Float[Array, ...]

Ground-truth targets, same shape as pred.

required

Returns:

Type Description
Float[Array, '']

Scalar mean-squared error.

Source code in src/vardax/_src/training.py
def reconstruction_loss(
    pred: Float[Array, ...],
    target: Float[Array, ...],
) -> Float[Array, ""]:
    """Mean-squared reconstruction loss.

    Args:
        pred: Model predictions, arbitrary shape.
        target: Ground-truth targets, same shape as ``pred``.

    Returns:
        Scalar mean-squared error.
    """
    return jnp.mean((pred - target) ** 2)

train_loss_fn

train_loss_fn(
    model: Any, batch: Batch1D | Batch2D
) -> Float[Array, ""]

Compute the training loss for a single batch.

Parameters:

Name Type Description Default
model Any

Equinox module implementing __call__(batch) -> prediction.

required
batch Batch1D | Batch2D

Training batch (must have a non-None target).

required

Returns:

Type Description
Float[Array, '']

Scalar reconstruction loss.

Source code in src/vardax/_src/training.py
def train_loss_fn(
    model: Any,
    batch: Batch1D | Batch2D,
) -> Float[Array, ""]:
    """Compute the training loss for a single batch.

    Args:
        model: Equinox module implementing ``__call__(batch) -> prediction``.
        batch: Training batch (must have a non-``None`` ``target``).

    Returns:
        Scalar reconstruction loss.
    """
    pred = model(batch)
    target = batch.target
    if target is None:
        raise ValueError("train_loss_fn requires batch.target to be set.")
    return reconstruction_loss(pred, target)

train_step

train_step(
    model: Any,
    batch: Batch1D | Batch2D,
    optimizer: GradientTransformation,
    opt_state: OptState,
) -> tuple[Any, OptState, Float[Array, ""]]

Perform a single training step (forward + backward + update).

This is the correctness-critical primitive: gradients flow through the FourDVarNet inner solver according to whichever differentiation strategy ("unrolled" / "one_step" / "implicit") the model is configured with. Users should compose this primitive into their training loop (notebook-level or pipekit_train.TrainingLoop) rather than reimplementing it.

Parameters:

Name Type Description Default
model Any

Equinox module to optimise.

required
batch Batch1D | Batch2D

Training batch.

required
optimizer GradientTransformation

Optax gradient transformation (e.g. optax.adam(1e-3)).

required
opt_state OptState

Current optimiser state.

required

Returns:

Type Description
tuple[Any, OptState, Float[Array, '']]

Tuple of (updated model, updated optimiser state, scalar loss).

Source code in src/vardax/_src/training.py
@eqx.filter_jit
def train_step(
    model: Any,
    batch: Batch1D | Batch2D,
    optimizer: optax.GradientTransformation,
    opt_state: optax.OptState,
) -> tuple[Any, optax.OptState, Float[Array, ""]]:
    """Perform a single training step (forward + backward + update).

    This is the correctness-critical primitive: gradients flow through
    the FourDVarNet inner solver according to whichever differentiation
    strategy (``"unrolled"`` / ``"one_step"`` / ``"implicit"``) the
    model is configured with. Users should compose this primitive into
    their training loop (notebook-level or ``pipekit_train.TrainingLoop``)
    rather than reimplementing it.

    Args:
        model: Equinox module to optimise.
        batch: Training batch.
        optimizer: Optax gradient transformation (e.g. ``optax.adam(1e-3)``).
        opt_state: Current optimiser state.

    Returns:
        Tuple of (updated model, updated optimiser state, scalar loss).
    """
    loss, grads = eqx.filter_value_and_grad(train_loss_fn)(model, batch)
    updates, opt_state = optimizer.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss

eval_step

eval_step(
    model: Any, batch: Batch1D | Batch2D
) -> Float[Array, ""]

Compute the evaluation loss for a single batch (no gradient).

Parameters:

Name Type Description Default
model Any

Equinox module.

required
batch Batch1D | Batch2D

Evaluation batch (must have a non-None target).

required

Returns:

Type Description
Float[Array, '']

Scalar reconstruction loss.

Source code in src/vardax/_src/training.py
@eqx.filter_jit
def eval_step(
    model: Any,
    batch: Batch1D | Batch2D,
) -> Float[Array, ""]:
    """Compute the evaluation loss for a single batch (no gradient).

    Args:
        model: Equinox module.
        batch: Evaluation batch (must have a non-``None`` ``target``).

    Returns:
        Scalar reconstruction loss.
    """
    pred = model(batch)
    target = batch.target
    if target is None:
        raise ValueError("eval_step requires batch.target to be set.")
    return reconstruction_loss(pred, target)

amortized_nll_loss_fn

amortized_nll_loss_fn(
    model: Any, batch: Batch1D | Batch2D
) -> Float[Array, ""]

Negative log-likelihood for amortized inference (Epic 8).

For AmortizedPosterior with flow / regression heads the maximum- likelihood objective on simulated pairs is

\[ \mathcal{L}_\text{MLE}(\phi) = -\mathbb{E}_{(x, y)} \log q_\phi(x \mid y). \]

Parameters:

Name Type Description Default
model Any

AmortizedPosterior (any head with a .log_prob method).

required
batch Batch1D | Batch2D

Training batch with target = x and input = y.

required

Returns:

Type Description
Float[Array, '']

Scalar NLL averaged over the batch.

Source code in src/vardax/_src/training.py
def amortized_nll_loss_fn(
    model: Any,
    batch: Batch1D | Batch2D,
) -> Float[Array, ""]:
    r"""Negative log-likelihood for amortized inference (Epic 8).

    For ``AmortizedPosterior`` with flow / regression heads the maximum-
    likelihood objective on simulated pairs is

    $$
    \mathcal{L}_\text{MLE}(\phi) = -\mathbb{E}_{(x, y)} \log q_\phi(x \mid y).
    $$

    Args:
        model: ``AmortizedPosterior`` (any head with a ``.log_prob``
            method).
        batch: Training batch with ``target = x`` and ``input = y``.

    Returns:
        Scalar NLL averaged over the batch.
    """
    target = batch.target
    if target is None:
        raise ValueError(
            "amortized_nll_loss_fn requires batch.target (the true x) to be set."
        )
    log_p = model.log_prob(target, batch)
    return -jnp.mean(log_p)

amortized_train_step

amortized_train_step(
    model: Any,
    batch: Batch1D | Batch2D,
    optimizer: GradientTransformation,
    opt_state: OptState,
) -> tuple[Any, OptState, Float[Array, ""]]

Single training step for AmortizedPosterior.

Same shape as train_step but uses amortized_nll_loss_fn instead of the MSE reconstruction loss. Use this for simulation- based training of amortized variants; use train_step for FourDVarNet and classical models that reconstruct fields.

Source code in src/vardax/_src/training.py
@eqx.filter_jit
def amortized_train_step(
    model: Any,
    batch: Batch1D | Batch2D,
    optimizer: optax.GradientTransformation,
    opt_state: optax.OptState,
) -> tuple[Any, optax.OptState, Float[Array, ""]]:
    """Single training step for ``AmortizedPosterior``.

    Same shape as ``train_step`` but uses ``amortized_nll_loss_fn``
    instead of the MSE reconstruction loss. Use this for simulation-
    based training of amortized variants; use ``train_step`` for
    ``FourDVarNet`` and classical models that reconstruct fields.
    """
    loss, grads = eqx.filter_value_and_grad(amortized_nll_loss_fn)(model, batch)
    updates, opt_state = optimizer.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss