A variational analysis returns a point estimate; this layer attaches the
uncertainty around it. The design is deliberately lazy and matrix-free:
a posterior adapter never materialises the covariance matrix. Instead it
builds a lineax operator for the Hessian (or precision) at the analysis
point, and inverses, diagonals (pointwise variances), and ensemble
estimates are delegated to
gaussx — gaussx.inv,
gaussx.diag, and gaussx.ensemble_covariance — so structured and
iterative solves come for free. See
Posterior Covariance in the Mathematical
Reference for derivations and when each approximation is trustworthy.
Adapters satisfy the PosteriorAdapter Protocol and all
produce the same Posterior container, so downstream diagnostics (e.g. the
posterior-agreement gate) are adapter-agnostic. Choose
LaplaceCovariance for the exact-Hessian Gaussian approximation around the
mode, GaussNewtonHessian when second derivatives of the observation
operator are unavailable or noisy, and EnsembleCovariance when samples
are cheaper than curvature.
Posterior covariance as lineax.AbstractLinearOperator,
or None for amortized / sampling-only heads. Stored
lazily — never materialised into a dense matrix unless the
user asks.
samples
Float[Array, ...] | None
Optional (B, M, ...) posterior samples (from a
flow head, ensemble, or score-based diffusion). None
for analytical adapters.
provenance
dict[str, Any]
Free-form dict carrying audit info — at minimum
{forward_model_id, n_iter, J_star, converged,
vardax_version}. Consumed by
GaussianMarkLikelihood
when serialising for downstream population models.
classPosterior(eqx.Module):"""Posterior container — Decision D10. Attributes: mean: Posterior mean / MAP estimate. cov: Posterior covariance as ``lineax.AbstractLinearOperator``, or ``None`` for amortized / sampling-only heads. Stored lazily — never materialised into a dense matrix unless the user asks. samples: Optional ``(B, M, ...)`` posterior samples (from a flow head, ensemble, or score-based diffusion). ``None`` for analytical adapters. provenance: Free-form dict carrying audit info — at minimum ``{forward_model_id, n_iter, J_star, converged, vardax_version}``. Consumed by [`GaussianMarkLikelihood`][vardax.GaussianMarkLikelihood] when serialising for downstream population models. Examples: >>> import jax.numpy as jnp >>> import vardax as vdx >>> post = vdx.Posterior(mean=jnp.zeros(4)) >>> post.mean.shape (4,) >>> post.cov is None True >>> post.provenance {} """mean:Float[Array,...]cov:lx.AbstractLinearOperator|None=Nonesamples:Float[Array,...]|None=Noneprovenance:dict[str,Any]=eqx.field(default_factory=dict)
Both required so the adapter can build \(P^*\) lazily. They
should match the operators used by the analysis method that
produced analysis.
Examples:
>>> importjax,jax.numpyasjnp,lineaxaslx>>> importvardaxasvdx>>> fromtypesimportSimpleNamespace>>> eye=lx.IdentityLinearOperator(jax.ShapeDtypeStruct((4,),jnp.float32))>>> laplace=vdx.LaplaceCovariance(prior_cov_op=eye,obs_cov_op=eye)>>> model=SimpleNamespace(obs_op=vdx.LinearObs(H_mat=eye))>>> post=laplace(jnp.zeros(4),model,batch=None)>>> post.mean.shape(4,)>>> # B = R = H = I, so P* = (I + I)^{-1} = I / 2>>> bool(jnp.allclose(post.cov.mv(jnp.ones(4)),0.5,atol=1e-2))True
Source code in src/vardax/_src/posterior/laplace.py
classLaplaceCovariance(eqx.Module):r"""Laplace approximation at MAP. Attributes: prior_cov_op: Background-error covariance $B$. obs_cov_op: Observation-error covariance $R$. Both required so the adapter can build $P^*$ lazily. They should match the operators used by the analysis method that produced ``analysis``. Examples: >>> import jax, jax.numpy as jnp, lineax as lx >>> import vardax as vdx >>> from types import SimpleNamespace >>> eye = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((4,), jnp.float32)) >>> laplace = vdx.LaplaceCovariance(prior_cov_op=eye, obs_cov_op=eye) >>> model = SimpleNamespace(obs_op=vdx.LinearObs(H_mat=eye)) >>> post = laplace(jnp.zeros(4), model, batch=None) >>> post.mean.shape (4,) >>> # B = R = H = I, so P* = (I + I)^{-1} = I / 2 >>> bool(jnp.allclose(post.cov.mv(jnp.ones(4)), 0.5, atol=1e-2)) True """prior_cov_op:lx.AbstractLinearOperatorobs_cov_op:lx.AbstractLinearOperatordef__call__(self,analysis:Float[Array,...],model:Any,# AnalysisStep-compliantbatch:Any,)->Posterior:"""Build the Laplace posterior at ``analysis``. Args: analysis: MAP / posterior mean. model: The analysis-step instance (used to recover the observation operator). Expected to have a ``.model`` or be the underlying ``eqx.Module`` directly. batch: The batch the analysis was computed against (used to recover the mask, instrument bookkeeping, etc.). Returns: [`Posterior`][vardax.Posterior] with ``mean = analysis`` and ``cov`` as a lazy ``AbstractLinearOperator`` representing $P^*$. """# Pull the obs operator off either an .as_analysis_step()# wrapper or the raw model. Both expose ``obs_op``.underlying=getattr(model,"model",model)obs_op=getattr(underlying,"obs_op",None)ifobs_opisNone:raiseAttributeError("LaplaceCovariance requires the model to expose `obs_op`; ""got a model without one.")H=obs_op.linearize(analysis)# P*^{-1} = H^T R^{-1} H + B^{-1}# Build lazily via operator composition (gaussx.sandwich would# materialise the Jacobian, breaking the matrix-free design);# the inverses are gaussx.inv, which dispatches to closed forms# for structured operators and falls back to lineax.CG mat-vecs.precision=H.transpose()@_inverse_op(self.obs_cov_op)@H+_inverse_op(self.prior_cov_op)precision_tagged=lx.TaggedLinearOperator(precision,lx.positive_semidefinite_tag)returnPosterior(mean=analysis,cov=gaussx.inv(precision_tagged,solver=_CG_SOLVER),samples=None,provenance={"adapter":"LaplaceCovariance"},)
Functionally similar to
LaplaceCovariance — both compute
\((B^{-1} + H^\top R^{-1} H)^{-1}\) (or its 4D extension)
— but the GN-Hessian adapter is the recommended path for
IncrementalFourDVar where the Hessian is already
materialised by the inner CG solver.
Attributes:
Name
Type
Description
prior_cov_op
AbstractLinearOperator
\(B\).
obs_cov_op
AbstractLinearOperator
\(R\).
n_krylov
int
Maximum Krylov iterations for mat-vec evaluations.
Source code in src/vardax/_src/posterior/gauss_newton.py
classGaussNewtonHessian(eqx.Module):r"""Gauss-Newton Hessian inversion at MAP. Functionally similar to [`LaplaceCovariance`][vardax.LaplaceCovariance] — both compute $(B^{-1} + H^\top R^{-1} H)^{-1}$ (or its 4D extension) — but the GN-Hessian adapter is the recommended path for ``IncrementalFourDVar`` where the Hessian is already materialised by the inner CG solver. Attributes: prior_cov_op: $B$. obs_cov_op: $R$. n_krylov: Maximum Krylov iterations for mat-vec evaluations. """prior_cov_op:lx.AbstractLinearOperatorobs_cov_op:lx.AbstractLinearOperatorn_krylov:int=eqx.field(static=True,default=50)def__call__(self,analysis:Float[Array,...],model:Any,batch:Any,)->Posterior:"""Build the GN-Hessian posterior at ``analysis``."""underlying=getattr(model,"model",model)obs_op=getattr(underlying,"obs_op",None)ifobs_opisNone:raiseAttributeError("GaussNewtonHessian requires the model to expose `obs_op`.")H=obs_op.linearize(analysis)precision=H.transpose()@_inverse_op(self.obs_cov_op)@H+_inverse_op(self.prior_cov_op)precision_tagged=lx.TaggedLinearOperator(precision,lx.positive_semidefinite_tag)returnPosterior(mean=analysis,cov=gaussx.inv(precision_tagged,solver=_CG_SOLVER),samples=None,provenance={"adapter":"GaussNewtonHessian","n_krylov":self.n_krylov,},)
classEnsembleCovariance(eqx.Module):r"""Sample-covariance posterior from an ensemble of analyses. Attributes: n_members: Expected ensemble size (used for diagnostics). inflation: Multiplicative inflation factor $\lambda$ applied to the sample covariance. Default 1.0 (no inflation). """n_members:int=eqx.field(static=True)inflation:float=eqx.field(static=True,default=1.0)def__call__(self,analyses:Float[Array,"M ..."],model:Any,batch:Any,)->Posterior:"""Build the ensemble posterior from per-member analyses. Args: analyses: Stack of ``M`` analyses with shape ``(M, ...)``. model: AnalysisStep instance (not used directly; kept for interface compatibility). batch: The batch (not used directly; kept for interface compatibility). Returns: [`Posterior`][vardax.Posterior] with mean = ensemble mean, ``cov`` = sample covariance as a dense ``MatrixLinearOperator`` (only assembled if you ask for it; safe for moderate M). """m=analyses.shape[0]ifm!=self.n_members:# Allow runtime mismatch; the configured ``n_members`` is# advisory rather than enforced.passmean=jnp.mean(analyses,axis=0)# Flatten everything but the leading ensemble dim; the# Bessel-corrected sample covariance comes from gaussx as a# rank-(m-1) LowRankUpdate operator (never dense), scaled by# the multiplicative inflation factor.cov_op=self.inflation*gaussx.ensemble_covariance(analyses.reshape(m,-1),bessel=True)returnPosterior(mean=mean,cov=cov_op,samples=analyses,provenance={"adapter":"EnsembleCovariance","n_members":int(m),"inflation":float(self.inflation),},)
Free-form dict (event_id, time, geometry,
instruments_used, …). Passed through verbatim into the
output.
Use .to_dict() to get a JSON-friendly representation that
downstream consumers (GeoCatalog, hierarchical Bayesian models,
DuckDB ingest) can store as-is.
classGaussianMarkLikelihood(eqx.Module):"""Serialise a Gaussian posterior into a mark-likelihood dict. Attributes: posterior: A vardax ``Posterior`` (with Gaussian ``cov``). event_metadata: Free-form dict (event_id, time, geometry, instruments_used, …). Passed through verbatim into the output. Use ``.to_dict()`` to get a JSON-friendly representation that downstream consumers (GeoCatalog, hierarchical Bayesian models, DuckDB ingest) can store as-is. Examples: >>> import jax.numpy as jnp >>> import vardax as vdx >>> post = vdx.Posterior(mean=jnp.zeros(3)) >>> mark = vdx.GaussianMarkLikelihood( ... posterior=post, event_metadata={"event_id": "e1"} ... ) >>> sorted(mark.to_dict()) ['cov_diag', 'event_metadata', 'mean', 'provenance', 'samples'] """posterior:Posteriorevent_metadata:dict[str,Any]=eqx.field(default_factory=dict)defto_dict(self)->dict[str,Any]:"""Return a JSON-friendly mark-likelihood representation."""out:dict[str,Any]={"mean":_to_list(self.posterior.mean),"cov_diag":(_to_list(self._cov_diag())ifself.posterior.covisnotNoneelseNone),"samples":(_to_list(self.posterior.samples)ifself.posterior.samplesisnotNoneelseNone),"provenance":dict(self.posterior.provenance),"event_metadata":dict(self.event_metadata),}returnoutdef_cov_diag(self):"""Extract the diagonal of the covariance operator. For materialised operators this is cheap; for lazy operators it costs one mat-vec per state dim (probe with unit vectors). We pick the materialised path when the operator exposes it, else fall back to a single ``as_matrix()`` call. """cov=self.posterior.covifcovisNone:returnNonetry:returngaussx.diag(cov)except(AttributeError,NotImplementedError):# Lazy operators don't always materialise; just emit a# null marker.returnNone
defto_dict(self)->dict[str,Any]:"""Return a JSON-friendly mark-likelihood representation."""out:dict[str,Any]={"mean":_to_list(self.posterior.mean),"cov_diag":(_to_list(self._cov_diag())ifself.posterior.covisnotNoneelseNone),"samples":(_to_list(self.posterior.samples)ifself.posterior.samplesisnotNoneelseNone),"provenance":dict(self.posterior.provenance),"event_metadata":dict(self.event_metadata),}returnout