Skip to content

Ecosystem Integration

Composition with somax, plumax, gaussx, filterax, pipekit-cycle, pipekit-experiment, georeader, coordax, GeoCatalog.


somax — Geophysical forward as dynamical prior

import somax
from vardax.priors import DynamicalPrior

# somax forwards satisfy pipekit_cycle.ForwardModel natively
swm = somax.ShallowWaterModel(grid=grid, params=params)

# Wrap as a Prior for the variational cost
prior = DynamicalPrior(forward=swm, n_steps=10)

# Plug into FourDVarNet
model = FourDVarNet(
    prior=prior,
    obs_op=MaskedIdentity(),
    grad_mod=ConvLSTMGradMod2D(hidden_dim=64),
    config=SolverConfig(n_steps=15),
)

# Or into IncrementalFourDVar (no prior wrap — forward is used directly)
incremental = IncrementalFourDVar(
    forward=swm,                  # used as the dynamics M_t
    obs_op=MaskedIdentity(),
    prior_mean=x_b, prior_cov_op=B_op, obs_cov_op=R_op,
    config=IncrementalConfig(),
)

plumax — Methane forward chain (Tier I)

Decision D7: vardax does not own methane forwards. plumax provides the forward chain (Tier I Gaussian plume → Tier IV coupled RTM). vardax does the inversion.

import plumax
from vardax.priors import DynamicalPrior
from vardax.obs_operators import AveragingKernel, MultiInstrumentFusion
from vardax.models import IncrementalFourDVar

# Forward: parameters Q, x_0, u, θ → predicted XCH4 enhancement
plume_fwd = plumax.tier1.GaussianPlume(met=met_field, dispersion="MO")

# Observation: averaging kernel from TROPOMI L2 metadata
tropomi_ak = AveragingKernel.from_l2(tropomi_l2_ds)

# Multi-instrument fusion (per-instrument bias as joint state — Epic 9)
fusion = MultiInstrumentFusion(
    registry=InstrumentRegistry.from_l2_dict({
        "TROPOMI": tropomi_l2_ds,
        "EMIT": emit_l2_ds,
        "GHGSat": ghgsat_l2_ds,
    }),
)

# Invert
inversion = IncrementalFourDVar(
    forward=plume_fwd,
    obs_op=fusion,
    prior_mean=Q_prior_mean,
    prior_cov_op=gx.LogNormalLinearOperator(Q_prior_mu, Q_prior_sigma_log),
    obs_cov_op=lx.BlockDiagonalLinearOperator([
        spec.R_op for spec in fusion.registry.entries.values()
    ]),
    config=IncrementalConfig(n_outer=3, n_inner=20),
)

x_star = inversion(batch)
posterior = GaussNewtonHessian()(x_star, inversion.as_analysis_step(), batch)

gaussx — Structured prior covariance

Decision D11: incremental 4DVar uses the control-variable transform with a gaussx.MaternLinearOperator factorisation of \(B\).

import gaussx as gx

# Matérn-3/2 on a regular grid — supports Kronecker structure
B_op = gx.MaternLinearOperator(
    grid_coords=coords,
    length_scale=10.0,         # km — basin-dependent
    nu=1.5,                     # Matérn-3/2
    sigma=1.0,
)

# Returns B^{1/2} for CVT
B_half = B_op.half()

# Or a Kronecker product (separable in lon/lat):
B_kron = gx.KroneckerLinearOperator([
    gx.Matern1DLinearOperator(coords=lon, length_scale=5.0),
    gx.Matern1DLinearOperator(coords=lat, length_scale=3.0),
])

For non-separable, structured priors, build from LowRankUpdate, BlockDiag, or factorise via Lanczos.


filterax — Hybrid ensemble-variational (Epic 9)

import filterax as fx
from vardax.posterior import EnsembleCovariance
from vardax.models import IncrementalFourDVar

# Ensemble forecast (filterax owns this)
ensemble_forecast = fx.EnsembleForecast(
    forward=somax_model,
    n_members=32,
    initial_perturbations=jax.random.normal(key, (32, *state_shape)),
)
ensemble_states = ensemble_forecast(x_b)

# Hybrid B: blend climatological + ensemble
B_hybrid = 0.5 * B_climatological + 0.5 * fx.ensemble_covariance(ensemble_states)

# Run vardax inversion with hybrid B
model = IncrementalFourDVar(
    forward=somax_model,
    obs_op=fusion,
    prior_mean=ensemble_states.mean(0),
    prior_cov_op=B_hybrid,
    obs_cov_op=R_op,
    config=IncrementalConfig(),
)
x_star = model(batch)

# Posterior from the ensemble itself
posterior = EnsembleCovariance(n_members=32)(ensemble_states, model.as_analysis_step(), batch)

pipekit-cycle — Operational DA cycling

import pipekit_cycle as pc
import vardax as vdx

# Build the model once
model = vdx.models.IncrementalFourDVar(
    forward=somax_model,
    obs_op=vdx.obs_operators.AveragingKernel(...),
    prior_mean=x_climatology,
    prior_cov_op=B_op, obs_cov_op=R_op,
    config=vdx.IncrementalConfig(),
)

# Cycle it
da_cycle = pc.DACycle(
    forward_model=somax_model,
    obs_op=model.obs_op,
    analysis_step=model.as_analysis_step(),
    obs_source=load_satellite_op,
    n_steps=n_assimilation_windows,
)
result, final_state = da_cycle(initial_state, pc.DAState(t=0.0, cycle_count=0))

pipekit-experiment + pipekit-jax — Trained model persistence

from pipekit_jax import JaxModelOp
from pipekit_experiment import LocalModelRegistry

# Wrap trained FourDVarNet for serialisation
model_op = JaxModelOp(trained_vardanet)

# Store with content-addressed hash
registry = LocalModelRegistry(root="./models")
hash_ = registry.store(
    model_op,
    weights=model_op.serialize_weights(),
    tags={"task": "ssh_oceanbench", "version": "v1"},
)

# Later — reload with a fresh skeleton
template = JaxModelOp(make_fresh_vardanet_skeleton())
reloaded = template.with_weights(registry.load_weights(hash_))

# Or use vardax shortcuts (D13)
from vardax.persist import save, load

hash_ = save(trained_vardanet, registry, tags={"task": "ssh"})
reloaded = load(registry, hash_, skeleton_factory=make_fresh_vardanet_skeleton)

pipekit-train — Training callbacks + metric writers

from pipekit_train import (
    MSE, EarlyStopping, Checkpoint, LogToExperiment, JSONLWriter, TrainingLoop,
)
from pipekit_experiment import WandbTracker

tracker = WandbTracker(project="vardax-ssh", run_name="vardanet-v1")
metric_writer = JSONLWriter("./logs/run.jsonl")

loop = TrainingLoop(
    dataset=oceanbench_ssh_dataset,
    model_op=JaxModelOp(model),
    loss=MSE(),
    callbacks=[
        EarlyStopping(metric="val_mse", patience=10),
        Checkpoint(registry=registry, every_n_epochs=1, metric="val_mse"),
        LogToExperiment(tracker),
    ],
)
trained_op, final_state = loop(JaxModelOp(model), TrainerCarryState(...))

pipekit-train is an optional [train] extra; vardax's train_step is the inner primitive plugged into pipekit_train.MSE.


georeader — Sensor data loading

import georeader as gr

# Per-instrument readers know how to handle each L1/L2 format
tropomi_l2 = gr.TROPOMI_L2.from_url("s3://tropomi/2024/01/15/...")
emit_l2    = gr.EMIT_L2.from_url("s3://emit/2024/01/15/...")
ghgsat_l2  = gr.GHGSat_L2.from_url("s3://ghgsat/2024/01/15/...")

# Each returns a GeoTensor with attached AK metadata
# Build vardax Batch from the multi-instrument set
batch = build_batch_from_readers([tropomi_l2, emit_l2, ghgsat_l2], met=met_field)

coordax — Coordinate-aware fields

Open question 1 from boundaries.md. Current design: vardax Batch* use raw Array. Optional [coords] extra exposes coordax adapters:

from vardax.adapters.coordax import batch_from_coordax_dataset

batch = batch_from_coordax_dataset(coordax_ds, state_var="ssh",
                                    obs_var="ssh_obs", mask_var="mask")
# Coordinates preserved in Batch.metadata for posterior provenance.

GeoCatalog — Event-driven batch assembly

from geotoolz import GeoCatalog

# Query overpasses across instruments for a methane event
catalog = GeoCatalog.from_geoparquet("s3://catalog/overpasses.parquet")
overpasses = catalog.query(
    geometry=event.bbox(buffer_km=50),
    interval=event.window(hours=2),
    instruments=["TROPOMI", "EMIT", "GHGSat"],
    quality_min="usable",
)

# Assemble multi-instrument batch
batch = build_event_batch(overpasses, met=met_field)

# Invert
x_star = inversion(batch)
posterior = GaussNewtonHessian()(x_star, inversion.as_analysis_step(), batch)

# Write back to catalog
catalog.write_posterior(event.id, GaussianMarkLikelihood(posterior, event.metadata))

Composition patterns

Pattern Components Use case
Learned 4DVarNet AE prior + masked obs + ConvLSTM Standard 4DVarNet SSH mapping
Physics-informed somax DynamicalPrior + learned grad mod Hybrid learned + physics DA
Operational SSH IncrementalFourDVar + somax SWM + altimetry + gaussx B Production ocean DA
Methane Tier I plumax Gaussian plume + AK + multi-instrument fusion + Incremental Single-overpass facility attribution
Methane Tier II plumax Lagrangian footprint + linear inversion + gaussx Matérn \(B\) Basin-scale regional inversion
Methane Tier IV plumax Eulerian + neural RTM + multi-instrument + amortized Operational alerts
Hybrid EnVar filterax ensemble cov + Incremental + AveragingKernel Non-Gaussian regimes
Catalog-driven GeoCatalog query → georeader → vardax → catalog write End-to-end pipeline
Research → operations Same vardax model: notebook (research), CI batch (regression), API (production) The whole arc