Posterior Adapter¶
Per Decision D10, every vardax analysis emits a Posterior container —
not just a point estimate. The posterior carries enough information to feed
downstream population models (Tier V TMTPP, hierarchical Bayesian
inversions) without re-running the inner inference.
Two of the seven Layer 2 classes have closed-form fast paths that emit the posterior as part of the analysis call:
OptimalInterpolation— closed-form \(P^* = (B^{-1} + H^\top R^{-1} H)^{-1}\) is returned directly by.posterior(batch).IncrementalFourDVar— the Gauss-Newton Hessian assembled during the last outer iteration is reused for \(P^*\).
For the other five (ThreeDVar, StrongFourDVar, WeakFourDVar,
FourDVarNet, AmortizedPosterior), pair the analysis with an
explicit PosteriorAdapter.
The Posterior container¶
class Posterior(eqx.Module):
mean: Array # MAP / posterior mean
cov: AbstractLinearOperator | None # gaussx / lineax operator (may be None)
samples: Array | None # (B, M, ...) for ensembles / flow samples
provenance: dict # forward_model_id, n_iter, J_star, …
Conventions:
mean— always populated.cov—Nonefor amortized samples / score-based heads. Otherwise anAbstractLinearOperatorsupporting mat-vec (not necessarily materialised).samples—Nonefor Laplace / GN-Hessian (Gaussian-only). Populated for ensemble / flow heads.provenance— dict with at minimum{forward_model_id, obs_ops_used, n_iter, J_star, converged, gaussx_op_hash, model_hash}.
Three posterior adapter families¶
LaplaceCovariance — cheap, Gaussian-only¶
At the MAP \(x^*\):
where \(H' = \partial H / \partial x\) at \(x^*\). Returned as an
AbstractLinearOperator so mat-vec via lineax.CG is cheap; full
materialisation only on request.
When to use. Gaussian likelihood, single mode, MAP near posterior mean
(confirmed by SBC). Default for IncrementalFourDVar.
Cost. One Hessian-vector product family per query. Posterior samples by \(x^* + (H'^\top R^{-1} H' + B^{-1})^{-1/2} \xi\) where \(\xi \sim \mathcal{N}(0,I)\) — requires square-root which may be expensive for unstructured \(B\).
class LaplaceCovariance(eqx.Module):
def __call__(self, analysis: Array, model: AnalysisStep, batch: Batch) -> Posterior: ...
GaussNewtonHessian — mid-cost, exact at MAP¶
Krylov / Lanczos inversion of \(J''(x^*)\) via lineax.CG. For incremental
4DVar this is the natural posterior — the inner Hessian is already assembled
during the last outer iteration.
When to use. Operational incremental 4DVar (IncrementalFourDVar). The
Hessian operator from the last GN outer iteration is reused.
Cost. n_krylov mat-vec products. Materialise only the diagonal /
required marginals.
class GaussNewtonHessian(eqx.Module):
n_krylov: int = eqx.field(static=True, default=50)
def __call__(self, analysis: Array, model: AnalysisStep, batch: Batch) -> Posterior: ...
EnsembleCovariance — non-Gaussian-aware¶
Delegates to filterax. Build \(M\) analyses from \(M\) perturbed initial
conditions / observation realisations, return the sample covariance.
When to use. Multimodal posterior, non-Gaussian likelihood, hybrid EnVar inversion.
Cost. \(M\) vardax analyses (parallel via eqx.filter_vmap). Sample
covariance materialisation O(\(M N^2\)) for full; localised / shrunk
estimators available via filterax.
class EnsembleCovariance(eqx.Module):
n_members: int = eqx.field(static=True)
inflation: float = 1.0
localization_radius: float | None = None
def __call__(self, analyses: Array, model: AnalysisStep, batch: Batch) -> Posterior: ...
Posterior selection heuristic¶
| Inference family | Default posterior adapter |
|---|---|
OptimalInterpolation |
closed-form via .posterior(batch) (no adapter) |
ThreeDVar |
LaplaceCovariance at MAP |
StrongFourDVar, WeakFourDVar |
LaplaceCovariance at MAP |
IncrementalFourDVar |
GaussNewtonHessian via .posterior(batch) (Hessian reused) |
FourDVarNet |
LaplaceCovariance at converged solver state |
AmortizedPosterior |
direct sampling (Posterior.samples, cov=None) |
Hybrid EnVarFourDVar (Epic 9) |
EnsembleCovariance |
Hybrid EnVar (EnVarDA*, Epic 9) |
EnsembleCovariance |
Export to population models — GaussianMarkLikelihood¶
Per-event posteriors feed Tier V (TMTPP, hierarchical Bayesian) via a mark-likelihood serialisation:
class GaussianMarkLikelihood(eqx.Module):
"""Posterior → mark-likelihood for population models (Tier V, Decision D10)."""
posterior: Posterior
event_metadata: dict # event_id, time, location, instruments_used, …
def to_dict(self) -> dict:
return {
"event_id": self.event_metadata["event_id"],
"time": self.event_metadata["time"],
"geometry": self.event_metadata["geometry"],
"mean": self.posterior.mean.tolist(),
"cov": self._serialise_cov(), # diagonal / Cholesky / LowRank repr
"samples": (self.posterior.samples.tolist()
if self.posterior.samples is not None else None),
"provenance": self.posterior.provenance,
}
def _serialise_cov(self) -> dict:
"""Serialise gaussx AbstractLinearOperator to JSON-compatible form."""
# ...
The serialised form is consumed by population models without re-instantiating the vardax inference — Tier V improvements automatically absorb upgrades to Tier I-IV forwards.
Posterior validation gates (Decision D12)¶
Six-step cycle requires:
-
Posterior agreement. Step 4 (emulator MAP) within \(1\sigma_\text{post}\) of Step 2 (physics MAP). Failed when CV(amortized) / CV(physics) > 2 or when |mean_amortized - mean_physics| / σ_physics > 1.
-
Adjoint calibration. \(\|\partial H_\text{em} / \partial x - \partial H_\text{phys} / \partial x\|_\text{op} < 5\%\) measured by random-vector probing.
-
Simulation-based calibration (SBC). Rank histograms uniform across parameters, stratified by met regime. Failed when the χ² test of uniformity rejects at p < 0.01.
These gates run on a held-out validation set in CI:
# tests/test_six_step_validation.py
def test_posterior_agreement(physics_model, emulator_model, val_batches):
for batch in val_batches:
p_phys = LaplaceCovariance()(physics_model(batch), physics_model, batch)
p_em = LaplaceCovariance()(emulator_model(batch), emulator_model, batch)
assert_posterior_agreement(p_em, p_phys, tolerance_sigma=1.0)
def test_adjoint_calibration(physics_obs_op, emulator_obs_op, val_states):
for x in val_states:
H_phys = physics_obs_op.linearize(x)
H_em = emulator_obs_op.linearize(x)
op_norm = estimate_operator_norm_diff(H_phys, H_em, n_probe=20)
assert op_norm < 0.05
Provenance schema¶
Posterior.provenance carries at minimum:
| Key | Type | Description |
|---|---|---|
forward_model_id |
str |
Identifier for the ForwardModel used (e.g. "plumax.tier1.gaussian") |
forward_model_hash |
str |
Content hash from pipekit-experiment.ModelRegistry |
obs_ops_used |
list[str] |
Instrument IDs whose obs operators contributed |
n_iter |
int |
Total inner iterations (or outer × inner for incremental) |
J_star |
float |
Final cost value |
converged |
bool |
True if convergence criterion met |
model_hash |
str \| None |
For learned heads, content hash of trained weights |
gaussx_op_hash |
str \| None |
Hash of \(B\) / \(R\) operators used |
met_source |
str \| None |
Met forcing identifier (e.g. "era5_2024-01-15T12Z") |
vardax_version |
str |
Package version |
This schema is the contract between the inference layer (vardax) and the audit layer (catalog / database). Don't break it lightly.