Amortized Inference¶
Where the variational methods on the Models page solve an optimisation problem per analysis, an amortized posterior pays the cost once at training time: a network is fit to map observations directly to an approximate posterior, so inference is a single forward pass. See Amortized Inference in the Mathematical Reference for the underlying theory and the fidelity/speed trade-offs.
AmortizedPosterior composes two exchangeable parts: an observation
encoder that summarises (possibly masked, possibly irregular) observations
into a conditioning vector, and a posterior head that turns that vector
into a distribution over states. Heads span the fidelity ladder — point
estimates (RegressionHead), full densities via conditional normalizing
flows (ConditionalFlowHead), and score-based diffusion sampling
(ScoreDiffusionHead). Amortized posteriors should pass the same
validation gates (simulation-based calibration, posterior
agreement) as their variational counterparts.
Posterior and configuration¶
vardax — Modular variational data assimilation with learned components.
All public symbols are re-exported from the private _src subpackage so
that user code imports from the top-level namespace:
AmortizedPosterior
¶
Bases: Module
Amortized variational posterior \(q_\phi(x \mid y)\).
Attributes:
| Name | Type | Description |
|---|---|---|
encoder |
Any
|
|
head |
Any
|
|
config |
AmortizedConfig
|
|
Source code in src/vardax/_src/amortized/posterior.py
sample
¶
Draw posterior samples per batch element.
Returns an array of shape (B, n, *state_shape).
Source code in src/vardax/_src/amortized/posterior.py
log_prob
¶
Per-sample log-density of x under q_φ(·|y).
For ScoreDiffusionHead this raises NotImplementedError
(no closed-form density). Use sample instead.
Source code in src/vardax/_src/amortized/posterior.py
as_analysis_step
¶
AmortizedConfig
¶
Bases: Module
Configuration for an AmortizedPosterior.
Attributes:
| Name | Type | Description |
|---|---|---|
head_type |
str
|
One of |
n_samples |
int
|
Default number of posterior samples to draw in
|
Source code in src/vardax/_src/amortized/config.py
Observation encoders¶
IdentityObsEncoder passes observations straight through — appropriate
when they are already a fixed-size vector; MLPObsEncoder learns the
summary.
vardax — Modular variational data assimilation with learned components.
All public symbols are re-exported from the private _src subpackage so
that user code imports from the top-level namespace:
IdentityObsEncoder
¶
Bases: Module
Concatenate input and mask into a flat context vector.
Zero parameters — the encoder is purely structural. The output
dimension is input.size + mask.size.
Source code in src/vardax/_src/amortized/encoder.py
MLPObsEncoder
¶
Bases: Module
Two-layer MLP encoder from flat (input, mask) to context.
Attributes:
| Name | Type | Description |
|---|---|---|
mlp |
MLP
|
|
input_size |
int
|
Flattened size of the input field ( |
Source code in src/vardax/_src/amortized/encoder.py
Posterior heads¶
vardax — Modular variational data assimilation with learned components.
All public symbols are re-exported from the private _src subpackage so
that user code imports from the top-level namespace:
RegressionHead
¶
Bases: Module
Gaussian regression head: q_φ(x|y) = N(μ_φ(y), diag(σ²_φ(y))).
Two MLPs share the context: one for the mean, one for the log
variance. Outputs are reshaped to the user-specified
state_shape.
Attributes:
| Name | Type | Description |
|---|---|---|
mlp_mu |
MLP
|
MLP from context to flat mean (size |
mlp_log_var |
MLP
|
MLP from context to flat log variance. |
state_shape |
tuple[int, ...]
|
Shape of a single posterior sample (e.g.
|
Source code in src/vardax/_src/amortized/heads.py
ConditionalFlowHead
¶
Bases: Module
Conditional normalising flow head (stub).
Implements x = f_φ(z; c_ψ(y)) with z ~ N(0, I). Exact
density via change-of-variables. Requires gauss_flows for the
flow primitives; ships as a stub until that dependency is added to
pyproject.toml.
Source code in src/vardax/_src/amortized/heads.py
ScoreDiffusionHead
¶
Bases: Module
Score-based diffusion head (stub).
Learns s_φ(x, t | y) ≈ ∇_x log p_t(x | y); samples via reverse
SDE solved with diffrax. Ships as a stub pending the reverse-SDE
pipeline.