Costs, Priors & Solvers¶
The functional core under the model classes: pure cost functions that score a candidate state against observations and prior, prior modules that supply the regularisation term, and the inner-loop solver functions that drive the 4DVarNet iteration. The model classes are thin, stateful-looking wrappers over these pieces — drop down to this layer when building custom methods or instrumenting the optimisation.
Cost functions¶
The variational cost \(J(x) = J_\text{obs}(x) + J_\text{prior}(x)\) and its
gradient, with the observation and prior terms also available separately
(decomposed_loss returns them unsummed for logging). The _1d / _2d
suffixes match the Batch1D / Batch2D carriers. See
3DVar and
strong-constraint 4DVar for the math each term
implements.
vardax — Modular variational data assimilation with learned components.
All public symbols are re-exported from the private _src subpackage so
that user code imports from the top-level namespace:
variational_cost
¶
variational_cost(
x: Float[Array, ...],
batch: Batch1D,
prior_fn: Callable[..., Any],
alpha_obs: float = 0.5,
alpha_prior: float = 0.5,
) -> Float[Array, ""]
Compute the variational cost \(U(x)\).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, ...]
|
Current state estimate. |
required |
batch
|
Batch1D
|
Observed data batch with |
required |
prior_fn
|
Callable[..., Any]
|
Callable |
required |
alpha_obs
|
float
|
Weight for the observation term (default |
0.5
|
alpha_prior
|
float
|
Weight for the prior term (default |
0.5
|
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar cost value. |
Examples:
With the trivial IdentityPrior the
prior term vanishes, leaving the weighted observation MSE.
>>> import jax.numpy as jnp, vardax
>>> batch = vardax.Batch1D(input=jnp.zeros((1, 2, 4)), mask=jnp.ones((1, 2, 4)))
>>> x = jnp.ones((1, 2, 4))
>>> float(vardax.variational_cost(x, batch, vardax.IdentityPrior()))
0.5
Source code in src/vardax/_src/costs.py
variational_cost_grad
¶
variational_cost_grad(
x: Float[Array, ...],
batch: Batch1D,
prior_fn: Callable[..., Any],
alpha_obs: float = 0.5,
alpha_prior: float = 0.5,
) -> Float[Array, ...]
Gradient of variational_cost w.r.t. x.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, ...]
|
Current state estimate. |
required |
batch
|
Batch1D
|
Observed data batch. |
required |
prior_fn
|
Callable[..., Any]
|
Callable |
required |
alpha_obs
|
float
|
Weight for the observation term. |
0.5
|
alpha_prior
|
float
|
Weight for the prior term. |
0.5
|
Returns:
| Type | Description |
|---|---|
Float[Array, ...]
|
Gradient array with the same shape as |
Source code in src/vardax/_src/costs.py
obs_cost_1d
¶
obs_cost_1d(
state: Float[Array, "B T N"],
obs: Float[Array, "B T N"],
mask: Float[Array, "B T N"],
) -> Float[Array, ""]
Observation cost for 1-D data.
Computes the masked mean-squared error between the state and observations:
where \(\Omega\) is the set of observed locations (mask == 1).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
Float[Array, 'B T N']
|
Current state estimate of shape |
required |
obs
|
Float[Array, 'B T N']
|
Observations of shape |
required |
mask
|
Float[Array, 'B T N']
|
Binary observation mask of shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar observation cost. |
Examples:
>>> import jax.numpy as jnp
>>> from vardax import obs_cost_1d
>>> state = jnp.ones((1, 1, 4))
>>> obs = jnp.zeros((1, 1, 4))
>>> mask = jnp.ones((1, 1, 4))
>>> float(obs_cost_1d(state, obs, mask))
1.0
Source code in src/vardax/_src/costs.py
obs_cost_2d
¶
obs_cost_2d(
state: Float[Array, "B T H W"],
obs: Float[Array, "B T H W"],
mask: Float[Array, "B T H W"],
) -> Float[Array, ""]
Observation cost for 2-D data.
Computes the masked mean-squared error between the state and observations:
where \(\Omega\) is the set of observed locations (mask == 1).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
Float[Array, 'B T H W']
|
Current state estimate of shape |
required |
obs
|
Float[Array, 'B T H W']
|
Observations of shape |
required |
mask
|
Float[Array, 'B T H W']
|
Binary observation mask of shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar observation cost. |
Source code in src/vardax/_src/costs.py
prior_cost
¶
prior_cost(
state: Float[Array, ...],
prior_reconstruction: Float[Array, ...],
) -> Float[Array, ""]
Prior cost based on learned autoencoder reconstruction.
Computes the mean-squared error between the state and its reconstruction through the learned prior (autoencoder):
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
Float[Array, ...]
|
Current state estimate of arbitrary shape. |
required |
prior_reconstruction
|
Float[Array, ...]
|
Autoencoder reconstruction of the state,
same shape as |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar prior cost. |
Source code in src/vardax/_src/costs.py
decomposed_loss
¶
decomposed_loss(
x: Float[Array, ...],
batch: Batch1D,
prior_fn: Callable[..., Any],
alpha_obs: float = 0.5,
alpha_prior: float = 0.5,
) -> dict[str, Float[Array, ""]]
Compute the decomposed variational loss.
Returns individual observation and prior components alongside the
total, matching the ModelLoss pattern from the legacy codebase.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, ...]
|
Current state estimate. |
required |
batch
|
Batch1D
|
Observed data batch. |
required |
prior_fn
|
Callable[..., Any]
|
Callable |
required |
alpha_obs
|
float
|
Weight for the observation term. |
0.5
|
alpha_prior
|
float
|
Weight for the prior term. |
0.5
|
Returns:
| Type | Description |
|---|---|
dict[str, Float[Array, '']]
|
Dictionary with keys |
Source code in src/vardax/_src/costs.py
Priors¶
Implementations of the Prior Protocol. IdentityPrior
gives plain Tikhonov regularisation; L63Prior / L96Prior encode Lorenz
dynamics as a model-consistency penalty; the autoencoder priors (MLP,
convolutional, and bilinear variants in 1D, 2D, and 2D-multivariate) are
learned priors that penalise distance from a trained reconstruction
manifold.
vardax — Modular variational data assimilation with learned components.
All public symbols are re-exported from the private _src subpackage so
that user code imports from the top-level namespace:
IdentityPrior
¶
Bases: Module
Trivial identity prior: \(\varphi(x) = x\).
Zero parameters. Useful as a pure obs-driven baseline (the prior cost vanishes everywhere) and as a sanity-check building block in the linear-Gaussian agreement tests.
Examples:
>>> import jax.numpy as jnp
>>> from vardax import IdentityPrior
>>> prior = IdentityPrior()
>>> x = jnp.arange(6.0).reshape(1, 2, 3)
>>> bool(jnp.all(prior(x) == x))
True
Source code in src/vardax/_src/priors.py
L63Prior
¶
Bases: Module
Learned prior for the Lorenz-63 system.
A simple MLP autoencoder designed for the 3-dimensional Lorenz-63
attractor. The state is treated as a flat vector of length 3.
Attributes:
| Name | Type | Description |
|---|---|---|
latent_dim |
Dimensionality of the latent code (default |
|
hidden_dim |
Hidden layer width. |
|
state_dim |
Dimensionality of the state vector (default |
Source code in src/vardax/_src/priors.py
L96Prior
¶
Bases: Module
Learned prior for the Lorenz-96 system.
A simple MLP autoencoder designed for the N-dimensional Lorenz-96
attractor. The state is treated as a flat vector of length N.
Attributes:
| Name | Type | Description |
|---|---|---|
latent_dim |
Dimensionality of the latent code. |
|
hidden_dim |
Hidden layer width. |
|
state_dim |
Dimensionality of the state vector. |
Source code in src/vardax/_src/priors.py
MLPAEPrior1D
¶
Bases: Module
MLP autoencoder prior for 1-D data.
Attributes:
| Name | Type | Description |
|---|---|---|
state_dim |
Spatial size of the input ( |
|
latent_dim |
Dimensionality of the latent code. |
|
hidden_dim |
Hidden layer width. |
|
n_time |
int
|
Number of time steps ( |
Source code in src/vardax/_src/priors.py
ConvAEPrior1D
¶
Bases: Module
Convolutional autoencoder prior for 1-D spatially-structured data.
Uses circular (periodic) padding suitable for systems with periodic
boundary conditions such as Lorenz-96. Operates on inputs of shape
(B, T, N) where N is the spatial dimension.
Attributes:
| Name | Type | Description |
|---|---|---|
latent_channels |
Number of channels in the latent representation. |
|
kernel_size |
int
|
Convolution kernel size (must be a positive odd integer). |
n_time |
int
|
Number of time steps |
Source code in src/vardax/_src/priors.py
BilinAEPrior1D
¶
Bases: Module
Bilinear autoencoder prior for 1-D data.
The encoder maps the input to a low-dimensional latent code; the decoder
reconstructs the original space. The prior cost is
||x - decode(encode(x))||^2.
Attributes:
| Name | Type | Description |
|---|---|---|
state_dim |
int
|
Spatial size of the input ( |
latent_dim |
int
|
Dimensionality of the latent code. |
n_time |
int
|
Number of time steps ( |
Examples:
>>> import jax, jax.numpy as jnp
>>> from vardax import BilinAEPrior1D
>>> prior = BilinAEPrior1D(
... state_dim=4, latent_dim=2, n_time=3, key=jax.random.PRNGKey(0)
... )
>>> prior(jnp.ones((2, 3, 4))).shape
(2, 3, 4)
Source code in src/vardax/_src/priors.py
BilinAEPrior2D
¶
Bases: Module
Bilinear autoencoder prior for 2-D data.
Attributes:
| Name | Type | Description |
|---|---|---|
latent_dim |
Dimensionality of the latent code. |
|
n_time |
int
|
Number of time steps ( |
height |
int
|
Spatial height |
width |
int
|
Spatial width |
Source code in src/vardax/_src/priors.py
BilinAEPrior2DMultivar
¶
Bases: Module
Bilinear autoencoder prior for 2-D multivariate data.
Attributes:
| Name | Type | Description |
|---|---|---|
latent_dim |
Dimensionality of the latent code. |
|
n_time |
int
|
Number of time steps ( |
n_channels |
int
|
Number of channels |
height |
int
|
Spatial height |
width |
int
|
Spatial width |
Source code in src/vardax/_src/priors.py
4DVarNet inner-loop solvers¶
The unrolled (and fixed-point) inner loop of 4DVarNet, exposed as pure
functions over an explicit SolverState: initialise with
init_solver_state_*, advance one modulated-gradient step with
solver_step_* (or fp_solver_step_1d for the fixed-point formulation),
or run the whole loop with solve_4dvarnet_*. The one_step_* variants
pair with OneStepAdjoint for memory-frugal training.
vardax — Modular variational data assimilation with learned components.
All public symbols are re-exported from the private _src subpackage so
that user code imports from the top-level namespace:
SolverState1D
¶
Bases: Module
Mutable solver state for 1-D problems.
Attributes:
| Name | Type | Description |
|---|---|---|
x |
Float[Array, 'B T N']
|
Current state estimate of shape |
lstm |
LSTMState1D
|
Current LSTM hidden/cell state for the gradient modulator. |
step |
int
|
Current iteration index. |
Source code in src/vardax/_src/solver.py
SolverState2D
¶
Bases: Module
Mutable solver state for 2-D problems.
Attributes:
| Name | Type | Description |
|---|---|---|
x |
Float[Array, 'B T H W']
|
Current state estimate of shape |
lstm |
LSTMState2D
|
Current LSTM hidden/cell state for the gradient modulator. |
step |
int
|
Current iteration index. |
Source code in src/vardax/_src/solver.py
init_solver_state_1d
¶
init_solver_state_1d(
batch: Batch1D, hidden_dim: int
) -> SolverState1D
Initialise a 1-D solver state from a batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch1D
|
Input batch. The initial state is set to the masked input (zeros where unobserved). |
required |
hidden_dim
|
int
|
Hidden dimension of the ConvLSTM gradient modulator. |
required |
Returns:
| Type | Description |
|---|---|
SolverState1D
|
Zero-initialised |
Source code in src/vardax/_src/solver.py
init_solver_state_2d
¶
init_solver_state_2d(
batch: Batch2D, hidden_dim: int
) -> SolverState2D
Initialise a 2-D solver state from a batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch2D
|
Input batch. The initial state is set to the masked input. |
required |
hidden_dim
|
int
|
Hidden dimension of the ConvLSTM gradient modulator. |
required |
Returns:
| Type | Description |
|---|---|
SolverState2D
|
Zero-initialised |
Source code in src/vardax/_src/solver.py
solver_step_1d
¶
solver_step_1d(
solver_state: SolverState1D,
batch: Batch1D,
prior_fn: Any,
grad_mod_fn: Any,
alpha: float = 1.0,
prior_weight: float = 1.0,
) -> SolverState1D
Perform a single 1-D solver iteration.
Computes the gradient of the variational cost, then passes it through the learned gradient modulator to obtain a state update.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
solver_state
|
SolverState1D
|
Current solver state. |
required |
batch
|
Batch1D
|
Observed data batch. |
required |
prior_fn
|
Any
|
Callable |
required |
grad_mod_fn
|
Any
|
Callable |
required |
alpha
|
float
|
Step-size scaling factor. |
1.0
|
prior_weight
|
float
|
Weighting factor \(\lambda\) for the prior cost term. |
1.0
|
Returns:
| Type | Description |
|---|---|
SolverState1D
|
Updated |
Source code in src/vardax/_src/solver.py
solver_step_2d
¶
solver_step_2d(
solver_state: SolverState2D,
batch: Batch2D,
prior_fn: Any,
grad_mod_fn: Any,
alpha: float = 1.0,
prior_weight: float = 1.0,
) -> SolverState2D
Perform a single 2-D solver iteration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
solver_state
|
SolverState2D
|
Current solver state. |
required |
batch
|
Batch2D
|
Observed data batch. |
required |
prior_fn
|
Any
|
Callable |
required |
grad_mod_fn
|
Any
|
Callable |
required |
alpha
|
float
|
Step-size scaling factor. |
1.0
|
prior_weight
|
float
|
Weighting factor \(\lambda\) for the prior cost term. |
1.0
|
Returns:
| Type | Description |
|---|---|
SolverState2D
|
Updated |
Source code in src/vardax/_src/solver.py
fp_solver_step_1d
¶
fp_solver_step_1d(
x: Float[Array, "B T N"], batch: Batch1D, prior_fn: Any
) -> Float[Array, "B T N"]
Perform a single 1-D fixed-point projection step.
Applies the prior projection then re-inserts observations at observed locations:
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, 'B T N']
|
Current state estimate of shape |
required |
batch
|
Batch1D
|
Observed data batch containing |
required |
prior_fn
|
Any
|
Callable |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'B T N']
|
Updated state estimate of shape |
Source code in src/vardax/_src/solver.py
solve_4dvarnet_1d
¶
solve_4dvarnet_1d(
batch: Batch1D,
prior_fn: Any,
grad_mod_fn: Any,
n_steps: int,
hidden_dim: int,
alpha: float = 1.0,
) -> Float[Array, "B T N"]
Run the full 1-D 4DVarNet solver for n_steps iterations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch1D
|
Observed data batch. |
required |
prior_fn
|
Any
|
Callable |
required |
grad_mod_fn
|
Any
|
Callable |
required |
n_steps
|
int
|
Number of gradient-descent steps to unroll. |
required |
hidden_dim
|
int
|
Hidden dimension of the ConvLSTM gradient modulator. |
required |
alpha
|
float
|
Step-size scaling factor. |
1.0
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'B T N']
|
Final state estimate of shape |
Source code in src/vardax/_src/solver.py
solve_4dvarnet_2d
¶
solve_4dvarnet_2d(
batch: Batch2D,
prior_fn: Any,
grad_mod_fn: Any,
n_steps: int,
hidden_dim: int,
alpha: float = 1.0,
) -> Float[Array, "B T H W"]
Run the full 2-D 4DVarNet solver for n_steps iterations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch2D
|
Observed data batch. |
required |
prior_fn
|
Any
|
Callable |
required |
grad_mod_fn
|
Any
|
Callable |
required |
n_steps
|
int
|
Number of gradient-descent steps to unroll. |
required |
hidden_dim
|
int
|
Hidden dimension of the ConvLSTM gradient modulator. |
required |
alpha
|
float
|
Step-size scaling factor. |
1.0
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'B T H W']
|
Final state estimate of shape |
Source code in src/vardax/_src/solver.py
solve_4dvarnet_1d_fixedpoint
¶
solve_4dvarnet_1d_fixedpoint(
batch: Batch1D, prior_fn: Any, n_fp_steps: int
) -> Float[Array, "B T N"]
Run n_fp_steps fixed-point projection steps using jax.lax.scan.
Initialises the state from the masked observations, then iterates the
fixed-point update fp_solver_step_1d for
n_fp_steps steps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch1D
|
Observed data batch. |
required |
prior_fn
|
Any
|
Callable |
required |
n_fp_steps
|
int
|
Number of fixed-point iterations. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'B T N']
|
Final state estimate of shape |
Source code in src/vardax/_src/solver.py
one_step_solve_4dvarnet_1d
¶
one_step_solve_4dvarnet_1d(
batch: Batch1D,
prior_fn: Any,
grad_mod_fn: Any,
n_steps: int,
hidden_dim: int,
alpha: float = 1.0,
prior_weight: float = 1.0,
k: int = 1,
) -> Float[Array, "B T N"]
Solve 4DVarNet-1D using k-step differentiation (Bolte et al., 2023).
Runs n_steps - k solver iterations with jax.lax.stop_gradient
applied to the iterate, then performs k final steps through which
gradients flow. This gives O(k) memory cost (k=1 matches implicit
differentiation) while being as simple to implement as unrolled backprop.
Reference
Bolte, Pauwels & Vaiter (NeurIPS 2023). "One-step differentiation of iterative algorithms." https://arxiv.org/abs/2305.13768
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch1D
|
Observed data batch. |
required |
prior_fn
|
Any
|
Callable |
required |
grad_mod_fn
|
Any
|
Callable |
required |
n_steps
|
int
|
Total number of solver iterations (warmup = n_steps - k, then k differentiable steps). |
required |
hidden_dim
|
int
|
Hidden dimension of the ConvLSTM gradient modulator. |
required |
alpha
|
float
|
Step-size scaling factor. |
1.0
|
prior_weight
|
float
|
Weighting factor \(\lambda\) for the prior cost term. |
1.0
|
k
|
int
|
Number of trailing differentiable steps (clipped to n_steps). |
1
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'B T N']
|
Final state estimate of shape |
Source code in src/vardax/_src/solver.py
one_step_solve_4dvarnet_2d
¶
one_step_solve_4dvarnet_2d(
batch: Batch2D,
prior_fn: Any,
grad_mod_fn: Any,
n_steps: int,
hidden_dim: int,
alpha: float = 1.0,
prior_weight: float = 1.0,
k: int = 1,
) -> Float[Array, "B T H W"]
Solve 4DVarNet-2D using k-step differentiation (Bolte et al., 2023).
Runs n_steps - k solver iterations with jax.lax.stop_gradient
applied to the iterate, then performs k final steps through which
gradients flow. This gives O(k) memory cost (k=1 matches implicit
differentiation) while being as simple to implement as unrolled backprop.
Reference
Bolte, Pauwels & Vaiter (NeurIPS 2023). "One-step differentiation of iterative algorithms." https://arxiv.org/abs/2305.13768
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch2D
|
Observed data batch. |
required |
prior_fn
|
Any
|
Callable |
required |
grad_mod_fn
|
Any
|
Callable |
required |
n_steps
|
int
|
Total number of solver iterations (warmup = n_steps - k, then k differentiable steps). |
required |
hidden_dim
|
int
|
Hidden dimension of the ConvLSTM gradient modulator. |
required |
alpha
|
float
|
Step-size scaling factor. |
1.0
|
prior_weight
|
float
|
Weighting factor \(\lambda\) for the prior cost term. |
1.0
|
k
|
int
|
Number of trailing differentiable steps (clipped to n_steps). |
1
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'B T H W']
|
Final state estimate of shape |