Skip to content

Amortized Inference

Where the variational methods on the Models page solve an optimisation problem per analysis, an amortized posterior pays the cost once at training time: a network is fit to map observations directly to an approximate posterior, so inference is a single forward pass. See Amortized Inference in the Mathematical Reference for the underlying theory and the fidelity/speed trade-offs.

AmortizedPosterior composes two exchangeable parts: an observation encoder that summarises (possibly masked, possibly irregular) observations into a conditioning vector, and a posterior head that turns that vector into a distribution over states. Heads span the fidelity ladder — point estimates (RegressionHead), full densities via conditional normalizing flows (ConditionalFlowHead), and score-based diffusion sampling (ScoreDiffusionHead). Amortized posteriors should pass the same validation gates (simulation-based calibration, posterior agreement) as their variational counterparts.

Posterior and configuration

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

AmortizedPosterior

Bases: Module

Amortized variational posterior \(q_\phi(x \mid y)\).

Attributes:

Name Type Description
encoder Any

eqx.Module mapping (input, mask) to a context vector. Vmapped over the batch axis at call time.

head Any

RegressionHead / ConditionalFlowHead / ScoreDiffusionHead. Exposes map_estimate(ctx), sample(ctx, key, n), log_prob(x, ctx).

config AmortizedConfig

AmortizedConfig carrying head type and sampling defaults.

Source code in src/vardax/_src/amortized/posterior.py
class AmortizedPosterior(eqx.Module):
    r"""Amortized variational posterior $q_\phi(x \mid y)$.

    Attributes:
        encoder: ``eqx.Module`` mapping ``(input, mask)`` to a context
            vector. Vmapped over the batch axis at call time.
        head: ``RegressionHead`` / ``ConditionalFlowHead`` /
            ``ScoreDiffusionHead``. Exposes ``map_estimate(ctx)``,
            ``sample(ctx, key, n)``, ``log_prob(x, ctx)``.
        config: ``AmortizedConfig`` carrying head type and sampling
            defaults.
    """

    encoder: Any
    head: Any
    config: AmortizedConfig

    def __call__(self, batch: Batch1D | Batch2D) -> Float[Array, ...]:
        r"""Return the per-sample MAP / mode of $q_\phi(x \mid y)$."""

        def _one(input_i, mask_i):
            ctx = self.encoder(input_i, mask_i)
            return self.head.map_estimate(ctx)

        return jax.vmap(_one)(batch.input, batch.mask)

    def sample(
        self,
        batch: Batch1D | Batch2D,
        key: PRNGKeyArray,
        n: int | None = None,
    ) -> Float[Array, ...]:
        """Draw posterior samples per batch element.

        Returns an array of shape ``(B, n, *state_shape)``.
        """
        n = self.config.n_samples if n is None else n
        b = batch.input.shape[0]
        keys = jax.random.split(key, b)

        def _one(input_i, mask_i, key_i):
            ctx = self.encoder(input_i, mask_i)
            return self.head.sample(ctx, key_i, n)

        return jax.vmap(_one)(batch.input, batch.mask, keys)

    def log_prob(
        self,
        x: Float[Array, ...],
        batch: Batch1D | Batch2D,
    ) -> Float[Array, " B"]:
        """Per-sample log-density of ``x`` under ``q_φ(·|y)``.

        For ``ScoreDiffusionHead`` this raises ``NotImplementedError``
        (no closed-form density). Use ``sample`` instead.
        """

        def _one(x_i, input_i, mask_i):
            ctx = self.encoder(input_i, mask_i)
            return self.head.log_prob(x_i, ctx)

        return jax.vmap(_one)(x, batch.input, batch.mask)

    def as_analysis_step(self) -> _AmortizedAnalysisStep:
        """Adapt to ``pipekit_cycle.AnalysisStep`` (Decision D8)."""
        return _AmortizedAnalysisStep(self)

sample

sample(
    batch: Batch1D | Batch2D,
    key: PRNGKeyArray,
    n: int | None = None,
) -> Float[Array, ...]

Draw posterior samples per batch element.

Returns an array of shape (B, n, *state_shape).

Source code in src/vardax/_src/amortized/posterior.py
def sample(
    self,
    batch: Batch1D | Batch2D,
    key: PRNGKeyArray,
    n: int | None = None,
) -> Float[Array, ...]:
    """Draw posterior samples per batch element.

    Returns an array of shape ``(B, n, *state_shape)``.
    """
    n = self.config.n_samples if n is None else n
    b = batch.input.shape[0]
    keys = jax.random.split(key, b)

    def _one(input_i, mask_i, key_i):
        ctx = self.encoder(input_i, mask_i)
        return self.head.sample(ctx, key_i, n)

    return jax.vmap(_one)(batch.input, batch.mask, keys)

log_prob

log_prob(
    x: Float[Array, ...], batch: Batch1D | Batch2D
) -> Float[Array, " B"]

Per-sample log-density of x under q_φ(·|y).

For ScoreDiffusionHead this raises NotImplementedError (no closed-form density). Use sample instead.

Source code in src/vardax/_src/amortized/posterior.py
def log_prob(
    self,
    x: Float[Array, ...],
    batch: Batch1D | Batch2D,
) -> Float[Array, " B"]:
    """Per-sample log-density of ``x`` under ``q_φ(·|y)``.

    For ``ScoreDiffusionHead`` this raises ``NotImplementedError``
    (no closed-form density). Use ``sample`` instead.
    """

    def _one(x_i, input_i, mask_i):
        ctx = self.encoder(input_i, mask_i)
        return self.head.log_prob(x_i, ctx)

    return jax.vmap(_one)(x, batch.input, batch.mask)

as_analysis_step

as_analysis_step() -> _AmortizedAnalysisStep

Adapt to pipekit_cycle.AnalysisStep (Decision D8).

Source code in src/vardax/_src/amortized/posterior.py
def as_analysis_step(self) -> _AmortizedAnalysisStep:
    """Adapt to ``pipekit_cycle.AnalysisStep`` (Decision D8)."""
    return _AmortizedAnalysisStep(self)

AmortizedConfig

Bases: Module

Configuration for an AmortizedPosterior.

Attributes:

Name Type Description
head_type str

One of "regression", "flow", "score". Determines the density family used by the head. Currently only "regression" is fully implemented; "flow" requires gauss_flows and "score" is pending the diffrax reverse-SDE pipeline. Set on construction; the head instance is responsible for matching the type.

n_samples int

Default number of posterior samples to draw in AmortizedPosterior.sample() when n is not provided.

Source code in src/vardax/_src/amortized/config.py
class AmortizedConfig(eqx.Module):
    """Configuration for an ``AmortizedPosterior``.

    Attributes:
        head_type: One of ``"regression"``, ``"flow"``, ``"score"``.
            Determines the density family used by the head. Currently
            only ``"regression"`` is fully implemented;  ``"flow"``
            requires ``gauss_flows`` and ``"score"`` is pending the
            diffrax reverse-SDE pipeline. Set on construction; the head
            instance is responsible for matching the type.
        n_samples: Default number of posterior samples to draw in
            ``AmortizedPosterior.sample()`` when ``n`` is not provided.
    """

    head_type: str = eqx.field(static=True, default="regression")
    n_samples: int = eqx.field(static=True, default=64)

    def __post_init__(self) -> None:
        valid = {"regression", "flow", "score"}
        if self.head_type not in valid:
            raise ValueError(
                f"AmortizedConfig.head_type must be one of {sorted(valid)}; "
                f"got {self.head_type!r}."
            )

Observation encoders

IdentityObsEncoder passes observations straight through — appropriate when they are already a fixed-size vector; MLPObsEncoder learns the summary.

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

IdentityObsEncoder

Bases: Module

Concatenate input and mask into a flat context vector.

Zero parameters — the encoder is purely structural. The output dimension is input.size + mask.size.

Source code in src/vardax/_src/amortized/encoder.py
class IdentityObsEncoder(eqx.Module):
    """Concatenate ``input`` and ``mask`` into a flat context vector.

    Zero parameters — the encoder is purely structural. The output
    dimension is ``input.size + mask.size``.
    """

    def __call__(
        self,
        input: Float[Array, ...],
        mask: Float[Array, ...],
    ) -> Float[Array, " D"]:
        return jnp.concatenate([input.ravel(), mask.ravel()])

MLPObsEncoder

Bases: Module

Two-layer MLP encoder from flat (input, mask) to context.

Attributes:

Name Type Description
mlp MLP

eqx.nn.MLP taking 2 * input_size features and emitting context_dim.

input_size int

Flattened size of the input field (T * N for Batch1D, T * H * W for Batch2D).

Source code in src/vardax/_src/amortized/encoder.py
class MLPObsEncoder(eqx.Module):
    """Two-layer MLP encoder from flat ``(input, mask)`` to context.

    Attributes:
        mlp: ``eqx.nn.MLP`` taking ``2 * input_size`` features and
            emitting ``context_dim``.
        input_size: Flattened size of the input field (``T * N`` for
            ``Batch1D``, ``T * H * W`` for ``Batch2D``).
    """

    mlp: eqx.nn.MLP
    input_size: int = eqx.field(static=True)

    def __init__(
        self,
        input_size: int,
        context_dim: int,
        *,
        hidden_dim: int = 64,
        depth: int = 2,
        key: PRNGKeyArray,
    ) -> None:
        self.input_size = input_size
        self.mlp = eqx.nn.MLP(
            in_size=2 * input_size,
            out_size=context_dim,
            width_size=hidden_dim,
            depth=depth,
            key=key,
        )

    def __call__(
        self,
        input: Float[Array, ...],
        mask: Float[Array, ...],
    ) -> Float[Array, " context_dim"]:
        x = jnp.concatenate([input.ravel(), mask.ravel()])
        return self.mlp(x)

Posterior heads

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

RegressionHead

Bases: Module

Gaussian regression head: q_φ(x|y) = N(μ_φ(y), diag(σ²_φ(y))).

Two MLPs share the context: one for the mean, one for the log variance. Outputs are reshaped to the user-specified state_shape.

Attributes:

Name Type Description
mlp_mu MLP

MLP from context to flat mean (size prod(state_shape)).

mlp_log_var MLP

MLP from context to flat log variance.

state_shape tuple[int, ...]

Shape of a single posterior sample (e.g. (T, N) for Batch1D targets).

Source code in src/vardax/_src/amortized/heads.py
class RegressionHead(eqx.Module):
    """Gaussian regression head: ``q_φ(x|y) = N(μ_φ(y), diag(σ²_φ(y)))``.

    Two MLPs share the context: one for the mean, one for the log
    variance. Outputs are reshaped to the user-specified
    ``state_shape``.

    Attributes:
        mlp_mu: MLP from context to flat mean (size ``prod(state_shape)``).
        mlp_log_var: MLP from context to flat log variance.
        state_shape: Shape of a single posterior sample (e.g.
            ``(T, N)`` for ``Batch1D`` targets).
    """

    mlp_mu: eqx.nn.MLP
    mlp_log_var: eqx.nn.MLP
    state_shape: tuple[int, ...] = eqx.field(static=True)

    def __init__(
        self,
        context_dim: int,
        state_shape: tuple[int, ...],
        *,
        hidden_dim: int = 64,
        depth: int = 2,
        key: PRNGKeyArray,
    ) -> None:
        self.state_shape = tuple(state_shape)
        flat_dim = 1
        for d in self.state_shape:
            flat_dim *= d
        k_mu, k_lv = jax.random.split(key)
        self.mlp_mu = eqx.nn.MLP(
            in_size=context_dim,
            out_size=flat_dim,
            width_size=hidden_dim,
            depth=depth,
            key=k_mu,
        )
        self.mlp_log_var = eqx.nn.MLP(
            in_size=context_dim,
            out_size=flat_dim,
            width_size=hidden_dim,
            depth=depth,
            key=k_lv,
        )

    def map_estimate(self, ctx: Float[Array, " D"]) -> Float[Array, ...]:
        return self.mlp_mu(ctx).reshape(self.state_shape)

    def sample(
        self,
        ctx: Float[Array, " D"],
        key: PRNGKeyArray,
        n: int,
    ) -> Float[Array, ...]:
        mu = self.mlp_mu(ctx).reshape(self.state_shape)
        log_var = self.mlp_log_var(ctx).reshape(self.state_shape)
        sigma = jnp.exp(0.5 * log_var)
        eps = jax.random.normal(key, (n, *self.state_shape))
        return mu + sigma * eps

    def log_prob(
        self,
        x: Float[Array, ...],
        ctx: Float[Array, " D"],
    ) -> Float[Array, ""]:
        mu = self.mlp_mu(ctx).reshape(self.state_shape)
        log_var = self.mlp_log_var(ctx).reshape(self.state_shape)
        # log N(x | mu, diag(exp(log_var))) — sum over state dims.
        return -0.5 * jnp.sum(((x - mu) ** 2) * jnp.exp(-log_var) + log_var + _LOG_2PI)

ConditionalFlowHead

Bases: Module

Conditional normalising flow head (stub).

Implements x = f_φ(z; c_ψ(y)) with z ~ N(0, I). Exact density via change-of-variables. Requires gauss_flows for the flow primitives; ships as a stub until that dependency is added to pyproject.toml.

Source code in src/vardax/_src/amortized/heads.py
class ConditionalFlowHead(eqx.Module):
    """Conditional normalising flow head (stub).

    Implements ``x = f_φ(z; c_ψ(y))`` with ``z ~ N(0, I)``. Exact
    density via change-of-variables. Requires ``gauss_flows`` for the
    flow primitives; ships as a stub until that dependency is added to
    ``pyproject.toml``.
    """

    def __init__(self, *_args: object, **_kwargs: object) -> None:
        raise NotImplementedError(
            "ConditionalFlowHead requires `gauss_flows`, which is not yet "
            "a vardax dependency. Use `RegressionHead` for now; flow support "
            "lands in a follow-up PR alongside the gauss_flows dependency."
        )

ScoreDiffusionHead

Bases: Module

Score-based diffusion head (stub).

Learns s_φ(x, t | y) ≈ ∇_x log p_t(x | y); samples via reverse SDE solved with diffrax. Ships as a stub pending the reverse-SDE pipeline.

Source code in src/vardax/_src/amortized/heads.py
class ScoreDiffusionHead(eqx.Module):
    """Score-based diffusion head (stub).

    Learns ``s_φ(x, t | y) ≈ ∇_x log p_t(x | y)``; samples via reverse
    SDE solved with diffrax. Ships as a stub pending the reverse-SDE
    pipeline.
    """

    def __init__(self, *_args: object, **_kwargs: object) -> None:
        raise NotImplementedError(
            "ScoreDiffusionHead is not yet implemented. Pending the diffrax-based "
            "reverse-SDE pipeline. Use `RegressionHead` for now."
        )