Skip to content

Posterior Covariance

A variational analysis returns a point estimate; this layer attaches the uncertainty around it. The design is deliberately lazy and matrix-free: a posterior adapter never materialises the covariance matrix. Instead it builds a lineax operator for the Hessian (or precision) at the analysis point, and inverses, diagonals (pointwise variances), and ensemble estimates are delegated to gaussxgaussx.inv, gaussx.diag, and gaussx.ensemble_covariance — so structured and iterative solves come for free. See Posterior Covariance in the Mathematical Reference for derivations and when each approximation is trustworthy.

Adapters satisfy the PosteriorAdapter Protocol and all produce the same Posterior container, so downstream diagnostics (e.g. the posterior-agreement gate) are adapter-agnostic. Choose LaplaceCovariance for the exact-Hessian Gaussian approximation around the mode, GaussNewtonHessian when second derivatives of the observation operator are unavailable or noisy, and EnsembleCovariance when samples are cheaper than curvature.

Container

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

Posterior

Bases: Module

Posterior container — Decision D10.

Attributes:

Name Type Description
mean Float[Array, ...]

Posterior mean / MAP estimate.

cov AbstractLinearOperator | None

Posterior covariance as lineax.AbstractLinearOperator, or None for amortized / sampling-only heads. Stored lazily — never materialised into a dense matrix unless the user asks.

samples Float[Array, ...] | None

Optional (B, M, ...) posterior samples (from a flow head, ensemble, or score-based diffusion). None for analytical adapters.

provenance dict[str, Any]

Free-form dict carrying audit info — at minimum {forward_model_id, n_iter, J_star, converged, vardax_version}. Consumed by GaussianMarkLikelihood when serialising for downstream population models.

Examples:

>>> import jax.numpy as jnp
>>> import vardax as vdx
>>> post = vdx.Posterior(mean=jnp.zeros(4))
>>> post.mean.shape
(4,)
>>> post.cov is None
True
>>> post.provenance
{}
Source code in src/vardax/_src/posterior/container.py
class Posterior(eqx.Module):
    """Posterior container — Decision D10.

    Attributes:
        mean: Posterior mean / MAP estimate.
        cov: Posterior covariance as ``lineax.AbstractLinearOperator``,
            or ``None`` for amortized / sampling-only heads. Stored
            lazily — never materialised into a dense matrix unless the
            user asks.
        samples: Optional ``(B, M, ...)`` posterior samples (from a
            flow head, ensemble, or score-based diffusion). ``None``
            for analytical adapters.
        provenance: Free-form dict carrying audit info — at minimum
            ``{forward_model_id, n_iter, J_star, converged,
            vardax_version}``. Consumed by
            [`GaussianMarkLikelihood`][vardax.GaussianMarkLikelihood]
            when serialising for downstream population models.

    Examples:
        >>> import jax.numpy as jnp
        >>> import vardax as vdx
        >>> post = vdx.Posterior(mean=jnp.zeros(4))
        >>> post.mean.shape
        (4,)
        >>> post.cov is None
        True
        >>> post.provenance
        {}
    """

    mean: Float[Array, ...]
    cov: lx.AbstractLinearOperator | None = None
    samples: Float[Array, ...] | None = None
    provenance: dict[str, Any] = eqx.field(default_factory=dict)

Adapters

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

LaplaceCovariance

Bases: Module

Laplace approximation at MAP.

Attributes:

Name Type Description
prior_cov_op AbstractLinearOperator

Background-error covariance \(B\).

obs_cov_op AbstractLinearOperator

Observation-error covariance \(R\).

Both required so the adapter can build \(P^*\) lazily. They should match the operators used by the analysis method that produced analysis.

Examples:

>>> import jax, jax.numpy as jnp, lineax as lx
>>> import vardax as vdx
>>> from types import SimpleNamespace
>>> eye = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((4,), jnp.float32))
>>> laplace = vdx.LaplaceCovariance(prior_cov_op=eye, obs_cov_op=eye)
>>> model = SimpleNamespace(obs_op=vdx.LinearObs(H_mat=eye))
>>> post = laplace(jnp.zeros(4), model, batch=None)
>>> post.mean.shape
(4,)
>>> # B = R = H = I, so P* = (I + I)^{-1} = I / 2
>>> bool(jnp.allclose(post.cov.mv(jnp.ones(4)), 0.5, atol=1e-2))
True
Source code in src/vardax/_src/posterior/laplace.py
class LaplaceCovariance(eqx.Module):
    r"""Laplace approximation at MAP.

    Attributes:
        prior_cov_op: Background-error covariance $B$.
        obs_cov_op: Observation-error covariance $R$.

    Both required so the adapter can build $P^*$ lazily. They
    should match the operators used by the analysis method that
    produced ``analysis``.

    Examples:
        >>> import jax, jax.numpy as jnp, lineax as lx
        >>> import vardax as vdx
        >>> from types import SimpleNamespace
        >>> eye = lx.IdentityLinearOperator(jax.ShapeDtypeStruct((4,), jnp.float32))
        >>> laplace = vdx.LaplaceCovariance(prior_cov_op=eye, obs_cov_op=eye)
        >>> model = SimpleNamespace(obs_op=vdx.LinearObs(H_mat=eye))
        >>> post = laplace(jnp.zeros(4), model, batch=None)
        >>> post.mean.shape
        (4,)
        >>> # B = R = H = I, so P* = (I + I)^{-1} = I / 2
        >>> bool(jnp.allclose(post.cov.mv(jnp.ones(4)), 0.5, atol=1e-2))
        True
    """

    prior_cov_op: lx.AbstractLinearOperator
    obs_cov_op: lx.AbstractLinearOperator

    def __call__(
        self,
        analysis: Float[Array, ...],
        model: Any,  # AnalysisStep-compliant
        batch: Any,
    ) -> Posterior:
        """Build the Laplace posterior at ``analysis``.

        Args:
            analysis: MAP / posterior mean.
            model: The analysis-step instance (used to recover the
                observation operator). Expected to have a ``.model``
                or be the underlying ``eqx.Module`` directly.
            batch: The batch the analysis was computed against (used
                to recover the mask, instrument bookkeeping, etc.).

        Returns:
            [`Posterior`][vardax.Posterior] with ``mean = analysis``
            and ``cov`` as a lazy ``AbstractLinearOperator``
            representing $P^*$.
        """
        # Pull the obs operator off either an .as_analysis_step()
        # wrapper or the raw model. Both expose ``obs_op``.
        underlying = getattr(model, "model", model)
        obs_op = getattr(underlying, "obs_op", None)
        if obs_op is None:
            raise AttributeError(
                "LaplaceCovariance requires the model to expose `obs_op`; "
                "got a model without one."
            )

        H = obs_op.linearize(analysis)
        # P*^{-1} = H^T R^{-1} H + B^{-1}
        # Build lazily via operator composition (gaussx.sandwich would
        # materialise the Jacobian, breaking the matrix-free design);
        # the inverses are gaussx.inv, which dispatches to closed forms
        # for structured operators and falls back to lineax.CG mat-vecs.
        precision = H.transpose() @ _inverse_op(self.obs_cov_op) @ H + _inverse_op(
            self.prior_cov_op
        )
        precision_tagged = lx.TaggedLinearOperator(
            precision, lx.positive_semidefinite_tag
        )

        return Posterior(
            mean=analysis,
            cov=gaussx.inv(precision_tagged, solver=_CG_SOLVER),
            samples=None,
            provenance={"adapter": "LaplaceCovariance"},
        )

GaussNewtonHessian

Bases: Module

Gauss-Newton Hessian inversion at MAP.

Functionally similar to LaplaceCovariance — both compute \((B^{-1} + H^\top R^{-1} H)^{-1}\) (or its 4D extension) — but the GN-Hessian adapter is the recommended path for IncrementalFourDVar where the Hessian is already materialised by the inner CG solver.

Attributes:

Name Type Description
prior_cov_op AbstractLinearOperator

\(B\).

obs_cov_op AbstractLinearOperator

\(R\).

n_krylov int

Maximum Krylov iterations for mat-vec evaluations.

Source code in src/vardax/_src/posterior/gauss_newton.py
class GaussNewtonHessian(eqx.Module):
    r"""Gauss-Newton Hessian inversion at MAP.

    Functionally similar to
    [`LaplaceCovariance`][vardax.LaplaceCovariance] — both compute
    $(B^{-1} + H^\top R^{-1} H)^{-1}$ (or its 4D extension)
    — but the GN-Hessian adapter is the recommended path for
    ``IncrementalFourDVar`` where the Hessian is already
    materialised by the inner CG solver.

    Attributes:
        prior_cov_op: $B$.
        obs_cov_op: $R$.
        n_krylov: Maximum Krylov iterations for mat-vec evaluations.
    """

    prior_cov_op: lx.AbstractLinearOperator
    obs_cov_op: lx.AbstractLinearOperator
    n_krylov: int = eqx.field(static=True, default=50)

    def __call__(
        self,
        analysis: Float[Array, ...],
        model: Any,
        batch: Any,
    ) -> Posterior:
        """Build the GN-Hessian posterior at ``analysis``."""
        underlying = getattr(model, "model", model)
        obs_op = getattr(underlying, "obs_op", None)
        if obs_op is None:
            raise AttributeError(
                "GaussNewtonHessian requires the model to expose `obs_op`."
            )

        H = obs_op.linearize(analysis)
        precision = H.transpose() @ _inverse_op(self.obs_cov_op) @ H + _inverse_op(
            self.prior_cov_op
        )
        precision_tagged = lx.TaggedLinearOperator(
            precision, lx.positive_semidefinite_tag
        )

        return Posterior(
            mean=analysis,
            cov=gaussx.inv(precision_tagged, solver=_CG_SOLVER),
            samples=None,
            provenance={
                "adapter": "GaussNewtonHessian",
                "n_krylov": self.n_krylov,
            },
        )

EnsembleCovariance

Bases: Module

Sample-covariance posterior from an ensemble of analyses.

Attributes:

Name Type Description
n_members int

Expected ensemble size (used for diagnostics).

inflation float

Multiplicative inflation factor \(\lambda\) applied to the sample covariance. Default 1.0 (no inflation).

Source code in src/vardax/_src/posterior/ensemble.py
class EnsembleCovariance(eqx.Module):
    r"""Sample-covariance posterior from an ensemble of analyses.

    Attributes:
        n_members: Expected ensemble size (used for diagnostics).
        inflation: Multiplicative inflation factor $\lambda$ applied
            to the sample covariance. Default 1.0 (no inflation).
    """

    n_members: int = eqx.field(static=True)
    inflation: float = eqx.field(static=True, default=1.0)

    def __call__(
        self,
        analyses: Float[Array, "M ..."],
        model: Any,
        batch: Any,
    ) -> Posterior:
        """Build the ensemble posterior from per-member analyses.

        Args:
            analyses: Stack of ``M`` analyses with shape ``(M, ...)``.
            model: AnalysisStep instance (not used directly; kept for
                interface compatibility).
            batch: The batch (not used directly; kept for interface
                compatibility).

        Returns:
            [`Posterior`][vardax.Posterior] with mean = ensemble mean,
            ``cov`` = sample covariance as a dense
            ``MatrixLinearOperator`` (only assembled if you ask for
            it; safe for moderate M).
        """
        m = analyses.shape[0]
        if m != self.n_members:
            # Allow runtime mismatch; the configured ``n_members`` is
            # advisory rather than enforced.
            pass

        mean = jnp.mean(analyses, axis=0)
        # Flatten everything but the leading ensemble dim; the
        # Bessel-corrected sample covariance comes from gaussx as a
        # rank-(m-1) LowRankUpdate operator (never dense), scaled by
        # the multiplicative inflation factor.
        cov_op = self.inflation * gaussx.ensemble_covariance(
            analyses.reshape(m, -1), bessel=True
        )

        return Posterior(
            mean=mean,
            cov=cov_op,
            samples=analyses,
            provenance={
                "adapter": "EnsembleCovariance",
                "n_members": int(m),
                "inflation": float(self.inflation),
            },
        )

Likelihoods

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

GaussianMarkLikelihood

Bases: Module

Serialise a Gaussian posterior into a mark-likelihood dict.

Attributes:

Name Type Description
posterior Posterior

A vardax Posterior (with Gaussian cov).

event_metadata dict[str, Any]

Free-form dict (event_id, time, geometry, instruments_used, …). Passed through verbatim into the output.

Use .to_dict() to get a JSON-friendly representation that downstream consumers (GeoCatalog, hierarchical Bayesian models, DuckDB ingest) can store as-is.

Examples:

>>> import jax.numpy as jnp
>>> import vardax as vdx
>>> post = vdx.Posterior(mean=jnp.zeros(3))
>>> mark = vdx.GaussianMarkLikelihood(
...     posterior=post, event_metadata={"event_id": "e1"}
... )
>>> sorted(mark.to_dict())
['cov_diag', 'event_metadata', 'mean', 'provenance', 'samples']
Source code in src/vardax/_src/posterior/adapter.py
class GaussianMarkLikelihood(eqx.Module):
    """Serialise a Gaussian posterior into a mark-likelihood dict.

    Attributes:
        posterior: A vardax ``Posterior`` (with Gaussian ``cov``).
        event_metadata: Free-form dict (event_id, time, geometry,
            instruments_used, …). Passed through verbatim into the
            output.

    Use ``.to_dict()`` to get a JSON-friendly representation that
    downstream consumers (GeoCatalog, hierarchical Bayesian models,
    DuckDB ingest) can store as-is.

    Examples:
        >>> import jax.numpy as jnp
        >>> import vardax as vdx
        >>> post = vdx.Posterior(mean=jnp.zeros(3))
        >>> mark = vdx.GaussianMarkLikelihood(
        ...     posterior=post, event_metadata={"event_id": "e1"}
        ... )
        >>> sorted(mark.to_dict())
        ['cov_diag', 'event_metadata', 'mean', 'provenance', 'samples']
    """

    posterior: Posterior
    event_metadata: dict[str, Any] = eqx.field(default_factory=dict)

    def to_dict(self) -> dict[str, Any]:
        """Return a JSON-friendly mark-likelihood representation."""
        out: dict[str, Any] = {
            "mean": _to_list(self.posterior.mean),
            "cov_diag": (
                _to_list(self._cov_diag()) if self.posterior.cov is not None else None
            ),
            "samples": (
                _to_list(self.posterior.samples)
                if self.posterior.samples is not None
                else None
            ),
            "provenance": dict(self.posterior.provenance),
            "event_metadata": dict(self.event_metadata),
        }
        return out

    def _cov_diag(self):
        """Extract the diagonal of the covariance operator.

        For materialised operators this is cheap; for lazy operators
        it costs one mat-vec per state dim (probe with unit vectors).
        We pick the materialised path when the operator exposes it,
        else fall back to a single ``as_matrix()`` call.
        """
        cov = self.posterior.cov
        if cov is None:
            return None
        try:
            return gaussx.diag(cov)
        except (AttributeError, NotImplementedError):
            # Lazy operators don't always materialise; just emit a
            # null marker.
            return None

to_dict

to_dict() -> dict[str, Any]

Return a JSON-friendly mark-likelihood representation.

Source code in src/vardax/_src/posterior/adapter.py
def to_dict(self) -> dict[str, Any]:
    """Return a JSON-friendly mark-likelihood representation."""
    out: dict[str, Any] = {
        "mean": _to_list(self.posterior.mean),
        "cov_diag": (
            _to_list(self._cov_diag()) if self.posterior.cov is not None else None
        ),
        "samples": (
            _to_list(self.posterior.samples)
            if self.posterior.samples is not None
            else None
        ),
        "provenance": dict(self.posterior.provenance),
        "event_metadata": dict(self.event_metadata),
    }
    return out