Utilities & Diagnostics¶
Supporting cast: a reference dynamical system for experiments, the statistical gates that decide whether an inference setup can be trusted, and the plotting helpers used throughout the end-to-end examples.
Dynamical systems¶
Lorenz-96 is the standard chaotic testbed for assimilation experiments;
simulate_lorenz96 generates trajectories for the
Lorenz examples and the test suite.
vardax — Modular variational data assimilation with learned components.
All public symbols are re-exported from the private _src subpackage so
that user code imports from the top-level namespace:
simulate_lorenz96
¶
simulate_lorenz96(
key: Array,
*,
N: int = 40,
F: float = 8.0,
dt: float = 0.01,
n_steps: int = 5000,
n_burn_in: int = 1000,
) -> tuple[Float[Array, T], Float[Array, "T N"]]
Simulate the Lorenz-96 system and return state trajectory.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Array
|
JAX PRNG key used to perturb the initial condition. |
required |
N
|
int
|
Number of variables (spatial dimension). |
40
|
F
|
float
|
Forcing constant. |
8.0
|
dt
|
float
|
Integration time step. |
0.01
|
n_steps
|
int
|
Total number of integration steps after burn-in. |
5000
|
n_burn_in
|
int
|
Number of initial steps to discard (burn-in). |
1000
|
Returns:
| Type | Description |
|---|---|
Float[Array, T]
|
|
Float[Array, 'T N']
|
(starting at 0) and the state trajectory of shape |
tuple[Float[Array, T], Float[Array, 'T N']]
|
|
Source code in src/vardax/_src/utils/dynamical_systems.py
Validation gates¶
The six-step methodology's go/no-go checks (Decision D12):
simulation-based calibration ranks the truth within posterior samples and
must be uniform; assert_posterior_agreement cross-checks two
posterior adapters against each other; and
assert_adjoint_calibrated verifies that a cheap
adjoint tracks the exact gradient before it is used for
training. Run these before believing any uncertainty estimate.
vardax — Modular variational data assimilation with learned components.
All public symbols are re-exported from the private _src subpackage so
that user code imports from the top-level namespace:
simulation_based_calibration
¶
simulation_based_calibration(
sample_posterior: Callable[
[Array, PRNGKeyArray, int], Array
],
sample_prior: Callable[[PRNGKeyArray], Array],
simulate_obs: Callable[[Array, PRNGKeyArray], Array],
*,
key: PRNGKeyArray,
n_runs: int = 200,
n_samples: int = 200,
) -> Float[Array, " n_runs"]
Per Talts et al. 2018: rank histogram of true draws within posterior samples.
Procedure for each of n_runs independent draws:
- Sample \(x^{(j)} \sim p(x)\) from the prior.
- Simulate \(y^{(j)} = \mathrm{simulate\_obs}(x^{(j)})\).
- Draw
n_samplesposterior samples from \(q_\phi(\cdot \mid y^{(j)})\). - Compute the rank of (a flattened scalar reduction of) \(x^{(j)}\) in the sample set.
A well-calibrated posterior produces uniformly distributed ranks
over [0, n_samples]. Bumps near the edges → over-confident;
centre-mass bump → under-confident.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample_posterior
|
Callable[[Array, PRNGKeyArray, int], Array]
|
|
required |
sample_prior
|
Callable[[PRNGKeyArray], Array]
|
|
required |
simulate_obs
|
Callable[[Array, PRNGKeyArray], Array]
|
|
required |
key
|
PRNGKeyArray
|
Top-level PRNG key. |
required |
n_runs
|
int
|
Number of (prior, obs, posterior) triples. |
200
|
n_samples
|
int
|
Posterior samples per run (defines the rank histogram resolution). |
200
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' n_runs']
|
|
Source code in src/vardax/_src/utils/validation.py
assert_posterior_agreement
¶
assert_posterior_agreement(
p_fast: Posterior,
p_oracle: Posterior,
*,
tolerance_sigma: float = 1.0,
) -> None
Check that p_fast.mean lies within tolerance_sigma standard
deviations of p_oracle.mean.
Marginal-only test: each component of the mean must satisfy
The oracle marginal \(\sigma_i\) is extracted from
p_oracle.cov by probing \(e_i^T \Sigma e_i\) via one matvec per
component. Cheap for moderate state size; for very large state
sizes use Hutchinson estimation upstream and supply the diagonal
directly via Posterior.cov materialised as a diagonal operator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
p_fast
|
Posterior
|
Posterior produced by the amortized / fast model. |
required |
p_oracle
|
Posterior
|
Posterior produced by the oracle (e.g. |
required |
tolerance_sigma
|
float
|
Allowed deviation in units of \(\sigma\). |
1.0
|
Raises:
| Type | Description |
|---|---|
AssertionError
|
If any component exceeds the tolerance. |
Source code in src/vardax/_src/utils/validation.py
assert_adjoint_calibrated
¶
assert_adjoint_calibrated(
fn_fast: Callable[[Array], Array],
fn_oracle: Callable[[Array], Array],
y: Float[Array, ...],
*,
key: PRNGKeyArray,
threshold: float = 0.05,
n_probes: int = 10,
) -> None
Random-vector probe of the Jacobian agreement at y.
Tests
for n_probes random unit vectors \(v\). Avoids
materialising either Jacobian — uses jax.jvp to apply each as
needed. Cheaper than dense Jacobian comparison and is the
operational test used by the six-step cycle.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn_fast
|
Callable[[Array], Array]
|
Callable |
required |
fn_oracle
|
Callable[[Array], Array]
|
Same signature, but the oracle. |
required |
y
|
Float[Array, ...]
|
Observation tensor. |
required |
key
|
PRNGKeyArray
|
PRNG key for the probe vectors. |
required |
threshold
|
float
|
Maximum allowed relative error. |
0.05
|
n_probes
|
int
|
Number of probe vectors. |
10
|
Raises:
| Type | Description |
|---|---|
AssertionError
|
If any probe exceeds the threshold. |
Source code in src/vardax/_src/utils/validation.py
Visualization¶
vardax — Modular variational data assimilation with learned components.
All public symbols are re-exported from the private _src subpackage so
that user code imports from the top-level namespace:
plot_l96_trajectories
¶
plot_l96_trajectories(
states: ndarray,
time_coords: ndarray,
*,
n_vars: int = 5,
ax: Axes | None = None,
) -> tuple[Figure, Axes]
Line plot of selected Lorenz-96 variables over time.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
states
|
ndarray
|
State array of shape |
required |
time_coords
|
ndarray
|
Time coordinate array of shape |
required |
n_vars
|
int
|
Number of evenly-spaced variables to plot. |
5
|
ax
|
Axes | None
|
Optional existing |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Figure, Axes]
|
|
Source code in src/vardax/_src/utils/viz.py
plot_l96_grid
¶
plot_l96_grid(
states: ndarray,
time_coords: ndarray,
*,
ax: Axes | None = None,
) -> tuple[Figure, Axes]
Hovmöller-style image of Lorenz-96 states over time.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
states
|
ndarray
|
State array of shape |
required |
time_coords
|
ndarray
|
Time coordinate array of shape |
required |
ax
|
Axes | None
|
Optional existing |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Figure, Axes]
|
|
Source code in src/vardax/_src/utils/viz.py
plot_reconstruction_comparison
¶
plot_reconstruction_comparison(
target: ndarray,
masked_input: ndarray,
reconstruction: ndarray,
*,
sample_idx: int = 0,
) -> tuple[Figure, ndarray]
Side-by-side comparison of target, masked input, and reconstruction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
target
|
ndarray
|
Ground-truth states of shape |
required |
masked_input
|
ndarray
|
Masked / noisy observations of shape |
required |
reconstruction
|
ndarray
|
Model reconstruction of shape |
required |
sample_idx
|
int
|
Which batch element to visualize. |
0
|
Returns:
| Type | Description |
|---|---|
tuple[Figure, ndarray]
|
|