Skip to content

Utilities & Diagnostics

Supporting cast: a reference dynamical system for experiments, the statistical gates that decide whether an inference setup can be trusted, and the plotting helpers used throughout the end-to-end examples.

Dynamical systems

Lorenz-96 is the standard chaotic testbed for assimilation experiments; simulate_lorenz96 generates trajectories for the Lorenz examples and the test suite.

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

simulate_lorenz96

simulate_lorenz96(
    key: Array,
    *,
    N: int = 40,
    F: float = 8.0,
    dt: float = 0.01,
    n_steps: int = 5000,
    n_burn_in: int = 1000,
) -> tuple[Float[Array, T], Float[Array, "T N"]]

Simulate the Lorenz-96 system and return state trajectory.

Parameters:

Name Type Description Default
key Array

JAX PRNG key used to perturb the initial condition.

required
N int

Number of variables (spatial dimension).

40
F float

Forcing constant.

8.0
dt float

Integration time step.

0.01
n_steps int

Total number of integration steps after burn-in.

5000
n_burn_in int

Number of initial steps to discard (burn-in).

1000

Returns:

Type Description
Float[Array, T]

(time_coords, states) — time coordinates of shape (T,)

Float[Array, 'T N']

(starting at 0) and the state trajectory of shape

tuple[Float[Array, T], Float[Array, 'T N']]

(n_steps + 1, N).

Source code in src/vardax/_src/utils/dynamical_systems.py
def simulate_lorenz96(
    key: Array,
    *,
    N: int = 40,
    F: float = 8.0,
    dt: float = 0.01,
    n_steps: int = 5000,
    n_burn_in: int = 1000,
) -> tuple[Float[Array, T], Float[Array, "T N"]]:  # type: ignore[unresolved-reference]  # ty:ignore[unresolved-reference]
    """Simulate the Lorenz-96 system and return state trajectory.

    Args:
        key: JAX PRNG key used to perturb the initial condition.
        N: Number of variables (spatial dimension).
        F: Forcing constant.
        dt: Integration time step.
        n_steps: Total number of integration steps *after* burn-in.
        n_burn_in: Number of initial steps to discard (burn-in).

    Returns:
        ``(time_coords, states)`` — time coordinates of shape ``(T,)``
        (starting at 0) and the state trajectory of shape
        ``(n_steps + 1, N)``.
    """
    model = Lorenz96(F=F)

    # Standard L96 initialization: uniform forcing with a small random perturbation
    x0 = jnp.full((N,), F)
    noise = jax.random.normal(key, shape=(N,)) * 0.01
    x0 = x0 + noise

    total_steps = n_burn_in + n_steps
    t0 = 0.0
    t1 = total_steps * dt

    save_times = jnp.linspace(t0, t1, total_steps + 1)

    sol = diffeqsolve(
        ODETerm(model),  # type: ignore[arg-type]  # ty:ignore[invalid-argument-type]
        Tsit5(),
        t0=t0,
        t1=t1,
        dt0=dt,
        y0=x0,
        saveat=SaveAt(ts=save_times),
    )

    states = sol.ys[n_burn_in:]
    time_coords = save_times[n_burn_in:] - save_times[n_burn_in]
    return time_coords, states

Validation gates

The six-step methodology's go/no-go checks (Decision D12): simulation-based calibration ranks the truth within posterior samples and must be uniform; assert_posterior_agreement cross-checks two posterior adapters against each other; and assert_adjoint_calibrated verifies that a cheap adjoint tracks the exact gradient before it is used for training. Run these before believing any uncertainty estimate.

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

simulation_based_calibration

simulation_based_calibration(
    sample_posterior: Callable[
        [Array, PRNGKeyArray, int], Array
    ],
    sample_prior: Callable[[PRNGKeyArray], Array],
    simulate_obs: Callable[[Array, PRNGKeyArray], Array],
    *,
    key: PRNGKeyArray,
    n_runs: int = 200,
    n_samples: int = 200,
) -> Float[Array, " n_runs"]

Per Talts et al. 2018: rank histogram of true draws within posterior samples.

Procedure for each of n_runs independent draws:

  1. Sample \(x^{(j)} \sim p(x)\) from the prior.
  2. Simulate \(y^{(j)} = \mathrm{simulate\_obs}(x^{(j)})\).
  3. Draw n_samples posterior samples from \(q_\phi(\cdot \mid y^{(j)})\).
  4. Compute the rank of (a flattened scalar reduction of) \(x^{(j)}\) in the sample set.

A well-calibrated posterior produces uniformly distributed ranks over [0, n_samples]. Bumps near the edges → over-confident; centre-mass bump → under-confident.

Parameters:

Name Type Description Default
sample_posterior Callable[[Array, PRNGKeyArray, int], Array]

(y, key, n) -> Array of shape (n, *state_shape) — posterior samples conditioned on y. For an AmortizedPosterior, the natural wrapper is lambda y, k, n: model.sample(Batch(input=y[None], ...), k, n)[0].

required
sample_prior Callable[[PRNGKeyArray], Array]

key -> x. Single prior draw.

required
simulate_obs Callable[[Array, PRNGKeyArray], Array]

(x, key) -> y. Forward + noise.

required
key PRNGKeyArray

Top-level PRNG key.

required
n_runs int

Number of (prior, obs, posterior) triples.

200
n_samples int

Posterior samples per run (defines the rank histogram resolution).

200

Returns:

Type Description
Float[Array, ' n_runs']

(n_runs,) int array of ranks in [0, n_samples].

Source code in src/vardax/_src/utils/validation.py
def simulation_based_calibration(
    sample_posterior: Callable[[Array, PRNGKeyArray, int], Array],
    sample_prior: Callable[[PRNGKeyArray], Array],
    simulate_obs: Callable[[Array, PRNGKeyArray], Array],
    *,
    key: PRNGKeyArray,
    n_runs: int = 200,
    n_samples: int = 200,
) -> Float[Array, " n_runs"]:
    r"""Per Talts et al. 2018: rank histogram of true draws within
    posterior samples.

    Procedure for each of ``n_runs`` independent draws:

    1. Sample $x^{(j)} \sim p(x)$ from the prior.
    2. Simulate $y^{(j)} = \mathrm{simulate\_obs}(x^{(j)})$.
    3. Draw ``n_samples`` posterior samples from
       $q_\phi(\cdot \mid y^{(j)})$.
    4. Compute the rank of (a flattened scalar reduction of)
       $x^{(j)}$ in the sample set.

    A well-calibrated posterior produces uniformly distributed ranks
    over ``[0, n_samples]``. Bumps near the edges → over-confident;
    centre-mass bump → under-confident.

    Args:
        sample_posterior: ``(y, key, n) -> Array`` of shape
            ``(n, *state_shape)`` — posterior samples conditioned on
            ``y``. For an ``AmortizedPosterior``, the natural wrapper is
            ``lambda y, k, n: model.sample(Batch(input=y[None], ...), k, n)[0]``.
        sample_prior: ``key -> x``. Single prior draw.
        simulate_obs: ``(x, key) -> y``. Forward + noise.
        key: Top-level PRNG key.
        n_runs: Number of (prior, obs, posterior) triples.
        n_samples: Posterior samples per run (defines the rank
            histogram resolution).

    Returns:
        ``(n_runs,)`` int array of ranks in ``[0, n_samples]``.
    """
    keys = jax.random.split(key, n_runs)
    ranks = []
    for k in keys:
        k_prior, k_obs, k_post = jax.random.split(k, 3)
        x_true = sample_prior(k_prior)
        y = simulate_obs(x_true, k_obs)
        samples = sample_posterior(y, k_post, n_samples)
        # Flatten and use the L2 norm as the scalar reduction.
        scalar_true = float(jnp.linalg.norm(x_true))
        scalar_samples = jnp.linalg.norm(samples.reshape(n_samples, -1), axis=-1)
        rank = int(jnp.sum(scalar_samples < scalar_true))
        ranks.append(rank)
    return jnp.asarray(ranks, dtype=jnp.int32)

assert_posterior_agreement

assert_posterior_agreement(
    p_fast: Posterior,
    p_oracle: Posterior,
    *,
    tolerance_sigma: float = 1.0,
) -> None

Check that p_fast.mean lies within tolerance_sigma standard deviations of p_oracle.mean.

Marginal-only test: each component of the mean must satisfy

\[ \frac{|x^*_\text{fast} - x^*_\text{oracle}|}{\sigma_\text{post, oracle}} \le \text{tolerance\_sigma}. \]

The oracle marginal \(\sigma_i\) is extracted from p_oracle.cov by probing \(e_i^T \Sigma e_i\) via one matvec per component. Cheap for moderate state size; for very large state sizes use Hutchinson estimation upstream and supply the diagonal directly via Posterior.cov materialised as a diagonal operator.

Parameters:

Name Type Description Default
p_fast Posterior

Posterior produced by the amortized / fast model.

required
p_oracle Posterior

Posterior produced by the oracle (e.g. StrongFourDVar + LaplaceCovariance).

required
tolerance_sigma float

Allowed deviation in units of \(\sigma\).

1.0

Raises:

Type Description
AssertionError

If any component exceeds the tolerance.

Source code in src/vardax/_src/utils/validation.py
def assert_posterior_agreement(
    p_fast: Posterior,
    p_oracle: Posterior,
    *,
    tolerance_sigma: float = 1.0,
) -> None:
    r"""Check that ``p_fast.mean`` lies within ``tolerance_sigma`` standard
    deviations of ``p_oracle.mean``.

    Marginal-only test: each component of the mean must satisfy

    $$
    \frac{|x^*_\text{fast} - x^*_\text{oracle}|}{\sigma_\text{post, oracle}}
    \le \text{tolerance\_sigma}.
    $$

    The oracle marginal $\sigma_i$ is extracted from
    ``p_oracle.cov`` by probing $e_i^T \Sigma e_i$ via one matvec per
    component. Cheap for moderate state size; for very large state
    sizes use Hutchinson estimation upstream and supply the diagonal
    directly via ``Posterior.cov`` materialised as a diagonal operator.

    Args:
        p_fast: Posterior produced by the amortized / fast model.
        p_oracle: Posterior produced by the oracle (e.g. ``StrongFourDVar``
            + [`LaplaceCovariance`][vardax.LaplaceCovariance]).
        tolerance_sigma: Allowed deviation in units of $\sigma$.

    Raises:
        AssertionError: If any component exceeds the tolerance.
    """
    if p_oracle.cov is None:
        raise ValueError(
            "assert_posterior_agreement requires p_oracle.cov to be set "
            "(used to extract the marginal standard deviations)."
        )
    if p_fast.mean.shape != p_oracle.mean.shape:
        raise ValueError(
            "p_fast.mean and p_oracle.mean must have the same shape; got "
            f"{p_fast.mean.shape} vs {p_oracle.mean.shape}."
        )
    sigma = _marginal_std(p_oracle.cov, p_oracle.mean)
    z = jnp.abs(p_fast.mean - p_oracle.mean) / jnp.maximum(sigma, 1e-12)
    if jnp.any(z > tolerance_sigma):
        worst = float(jnp.max(z))
        raise AssertionError(
            f"posterior agreement failed: max z = {worst:.3f} > {tolerance_sigma:.3f}."
        )

assert_adjoint_calibrated

assert_adjoint_calibrated(
    fn_fast: Callable[[Array], Array],
    fn_oracle: Callable[[Array], Array],
    y: Float[Array, ...],
    *,
    key: PRNGKeyArray,
    threshold: float = 0.05,
    n_probes: int = 10,
) -> None

Random-vector probe of the Jacobian agreement at y.

Tests

\[ \frac{\|J_\text{fast} v - J_\text{oracle} v\|} {\|J_\text{oracle} v\|} < \text{threshold} \]

for n_probes random unit vectors \(v\). Avoids materialising either Jacobian — uses jax.jvp to apply each as needed. Cheaper than dense Jacobian comparison and is the operational test used by the six-step cycle.

Parameters:

Name Type Description Default
fn_fast Callable[[Array], Array]

Callable y → analysis (e.g. lambda y_: amortized(Batch1D(input=y_, mask=mask, target=None))).

required
fn_oracle Callable[[Array], Array]

Same signature, but the oracle.

required
y Float[Array, ...]

Observation tensor.

required
key PRNGKeyArray

PRNG key for the probe vectors.

required
threshold float

Maximum allowed relative error.

0.05
n_probes int

Number of probe vectors.

10

Raises:

Type Description
AssertionError

If any probe exceeds the threshold.

Source code in src/vardax/_src/utils/validation.py
def assert_adjoint_calibrated(
    fn_fast: Callable[[Array], Array],
    fn_oracle: Callable[[Array], Array],
    y: Float[Array, ...],
    *,
    key: PRNGKeyArray,
    threshold: float = 0.05,
    n_probes: int = 10,
) -> None:
    r"""Random-vector probe of the Jacobian agreement at ``y``.

    Tests

    $$
    \frac{\|J_\text{fast} v - J_\text{oracle} v\|}
         {\|J_\text{oracle} v\|} < \text{threshold}
    $$

    for ``n_probes`` random unit vectors $v$. Avoids
    materialising either Jacobian — uses ``jax.jvp`` to apply each as
    needed. Cheaper than dense Jacobian comparison and is the
    operational test used by the six-step cycle.

    Args:
        fn_fast: Callable ``y → analysis`` (e.g.
            ``lambda y_: amortized(Batch1D(input=y_, mask=mask, target=None))``).
        fn_oracle: Same signature, but the oracle.
        y: Observation tensor.
        key: PRNG key for the probe vectors.
        threshold: Maximum allowed relative error.
        n_probes: Number of probe vectors.

    Raises:
        AssertionError: If any probe exceeds the threshold.
    """
    keys = jax.random.split(key, n_probes)
    worst_rel = 0.0
    for k in keys:
        v = jax.random.normal(k, y.shape)
        v = v / jnp.linalg.norm(v)
        _, jv_fast = jax.jvp(fn_fast, (y,), (v,))
        _, jv_oracle = jax.jvp(fn_oracle, (y,), (v,))
        denom = jnp.linalg.norm(jv_oracle) + 1e-12
        rel = float(jnp.linalg.norm(jv_fast - jv_oracle) / denom)
        worst_rel = max(worst_rel, rel)
    if worst_rel > threshold:
        raise AssertionError(
            f"adjoint calibration failed: max relative error {worst_rel:.4f} > "
            f"{threshold:.4f} (over {n_probes} probes)."
        )

Visualization

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

plot_l96_trajectories

plot_l96_trajectories(
    states: ndarray,
    time_coords: ndarray,
    *,
    n_vars: int = 5,
    ax: Axes | None = None,
) -> tuple[Figure, Axes]

Line plot of selected Lorenz-96 variables over time.

Parameters:

Name Type Description Default
states ndarray

State array of shape (T, N).

required
time_coords ndarray

Time coordinate array of shape (T,).

required
n_vars int

Number of evenly-spaced variables to plot.

5
ax Axes | None

Optional existing Axes to draw on.

None

Returns:

Type Description
tuple[Figure, Axes]

(fig, ax) — Matplotlib figure and axes.

Source code in src/vardax/_src/utils/viz.py
def plot_l96_trajectories(
    states: np.ndarray,
    time_coords: np.ndarray,
    *,
    n_vars: int = 5,
    ax: Axes | None = None,
) -> tuple[Figure, Axes]:
    """Line plot of selected Lorenz-96 variables over time.

    Args:
        states: State array of shape ``(T, N)``.
        time_coords: Time coordinate array of shape ``(T,)``.
        n_vars: Number of evenly-spaced variables to plot.
        ax: Optional existing ``Axes`` to draw on.

    Returns:
        ``(fig, ax)`` — Matplotlib figure and axes.
    """
    import matplotlib.pyplot as plt

    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.get_figure()
        assert fig is not None

    N = states.shape[1]
    if n_vars <= 0:
        raise ValueError(f"n_vars must be a positive integer, got {n_vars}.")
    n_selected = min(n_vars, N)
    indices = np.unique(np.linspace(0, N - 1, n_selected, dtype=int))
    for i in indices:
        ax.plot(time_coords, states[:, i], label=f"x{i}")
    ax.set_xlabel("time")
    ax.set_ylabel("state")
    ax.legend()
    return fig, ax  # type: ignore[return-value]  # ty:ignore[invalid-return-type]

plot_l96_grid

plot_l96_grid(
    states: ndarray,
    time_coords: ndarray,
    *,
    ax: Axes | None = None,
) -> tuple[Figure, Axes]

Hovmöller-style image of Lorenz-96 states over time.

Parameters:

Name Type Description Default
states ndarray

State array of shape (T, N).

required
time_coords ndarray

Time coordinate array of shape (T,).

required
ax Axes | None

Optional existing Axes to draw on.

None

Returns:

Type Description
tuple[Figure, Axes]

(fig, ax) — Matplotlib figure and axes.

Source code in src/vardax/_src/utils/viz.py
def plot_l96_grid(
    states: np.ndarray,
    time_coords: np.ndarray,
    *,
    ax: Axes | None = None,
) -> tuple[Figure, Axes]:
    """Hovmöller-style image of Lorenz-96 states over time.

    Args:
        states: State array of shape ``(T, N)``.
        time_coords: Time coordinate array of shape ``(T,)``.
        ax: Optional existing ``Axes`` to draw on.

    Returns:
        ``(fig, ax)`` — Matplotlib figure and axes.
    """
    import matplotlib.pyplot as plt

    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.get_figure()
        assert fig is not None

    ax.imshow(
        states.T,
        aspect="auto",
        origin="lower",
        extent=(
            float(time_coords[0]),
            float(time_coords[-1]),
            0.0,
            float(states.shape[1]),
        ),
    )
    ax.set_xlabel("time")
    ax.set_ylabel("variable index")
    return fig, ax  # type: ignore[return-value]  # ty:ignore[invalid-return-type]

plot_reconstruction_comparison

plot_reconstruction_comparison(
    target: ndarray,
    masked_input: ndarray,
    reconstruction: ndarray,
    *,
    sample_idx: int = 0,
) -> tuple[Figure, ndarray]

Side-by-side comparison of target, masked input, and reconstruction.

Parameters:

Name Type Description Default
target ndarray

Ground-truth states of shape (B, T, N).

required
masked_input ndarray

Masked / noisy observations of shape (B, T, N).

required
reconstruction ndarray

Model reconstruction of shape (B, T, N).

required
sample_idx int

Which batch element to visualize.

0

Returns:

Type Description
tuple[Figure, ndarray]

(fig, axes) — Matplotlib figure and array of axes.

Source code in src/vardax/_src/utils/viz.py
def plot_reconstruction_comparison(
    target: np.ndarray,
    masked_input: np.ndarray,
    reconstruction: np.ndarray,
    *,
    sample_idx: int = 0,
) -> tuple[Figure, np.ndarray]:
    """Side-by-side comparison of target, masked input, and reconstruction.

    Args:
        target: Ground-truth states of shape ``(B, T, N)``.
        masked_input: Masked / noisy observations of shape ``(B, T, N)``.
        reconstruction: Model reconstruction of shape ``(B, T, N)``.
        sample_idx: Which batch element to visualize.

    Returns:
        ``(fig, axes)`` — Matplotlib figure and array of axes.
    """
    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    data = [
        (target[sample_idx], "Target"),
        (masked_input[sample_idx], "Masked input"),
        (reconstruction[sample_idx], "Reconstruction"),
    ]
    for ax, (arr, title) in zip(axes, data, strict=False):
        ax.imshow(arr.T, aspect="auto", origin="lower")
        ax.set_title(title)
        ax.set_xlabel("time")
        ax.set_ylabel("feature")
    fig.tight_layout()
    return fig, axes