Training & Adjoints¶
Training a learned solver means differentiating through an inner optimisation loop, and how you do that determines memory cost and gradient quality. The adjoint strategies here make that choice explicit and swappable — see Adjoint Methods in the Mathematical Reference for the trade-offs. Around them sit the loss functions and train/eval steps for 4DVarNet and the amortized posteriors, and the ConvLSTM gradient modulators that 4DVarNet learns in place of a hand-tuned inner optimiser.
Adjoint strategies¶
Implementations of vardax.adjoints (Decision D15): full
backpropagation with checkpointed memory (RecursiveCheckpointAdjoint),
truncated one-step gradients (OneStepAdjoint, pairing with the
one_step_solve_* solver functions), and
implicit differentiation at a fixed point (ImplicitAdjoint, pairing with
solve_4dvarnet_1d_fixedpoint). All three are also accessible via the
vardax.adjoints submodule namespace. Use
assert_adjoint_calibrated to verify a cheap adjoint against
the exact one before trusting it.
RecursiveCheckpointAdjoint and ImplicitAdjoint are re-exported from
optimistix for one-stop import;
see the optimistix documentation for their full signatures:
vardax.RecursiveCheckpointAdjoint—optimistix.RecursiveCheckpointAdjoint, exact reverse-mode backpropagation through the unrolled inner loop with binomial checkpointing (the default).vardax.ImplicitAdjoint—optimistix.ImplicitAdjoint, implicit-function-theorem differentiation at a fixed point; pair withsolve_4dvarnet_1d_fixedpoint.
KStepAdjoint(k) is vardax's own truncated adjoint — warmup under
stop_gradient, then k differentiable steps; OneStepAdjoint is the
k=1 alias (Bolte, Pauwels &
Vaiter, NeurIPS 2023):
vardax — Modular variational data assimilation with learned components.
All public symbols are re-exported from the private _src subpackage so
that user code imports from the top-level namespace:
KStepAdjoint
¶
Bases: AbstractAdjoint
K-step differentiation: warmup under stop_gradient, then k live steps.
Use as the solver_adjoint argument to
FourDVarNet1D /
FourDVarNet2D:
from vardax.adjoints import KStepAdjoint
model = FourDVarNet1D(
state_dim=N, n_time=T, ...,
solver_adjoint=KStepAdjoint(k=3),
key=key,
)
Attributes:
| Name | Type | Description |
|---|---|---|
k |
int
|
Number of trailing solver iterations that propagate
gradients. Must be at least 1; values larger than
|
References
Bolte, J., Pauwels, E. & Vaiter, S. (2023). One-step differentiation of iterative algorithms. NeurIPS 36. arXiv:2305.13768.
Source code in src/vardax/_src/adjoints/k_step.py
apply
¶
apply(
primal_fn: Callable,
rewrite_fn: Callable,
inputs: PyTree,
tags: frozenset[object],
) -> PyTree[Array]
Not used by vardax's custom learned solver — dispatch happens
in vardax._src.solver via isinstance on the adjoint type.
Implementing this method for upstream optimistix.minimise
compatibility is tracked under the planned upstream
contribution (Decision D6).
Source code in src/vardax/_src/adjoints/k_step.py
OneStepAdjoint
¶
Bases: KStepAdjoint
One-step differentiation (Bolte et al., 2023).
Run K - 1 warmup iterations of the inner solver with
jax.lax.stop_gradient, then one differentiable step. Gives
O(1) memory and is exact at the fixed point of the inner
iteration.
Use as the solver_adjoint argument to
FourDVarNet1D /
FourDVarNet2D:
from vardax.adjoints import OneStepAdjoint
model = FourDVarNet1D(
state_dim=N, n_time=T, ...,
solver_adjoint=OneStepAdjoint(),
key=key,
)
References
Bolte, J., Pauwels, E. & Vaiter, S. (2023). One-step differentiation of iterative algorithms. NeurIPS 36. arXiv:2305.13768.
Source code in src/vardax/_src/adjoints/one_step.py
to_optimistix_adjoint
¶
Map an adjoint spec onto an optimistix.AbstractAdjoint.
Mapping:
ImplicitAdjoint()→optx.ImplicitAdjoint()— exact at a converged fixed point, O(1) memory, one Hessian linear solve.RecursiveCheckpointAdjoint(checkpoints)→optx.RecursiveCheckpointAdjoint(checkpoints)— exact, recomputing.TruncatedAdjoint(k)→KStepAdjoint(k)— warmup understop_gradient, thenkdifferentiable steps (k=1isOneStepAdjoint).DirectAdjoint/BacksolveAdjoint→ValueError: the first is the plain unrolled default (passoptx.RecursiveCheckpointAdjoint()or nothing), the second only exists at the dynamics layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
spec
|
Any
|
A spec from |
required |
Returns:
| Type | Description |
|---|---|
AbstractAdjoint
|
The corresponding optimistix adjoint instance. |
Raises:
| Type | Description |
|---|---|
ValueError
|
for layer-inappropriate or unrecognised specs. |
Examples:
>>> import optimistix as optx
>>> from pipekit_cycle.adjoints import TruncatedAdjoint
>>> from vardax.adjoints import to_optimistix_adjoint
>>> to_optimistix_adjoint(TruncatedAdjoint(k=3))
KStepAdjoint(k=3)
>>> isinstance(
... to_optimistix_adjoint(optx.ImplicitAdjoint()), optx.ImplicitAdjoint
... )
True
Source code in src/vardax/_src/adjoints/mapping.py
Gradient modulators¶
The learned components of 4DVarNet: ConvLSTM cells that map the raw
variational-cost gradient to a descent update, satisfying the
GradModulator Protocol. Their recurrent state is carried
in the LSTMState1D / LSTMState2D containers.
vardax — Modular variational data assimilation with learned components.
All public symbols are re-exported from the private _src subpackage so
that user code imports from the top-level namespace:
ConvLSTMGradMod1D
¶
Bases: Module
1-D ConvLSTM-based gradient modulator.
Accepts the concatenation of the current state and its gradient as input and produces a modulated gradient update (and updated LSTM state).
Attributes:
| Name | Type | Description |
|---|---|---|
state_channels |
int
|
Number of channels in the state / gradient ( |
hidden_dim |
int
|
Number of hidden channels in the LSTM. |
kernel_size |
int
|
1-D convolution kernel size. |
Source code in src/vardax/_src/grad_mod.py
ConvLSTMGradMod2D
¶
Bases: Module
2-D ConvLSTM-based gradient modulator.
Attributes:
| Name | Type | Description |
|---|---|---|
state_channels |
int
|
Number of time channels in the state / gradient. |
hidden_dim |
int
|
Number of hidden channels in the LSTM. |
kernel_size |
int
|
2-D convolution kernel size. |
Source code in src/vardax/_src/grad_mod.py
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | |
Losses & steps¶
Outer-loop training: reconstruction-based losses for 4DVarNet, the
negative-log-likelihood loss for amortized posteriors, and the
optax-driven train_step / amortized_train_step / eval_step that
consume them.
vardax — Modular variational data assimilation with learned components.
All public symbols are re-exported from the private _src subpackage so
that user code imports from the top-level namespace:
reconstruction_loss
¶
Mean-squared reconstruction loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred
|
Float[Array, ...]
|
Model predictions, arbitrary shape. |
required |
target
|
Float[Array, ...]
|
Ground-truth targets, same shape as |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar mean-squared error. |
Source code in src/vardax/_src/training.py
train_loss_fn
¶
Compute the training loss for a single batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Any
|
Equinox module implementing |
required |
batch
|
Batch1D | Batch2D
|
Training batch (must have a non- |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar reconstruction loss. |
Source code in src/vardax/_src/training.py
train_step
¶
train_step(
model: Any,
batch: Batch1D | Batch2D,
optimizer: GradientTransformation,
opt_state: OptState,
) -> tuple[Any, OptState, Float[Array, ""]]
Perform a single training step (forward + backward + update).
This is the correctness-critical primitive: gradients flow through
the FourDVarNet inner solver according to whichever differentiation
strategy ("unrolled" / "one_step" / "implicit") the
model is configured with. Users should compose this primitive into
their training loop (notebook-level or pipekit_train.TrainingLoop)
rather than reimplementing it.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Any
|
Equinox module to optimise. |
required |
batch
|
Batch1D | Batch2D
|
Training batch. |
required |
optimizer
|
GradientTransformation
|
Optax gradient transformation (e.g. |
required |
opt_state
|
OptState
|
Current optimiser state. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Any, OptState, Float[Array, '']]
|
Tuple of (updated model, updated optimiser state, scalar loss). |
Source code in src/vardax/_src/training.py
eval_step
¶
Compute the evaluation loss for a single batch (no gradient).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Any
|
Equinox module. |
required |
batch
|
Batch1D | Batch2D
|
Evaluation batch (must have a non- |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar reconstruction loss. |
Source code in src/vardax/_src/training.py
amortized_nll_loss_fn
¶
Negative log-likelihood for amortized inference (Epic 8).
For AmortizedPosterior with flow / regression heads the maximum-
likelihood objective on simulated pairs is
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Any
|
|
required |
batch
|
Batch1D | Batch2D
|
Training batch with |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar NLL averaged over the batch. |
Source code in src/vardax/_src/training.py
amortized_train_step
¶
amortized_train_step(
model: Any,
batch: Batch1D | Batch2D,
optimizer: GradientTransformation,
opt_state: OptState,
) -> tuple[Any, OptState, Float[Array, ""]]
Single training step for AmortizedPosterior.
Same shape as train_step but uses amortized_nll_loss_fn
instead of the MSE reconstruction loss. Use this for simulation-
based training of amortized variants; use train_step for
FourDVarNet and classical models that reconstruct fields.