Layer 1 — Component Examples¶
Implementing protocols and composing eqx.Module operators.
Implementing Prior — Autoencoder¶
import equinox as eqx
from vardax.protocols import Prior # runtime-checkable Protocol
class ConvAEPrior(eqx.Module):
encoder: eqx.nn.Sequential
decoder: eqx.nn.Sequential
def __call__(self, x):
return self.decoder(self.encoder(x))
# Structural conformance — no inheritance needed
assert isinstance(ConvAEPrior(enc, dec), Prior)
Implementing ObservationOperator — Masked identity¶
Vardax obs operators satisfy pipekit_cycle.ObservationOperator
directly (Decision D8). Implement both __call__ and linearize:
from pipekit_cycle import ObservationOperator
from lineax import JacobianLinearOperator
class MaskedIdentity(eqx.Module):
"""H(x) = mask ⊙ x."""
def __call__(self, x, mask=None):
return x * mask if mask is not None else x
def linearize(self, x):
return JacobianLinearOperator(self.__call__, x)
assert isinstance(MaskedIdentity(), ObservationOperator)
Implementing ObservationOperator — Averaging kernel (D9)¶
import lineax as lx
from vardax.obs_operators import AveragingKernel
ak = AveragingKernel(
A=lx.MatrixLinearOperator(A_matrix), # or gaussx structured op
x_a=retrieval_prior,
h=weighting_vector,
)
# ŷ = A · (h ⊙ x + (1-h) ⊙ x_a)
y_pred = ak(x)
# Tangent-linear operator for incremental 4DVar
H_lin = ak.linearize(x)
y_adjoint = H_lin.T @ residual
Implementing ObservationOperator — Multi-instrument fusion¶
from vardax.obs_operators import (
AveragingKernel, MultiInstrumentFusion, InstrumentRegistry, InstrumentSpec,
)
tropomi_spec = InstrumentSpec(
obs_op=AveragingKernel(A=tropomi_A, x_a=tropomi_xa, h=tropomi_h),
mask=tropomi_qa_flag,
R_op=lx.DiagonalLinearOperator(tropomi_uncertainty),
instrument_id="TROPOMI",
)
emit_spec = InstrumentSpec(...)
ghgsat_spec = InstrumentSpec(...)
fusion = MultiInstrumentFusion(
registry=InstrumentRegistry(entries={
"TROPOMI": tropomi_spec,
"EMIT": emit_spec,
"GHGSat": ghgsat_spec,
}),
)
# Returns dict[instrument_id, predicted_obs]
predictions = fusion(x, batch)
# For strict pipekit_cycle.ObservationOperator contexts, use the adapter:
fusion_as_op = fusion.to_observation_operator()
assert isinstance(fusion_as_op, ObservationOperator)
Implementing GradModulator — ConvLSTM (FourDVarNet only)¶
from vardax.protocols import GradModulator
class ConvLSTMGradMod(eqx.Module):
conv_lstm: eqx.Module
output_proj: eqx.nn.Conv2d
def __call__(self, grad, carry):
h, c = carry
h, c = self.conv_lstm(grad, (h, c))
return self.output_proj(h), (h, c)
assert isinstance(ConvLSTMGradMod(lstm, proj), GradModulator)
The grad modulator family is FourDVarNet-specific. Classical methods
use optimistix.AbstractMinimiser for the inner solver instead.
Wrapping a Minimiser (classical methods)¶
import optimistix as optx
from vardax.minimisers import Minimiser
from vardax.models import ThreeDVar
# Pick any optimistix minimiser
gn = Minimiser(optx.GaussNewton(rtol=1e-5, atol=1e-5),
adjoint=optx.ImplicitAdjoint())
bfgs = Minimiser(optx.BFGS(rtol=1e-5, atol=1e-5))
ncg = Minimiser(optx.NonlinearCG(rtol=1e-5, atol=1e-5))
model = ThreeDVar(
obs_op=obs_op,
prior_mean=x_b, prior_cov_op=B_op, obs_cov_op=R_op,
minimiser=gn.minimiser, # the underlying optimistix solver
minimiser_adjoint=gn.adjoint,
)
The Minimiser wrapper is convenience; you can also pass the
optimistix.AbstractMinimiser directly into the model constructor.
Wrapping a somax / plumax forward as a DynamicalPrior¶
import somax
from vardax.priors import DynamicalPrior
import diffrax as dfx
# somax already satisfies pipekit_cycle.ForwardModel
swm = somax.ShallowWaterModel(grid=grid, params=params)
# Wrap as a Prior (integrates n_steps forward)
prior = DynamicalPrior(
forward=swm, n_steps=10,
forward_adjoint=dfx.BacksolveAdjoint(), # memory-efficient
)
For methane / plumax:
import plumax
plume = plumax.GaussianPlumeForward(met=met_field, dispersion="MO")
prior = DynamicalPrior(forward=plume, n_steps=1) # single-shot for Tier I
Composing components into an OptimalInterpolation¶
import gaussx as gx
from vardax.models import OptimalInterpolation
from vardax.obs_operators import LinearObs
model = OptimalInterpolation(
obs_op=LinearObs(H_mat=H_lin_op), # must be linear
prior_mean=x_b,
prior_cov_op=gx.MaternLinearOperator(coords, length_scale=10.0, sigma=0.1),
obs_cov_op=lx.DiagonalLinearOperator(obs_uncertainty),
)
x_star = model(batch) # closed-form
posterior = model.posterior(batch) # closed-form (mean, cov, provenance)
Composing components into a ThreeDVar¶
import optimistix as optx
from vardax.models import ThreeDVar
model = ThreeDVar(
obs_op=AveragingKernel(A=A, x_a=xa, h=h), # nonlinear via h ⊙ x term
prior_mean=x_b,
prior_cov_op=B_op, obs_cov_op=R_op,
minimiser=optx.GaussNewton(rtol=1e-5, atol=1e-5),
minimiser_adjoint=optx.ImplicitAdjoint(),
)
x_star = model(batch)
Composing components into an IncrementalFourDVar¶
from vardax.models import IncrementalFourDVar
from vardax import IncrementalConfig
import diffrax as dfx
model = IncrementalFourDVar(
forward=somax_model,
obs_op=AveragingKernel(A=A, x_a=xa, h=h),
prior_mean=x_b,
prior_cov_op=B_op, obs_cov_op=R_op,
config=IncrementalConfig(n_outer=3, n_inner=20, cvt=True),
forward_adjoint=dfx.BacksolveAdjoint(),
)
x_star = model(batch)
posterior = model.posterior(batch) # reuses GN Hessian from last outer iter
Composing components into a FourDVarNet¶
from vardax.models import FourDVarNet
from vardax import SolverConfig
from vardax.adjoints import OneStepAdjoint
model = FourDVarNet(
prior=ConvAEPrior(encoder=enc, decoder=dec),
obs_op=MaskedIdentity(),
grad_mod=ConvLSTMGradMod(lstm, proj),
config=SolverConfig(n_steps=15, alpha=0.2),
solver_adjoint=OneStepAdjoint(), # O(1) memory training
)
Posterior adapters¶
from vardax.posterior import LaplaceCovariance, GaussianMarkLikelihood
# Classical: pair with explicit adapter
x_star = strong_4dvar(batch)
posterior = LaplaceCovariance()(x_star, strong_4dvar.as_analysis_step(), batch)
# OI / Incremental: direct call, no adapter
posterior = oi.posterior(batch)
posterior = incremental.posterior(batch)
# Export to population model
mark = GaussianMarkLikelihood(
posterior=posterior,
event_metadata={"event_id": "ev_001", "time": ..., "geometry": ...},
).to_dict()