Skip to content

SSH Reconstruction Example

Sea surface height (SSH) reconstruction from satellite altimetry is the canonical large-scale 4DVarNet benchmark (Fablet et al. 2021, 2023). The state is a 2D field; observations are along-track samples from altimetric satellites (Jason, SARAL/AltiKa, SWOT); the gaps between tracks need to be filled.

This chapter shows three approaches to the same problem: OI as the classical baseline, IncrementalFourDVar as the operational pattern, and FourDVarNet as the learned alternative. Same data, three analyses, side-by-side comparison.

Problem setup

State: SSH on a regular lon-lat grid, \(H \times W = 128 \times 128\) typical for a regional patch. Observations: along-track samples from one or more altimeters, with along-track noise \(\sigma \sim 3\) cm and gaps of \(\sim 100\) km perpendicular to track.

import equinox as eqx
import jax.numpy as jnp
import lineax as lx
import gaussx as gx
import optimistix as optx
import diffrax as dfx

from vardax import Batch2D
from vardax.obs_operators import MaskedIdentity, LinearObs

# Build a Batch2D from preprocessed altimetric data
batch = Batch2D(
    input=ssh_observed,          # (B, T, H, W) with NaN in gaps
    mask=track_mask,             # 1 along-track, 0 in gaps
    target=ssh_truth,            # ground truth (training only)
    obs_err=altika_uncertainty,
)

Approach 1 — OptimalInterpolation

The classical baseline. Linear \(H\) (masked identity), Gaussian Matérn prior, Gaussian diagonal obs error — closed-form fast path.

from vardax.models import OptimalInterpolation

coords = jnp.stack(jnp.meshgrid(lon, lat, indexing="ij"), axis=-1)
B_op = gx.MaternLinearOperator(
    grid_coords=coords.reshape(-1, 2), length_scale=100.0, sigma=0.1, nu=1.5,
)
R_op = lx.DiagonalLinearOperator(altika_uncertainty.flatten() ** 2)

oi = OptimalInterpolation(
    obs_op=MaskedIdentity(),
    prior_mean=climatology_ssh,
    prior_cov_op=B_op,
    obs_cov_op=R_op,
)

# Single forward pass — no iteration
x_oi = oi(batch)
posterior_oi = oi.posterior(batch)
# Reconstruction in gaps comes from the Matérn correlation length.

What this gives you: the standard OceanBench DUACS baseline. Smoothed through the prior length scale. Easy to interpret, easy to validate. Misses mesoscale features that are shorter than the correlation length.

Approach 2 — IncrementalFourDVar with shallow-water dynamics

Add a somax.ShallowWaterModel as the dynamics. Multi-time observations now constrain the trajectory through the physics — gaps get filled by dynamical propagation rather than just smoothing.

import somax
from vardax.models import IncrementalFourDVar
from vardax import IncrementalConfig

swm = somax.ShallowWaterModel(grid=grid, params=ssh_params)

incremental = IncrementalFourDVar(
    forward=swm,
    obs_op=MaskedIdentity(),
    prior_mean=climatology_ssh,
    prior_cov_op=B_op,
    obs_cov_op=R_op,
    config=IncrementalConfig(n_outer=3, n_inner=20, cvt=True),
    forward_adjoint=dfx.RecursiveCheckpointAdjoint(),
)

x_inc = incremental(batch)
posterior_inc = incremental.posterior(batch)
# Mesoscale eddies now propagate consistently across observation gaps;
# the posterior reflects model + observation uncertainty.

What this gives you: physically-consistent reconstruction. The incremental machinery scales to operational grids (~\(10^6\)\(10^7\) cells) thanks to the CVT-preconditioned CG.

Approach 3 — FourDVarNet with learned prior

Train a FourDVarNet on (observed, truth) patches. The learned prior \(\varphi_\theta\) captures the SSH manifold; the learned grad modulator \(\Phi_\phi\) accelerates inner iterations.

from vardax.models import FourDVarNet
from vardax.priors import BilinAEPrior2D
from vardax.grad_mod import ConvLSTMGradMod2D
from vardax import SolverConfig
from vardax.adjoints import OneStepAdjoint
from vardax.training import train_step

model = FourDVarNet(
    prior=BilinAEPrior2D(latent_dim=128, n_time=10, height=128, width=128),
    obs_op=MaskedIdentity(),
    grad_mod=ConvLSTMGradMod2D(state_channels=10, hidden_dim=64),
    config=SolverConfig(n_steps=15, alpha=0.2),
    solver_adjoint=OneStepAdjoint(),       # O(1) memory training
)

# Train on the OceanBench dataset
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

@eqx.filter_jit
def step(model, opt_state, batch):
    model, opt_state, loss = train_step(model, batch, optimizer, opt_state)
    return model, opt_state, loss

for epoch in range(100):
    for batch in oceanbench_loader:
        model, opt_state, loss = step(model, opt_state, batch)

# Reconstruction
x_net = model(batch)

What this gives you: state-of-the-art reconstruction on the distribution the network was trained on. Better than OI in information-rich regions; comparable to incremental 4DVar without needing somax. Limitations: extrapolates poorly outside the training distribution; posterior calibration via Laplace is less defensible than the GN-Hessian-based posterior from incremental 4DVar.

Comparison

Method Setup cost Per-event cost Recovers mesoscale? Posterior quality When
OptimalInterpolation trivial low (single pass) only down to prior length scale exact (closed form) Linear-Gaussian baseline
IncrementalFourDVar requires somax model medium (3 outer × 20 inner CG) yes, via dynamics calibrated (GN-Hessian) Production / operational
FourDVarNet requires training low (15 solver steps) yes, via learned prior needs validation Data-rich regime

The OceanBench benchmark (Le Guillou et al. 2023) provides a standard dataset for these comparisons. The expected ordering is OI < IncrementalFourDVar ≈ FourDVarNet, with FourDVarNet matching IncrementalFourDVar on in-distribution data but degrading more gracefully when the input distribution drifts.

Operational cycling

For continuous SSH analysis (every overpass, every few hours), wrap the chosen analysis in a pipekit_cycle.DACycle:

import pipekit_cycle as pc

ssh_cycle = pc.DACycle(
    forward_model=swm,
    obs_op=MaskedIdentity(),
    analysis_step=incremental.as_analysis_step(),    # or oi / net
    obs_source=along_track_loader,
    n_steps=n_assimilation_windows,
)

# Run continuously
result, final_state = ssh_cycle(initial_state, pc.DAState(t=0.0, cycle_count=0))

The analysis_step slot accepts any of the three methods — swap OptimalInterpolation for IncrementalFourDVar for FourDVarNet by changing one slot. Nothing in the cycling, data loading, or output serialisation changes.

Persistence

For trained FourDVarNet, persist via pipekit-jax.JaxModelOp + pipekit-experiment.ModelRegistry:

from pipekit_jax import JaxModelOp
from pipekit_experiment import LocalModelRegistry

registry = LocalModelRegistry(root="./ssh_models")
model_op = JaxModelOp(model)
hash_ = registry.store(
    model_op, weights=model_op.serialize_weights(),
    tags={"task": "ssh_oceanbench", "version": "v1"},
)

# Later — reload with a fresh skeleton
template = JaxModelOp(make_fresh_skeleton())
reloaded = template.with_weights(registry.load_weights(hash_))

OptimalInterpolation and IncrementalFourDVar have no learned parameters; they need no persistence (the prior_cov_op and config are reconstructed from numerical parameters).

See also

  • Chapter 4 — OptimalInterpolation (BLUE / OI)
  • Chapter 8 — IncrementalFourDVar (operational fast path)
  • Chapter 9 — FourDVarNet (learned variant)
  • Chapter 14 — six-step cycle (the methodology that ties this together)

References

  • Fablet, R., Febvre, Q., & Chapron, B. (2023). Multimodal 4DVarNets for the reconstruction of sea surface dynamics from NADIR and wide-swath altimetry. IEEE TGRS 61.
  • Le Guillou, F., et al. (2023). Mapping altimetry in the forthcoming SWOT era by back-and-forth nudging a one-layer quasigeostrophic model. JTECH 40(1). [OceanBench reference.]
  • Ubelmann, C., Klein, P., & Fu, L.-L. (2015). Dynamic interpolation of sea surface height and potential applications for future high-resolution altimetry mapping. JTECH 32.