Layer 0 — Primitive Examples¶
Pure JAX cost terms, closed-form BLUE, control-variable transform, posterior primitives, adjoint composition.
Closed-form BLUE (OptimalInterpolation)¶
import gaussx as gx
import lineax as lx
from vardax.costs import blue_analysis
# Structured B and R via gaussx / lineax
B_op = gx.MaternLinearOperator(coords, length_scale=10.0, nu=1.5, sigma=1.0)
R_op = lx.DiagonalLinearOperator(obs_variances)
H_op = obs_op.linearize(x_b) # AbstractLinearOperator (since obs_op is linear)
x_star, P_star_op = blue_analysis(x_b, y, B_op, R_op, H_op)
# P_star_op is an AbstractLinearOperator — mat-vec via lineax.CG, no full materialisation
The implementation picks between the \(B\)-space and \(R\)-space forms (Sherman-Morrison-Woodbury) based on dimensions.
Variational cost¶
from vardax.costs import obs_cost, prior_cost, variational_cost
# Observation cost (classical with R^{-1})
j_obs = obs_cost(x, batch.input, batch.mask, obs_operator=obs_op, R_inv_op=R_inv)
# Prior cost — classical with B^{-1}
j_prior_classical = prior_cost(x, prior_mean=x_b, B_inv_op=B_inv)
# Prior cost — learned (FourDVarNet)
j_prior_learned = prior_cost(x, prior_fn=ae_model)
# Total (FourDVarNet form)
j_total = variational_cost(
x, batch, prior_fn=ae_model, obs_operator=obs_op,
alpha_obs=1.0, alpha_prior=0.1,
)
grad_J = jax.grad(variational_cost)(x, batch, ae_model, obs_op, 1.0, 0.1)
Incremental 4DVar inner loop¶
import jax, lineax as lx
from vardax.solver import gauss_newton_inner, incremental_outer
# At outer iterate x_b, linearise via jax.linearize
forward_lin = jax.linearize(lambda s: forward_model.step(s, dt), x_b)
obs_op_lin = obs_op.linearize(x_b)
# Inner CG solve
dx_star = gauss_newton_inner(
dx0=jnp.zeros_like(x_b),
x_b=x_b, batch=batch,
forward_lin=forward_lin, obs_op_lin=obs_op_lin,
B_inv_op=B_inv, R_inv_op=R_inv,
n_inner=20, cg_atol=1e-5, cg_rtol=1e-5,
)
# Outer update
x_b = x_b + dx_star
# Or use the high-level helper
x_star = incremental_outer(x0, batch, forward, obs_op, B_op, R_op, config)
Control-variable transform¶
import gaussx as gx
from vardax.cvt import cvt_transform, cvt_inverse
B_half = gx.MaternLinearOperator(
grid_coords=coords, length_scale=10.0, nu=1.5, sigma=1.0,
).half()
chi = cvt_transform(x, x_b, B_half) # χ = B^{-1/2}(x - x_b)
x = cvt_inverse(chi, x_b, B_half) # x = x_b + B^{1/2} χ
In CVT coordinates the prior cost is \(\|\chi\|^2\), preconditioning CG.
Adjoint composition (Decision D15)¶
Gradients through ODE integration come from diffrax.AbstractAdjoint;
gradients through inner minimisation come from
optimistix.AbstractAdjoint. Vardax exposes both as constructor slots.
import diffrax as dfx
import optimistix as optx
from vardax.models import StrongFourDVar
# Memory-efficient operational configuration
model = StrongFourDVar(
forward=somax_model,
obs_op=obs_op,
prior_mean=x_b, prior_cov_op=B_op, obs_cov_op=R_op,
minimiser=optx.GaussNewton(rtol=1e-5, atol=1e-5),
minimiser_adjoint=optx.ImplicitAdjoint(),
forward_adjoint=dfx.BacksolveAdjoint(), # continuous adjoint, O(1) memory
)
# Standard configuration
model = StrongFourDVar(
...,
minimiser_adjoint=optx.RecursiveCheckpointAdjoint(),
forward_adjoint=dfx.RecursiveCheckpointAdjoint(),
)
# Forward sensitivity (good when params << state)
forward_adjoint = dfx.ForwardMode()
For FourDVarNet, the inner-solver adjoint controls how training
gradients flow through the unrolled learned iteration:
from vardax.models import FourDVarNet
from vardax.adjoints import OneStepAdjoint # Bolte et al. 2023
# Three adjoint options for FourDVarNet's inner solver:
solver_adjoint = optx.RecursiveCheckpointAdjoint() # O(K) memory, standard backprop
solver_adjoint = OneStepAdjoint() # O(1) memory, last step only
solver_adjoint = optx.ImplicitAdjoint() # O(1) memory, exact at fixed point
model = FourDVarNet(
prior=ae_prior, obs_op=obs_op, grad_mod=conv_lstm,
config=SolverConfig(n_steps=15),
solver_adjoint=solver_adjoint,
)
Laplace posterior covariance¶
from vardax.posterior import laplace_covariance, gauss_newton_hessian
# Laplace at MAP — for ThreeDVar / StrongFourDVar / WeakFourDVar / FourDVarNet
P_star = laplace_covariance(
x_star, cost_grad_fn=jax.grad(threedvar_cost),
B_inv_op=B_inv, R_inv_op=R_inv,
)
# Gauss-Newton Hessian — reused inside IncrementalFourDVar.posterior(batch)
H_inv = gauss_newton_hessian(
x_star, batch, forward, obs_op, B_op, R_op, n_krylov=50,
)
OptimalInterpolation and IncrementalFourDVar return posteriors
directly via .posterior(batch) — no laplace_covariance call needed.