Skip to content

Inference API

The pyrox.inference subpackage exposes ensemble-of-MAP and ensemble-of-VI runners as a layered surface — pick the level of control that fits your use case.

Layer 1 — Functional primitives

Roll your own training loop on top of the vmapped state primitives.

pyrox.inference.ensemble_init(init_fn, optimizer, *, ensemble_size, seed)

Initialize an ensemble of (params, opt_state) by vmap over keys.

Parameters:

Name Type Description Default
init_fn Callable[[PRNGKeyArray], PyTree]

key -> params. Called once per ensemble member.

required
optimizer GradientTransformation

optax.GradientTransformation.

required
ensemble_size int

Number of independent ensemble members E.

required
seed PRNGKeyArray

PRNG key, split into E per-member init keys.

required

Returns:

Type Description
EnsembleState

class:EnsembleState with stacked params and opt_state.

EnsembleState

Array leaves carry a leading (E,) axis.

Source code in src/pyrox/inference/_ensemble.py
def ensemble_init(
    init_fn: Callable[[PRNGKeyArray], PyTree],
    optimizer: optax.GradientTransformation,
    *,
    ensemble_size: int,
    seed: PRNGKeyArray,
) -> EnsembleState:
    """Initialize an ensemble of (params, opt_state) by vmap over keys.

    Args:
        init_fn: ``key -> params``. Called once per ensemble member.
        optimizer: ``optax.GradientTransformation``.
        ensemble_size: Number of independent ensemble members ``E``.
        seed: PRNG key, split into ``E`` per-member init keys.

    Returns:
        :class:`EnsembleState` with stacked ``params`` and ``opt_state``.
        Array leaves carry a leading ``(E,)`` axis.
    """
    keys = jr.split(seed, ensemble_size)

    @eqx.filter_vmap
    def _per_member(key: PRNGKeyArray) -> tuple[PyTree, Any]:
        params = init_fn(key)
        opt_state = optimizer.init(eqx.filter(params, eqx.is_inexact_array))
        return params, opt_state

    params, opt_state = _per_member(keys)
    return EnsembleState(params=params, opt_state=opt_state)

pyrox.inference.ensemble_loss(log_joint, *, prior_weight=1.0, scale=1.0)

Build the filter_value_and_grad loss from a log-joint.

Returns a function loss_fn(params, x_batch, y_batch) -> (loss, grads) that computes

.. math::

\mathcal{L}(\theta) = -\, \text{scale} \cdot \log p(y \mid x, \theta)
    \;-\; w_{\text{prior}} \cdot \log p(\theta).

prior_weight=0 short-circuits the prior term so the user may return a placeholder 0.0 for logprior.

Parameters:

Name Type Description Default
log_joint Callable[[PyTree, Array, Array], tuple[Float[Array, ''], Float[Array, '']]]

(params, x_batch, y_batch) -> (loglik, logprior).

required
prior_weight float

Weight on the logprior term. 0.0 ⇒ MLE, 1.0 ⇒ MAP.

1.0
scale float

Multiplicative weight on the loglik term. Set to N / |B| for unbiased mini-batch SGD-MAP.

1.0
Source code in src/pyrox/inference/_ensemble.py
def ensemble_loss(
    log_joint: Callable[
        [PyTree, Array, Array],
        tuple[Float[Array, ""], Float[Array, ""]],
    ],
    *,
    prior_weight: float = 1.0,
    scale: float = 1.0,
) -> Callable[[PyTree, Array, Array], tuple[Float[Array, ""], PyTree]]:
    r"""Build the ``filter_value_and_grad`` loss from a log-joint.

    Returns a function ``loss_fn(params, x_batch, y_batch) -> (loss, grads)``
    that computes

    .. math::

        \mathcal{L}(\theta) = -\, \text{scale} \cdot \log p(y \mid x, \theta)
            \;-\; w_{\text{prior}} \cdot \log p(\theta).

    ``prior_weight=0`` short-circuits the prior term so the user may
    return a placeholder ``0.0`` for ``logprior``.

    Args:
        log_joint: ``(params, x_batch, y_batch) -> (loglik, logprior)``.
        prior_weight: Weight on the ``logprior`` term. ``0.0`` ⇒ MLE,
            ``1.0`` ⇒ MAP.
        scale: Multiplicative weight on the ``loglik`` term. Set to
            ``N / |B|`` for unbiased mini-batch SGD-MAP.
    """
    use_prior = prior_weight != 0.0

    @eqx.filter_value_and_grad
    def _loss(params: PyTree, xb: Array, yb: Array) -> Float[Array, ""]:
        ll, lp = log_joint(params, xb, yb)
        nll = -scale * ll
        if use_prior:
            return nll - prior_weight * lp
        return nll

    return _loss

pyrox.inference.ensemble_step(state, x_batch, y_batch, *, log_joint, optimizer, prior_weight=1.0, scale=1.0)

Perform one ensemble update step on a batch.

Each ensemble member computes ∇L(θ_e) independently (vmapped), advances its optax state, and returns the updated params + state.

Parameters:

Name Type Description Default
state EnsembleState

Current :class:EnsembleState.

required
x_batch Array

Inputs.

required
y_batch Array

Targets.

required
log_joint Callable[[PyTree, Array, Array], tuple[Float[Array, ''], Float[Array, '']]]

(params, x, y) -> (loglik, logprior).

required
optimizer GradientTransformation

Same optax.GradientTransformation passed to :func:ensemble_init. Stateless transforms can be re-built per call; stateful schedules need to share the same instance.

required
prior_weight float

Weight on logprior; 0 ⇒ MLE.

1.0
scale float

Weight on loglik (use N / |B| for mini-batch).

1.0

Returns:

Type Description
EnsembleState

(new_state, per_member_losses) where losses has shape

Float[Array, ' E']

(E,).

Source code in src/pyrox/inference/_ensemble.py
def ensemble_step(
    state: EnsembleState,
    x_batch: Array,
    y_batch: Array,
    *,
    log_joint: Callable[
        [PyTree, Array, Array],
        tuple[Float[Array, ""], Float[Array, ""]],
    ],
    optimizer: optax.GradientTransformation,
    prior_weight: float = 1.0,
    scale: float = 1.0,
) -> tuple[EnsembleState, Float[Array, " E"]]:
    """Perform one ensemble update step on a batch.

    Each ensemble member computes ``∇L(θ_e)`` independently (vmapped),
    advances its optax state, and returns the updated params + state.

    Args:
        state: Current :class:`EnsembleState`.
        x_batch: Inputs.
        y_batch: Targets.
        log_joint: ``(params, x, y) -> (loglik, logprior)``.
        optimizer: Same ``optax.GradientTransformation`` passed to
            :func:`ensemble_init`. Stateless transforms can be re-built
            per call; stateful schedules need to share the same
            instance.
        prior_weight: Weight on ``logprior``; ``0`` ⇒ MLE.
        scale: Weight on ``loglik`` (use ``N / |B|`` for mini-batch).

    Returns:
        ``(new_state, per_member_losses)`` where losses has shape
        ``(E,)``.
    """
    loss_fn = ensemble_loss(log_joint, prior_weight=prior_weight, scale=scale)

    @eqx.filter_vmap
    def _per_member(
        params: PyTree, opt_state: Any
    ) -> tuple[PyTree, Any, Float[Array, ""]]:
        loss, grads = loss_fn(params, x_batch, y_batch)
        updates, opt_state = optimizer.update(
            grads, opt_state, eqx.filter(params, eqx.is_inexact_array)
        )
        params = eqx.apply_updates(params, updates)
        return params, opt_state, loss

    params, opt_state, losses = _per_member(state.params, state.opt_state)
    return EnsembleState(params=params, opt_state=opt_state), losses

Layer 2 — NumPyro-like inference ops

init / update / run triplets that mirror numpyro.infer.SVI.

pyrox.inference.EnsembleMAP

Bases: Module

NumPyro-like ensemble MAP/MLE runner.

Mirrors :class:numpyro.infer.SVI's init / update / run triplet, but every operation is ensembled by vmap over the leading (E,) axis.

Per-member objective is the tempered negative log-posterior

.. math::

\mathcal{L}_e(\theta_e) = -\frac{N}{|B|}\sum_{i \in B}
    \log p(y_i \mid x_i, \theta_e)
    \;-\; w_{\text{prior}} \cdot \log p(\theta_e),

where :math:w_{\text{prior}} is :attr:prior_weight.

Example

runner = EnsembleMAP( ... log_joint=log_joint, ... init_fn=init_fn, ... optimizer=optax.adam(5e-3), ... ensemble_size=16, ... )

numpyro-style three-method API

state = runner.init(jr.PRNGKey(0)) for _ in range(2000): ... state, losses = runner.update(state, x, y)

or one-shot

result = runner.run(jr.PRNGKey(0), 2000, x, y)

Source code in src/pyrox/inference/_ensemble.py
class EnsembleMAP(eqx.Module):
    r"""NumPyro-like ensemble MAP/MLE runner.

    Mirrors :class:`numpyro.infer.SVI`'s ``init`` / ``update`` / ``run``
    triplet, but every operation is ensembled by ``vmap`` over the
    leading ``(E,)`` axis.

    Per-member objective is the tempered negative log-posterior

    .. math::

        \mathcal{L}_e(\theta_e) = -\frac{N}{|B|}\sum_{i \in B}
            \log p(y_i \mid x_i, \theta_e)
            \;-\; w_{\text{prior}} \cdot \log p(\theta_e),

    where :math:`w_{\text{prior}}` is :attr:`prior_weight`.

    Example:
        >>> runner = EnsembleMAP(
        ...     log_joint=log_joint,
        ...     init_fn=init_fn,
        ...     optimizer=optax.adam(5e-3),
        ...     ensemble_size=16,
        ... )
        >>> # numpyro-style three-method API
        >>> state = runner.init(jr.PRNGKey(0))
        >>> for _ in range(2000):
        ...     state, losses = runner.update(state, x, y)
        >>> # or one-shot
        >>> result = runner.run(jr.PRNGKey(0), 2000, x, y)
    """

    log_joint: Callable[
        [PyTree, Array, Array],
        tuple[Float[Array, ""], Float[Array, ""]],
    ]
    init_fn: Callable[[PRNGKeyArray], PyTree]
    optimizer: Any  # optax.GradientTransformation, but optax may not be installed
    ensemble_size: int = eqx.field(static=True, default=16)
    prior_weight: float = eqx.field(static=True, default=1.0)

    def init(self, seed: PRNGKeyArray) -> EnsembleState:
        """Initialize the ensemble. Mirrors ``numpyro.infer.SVI.init``."""
        return ensemble_init(
            self.init_fn,
            self.optimizer,
            ensemble_size=self.ensemble_size,
            seed=seed,
        )

    def update(
        self,
        state: EnsembleState,
        x_batch: Array,
        y_batch: Array,
        *,
        scale: float = 1.0,
    ) -> tuple[EnsembleState, Float[Array, " E"]]:
        """One ensemble update step. Mirrors ``numpyro.infer.SVI.update``.

        Args:
            state: Current :class:`EnsembleState`.
            x_batch: Batch inputs.
            y_batch: Batch targets.
            scale: ``N / |B|`` for mini-batch unbiased SGD-MAP. Defaults
                to ``1.0`` (full-batch).
        """
        return ensemble_step(
            state,
            x_batch,
            y_batch,
            log_joint=self.log_joint,
            optimizer=self.optimizer,
            prior_weight=self.prior_weight,
            scale=scale,
        )

    def run(
        self,
        seed: PRNGKeyArray,
        num_epochs: int,
        x: Array,
        y: Array,
        *,
        batch_size: int | None = None,
    ) -> EnsembleResult:
        """Fit the ensemble end-to-end. Mirrors ``numpyro.infer.SVI.run``.

        Internally drives :func:`ensemble_step` via ``lax.scan`` for
        speed; equivalent to a hand-written Python loop over
        :meth:`update`.

        Args:
            seed: PRNG key used for both init and (when applicable)
                mini-batch index permutation.
            num_epochs: Number of optimizer steps per member.
            x: Inputs.
            y: Targets.
            batch_size: Optional mini-batch size. ``None`` ⇒ full-batch.

        Returns:
            :class:`EnsembleResult` with stacked final params + loss
            history of shape ``(E, num_epochs)``.
        """
        n = y.shape[0]
        # Clamp batch_size to n so that batch_size > n falls back cleanly
        # to full-batch with scale=1, instead of silently downweighting
        # the likelihood by n/batch_size < 1 and biasing toward the prior.
        bsz = n if batch_size is None else min(batch_size, n)
        scale = float(n) / float(bsz)
        use_minibatch = batch_size is not None and bsz < n

        init_key, perm_key = jr.split(seed)
        state = self.init(init_key)

        # Partition non-array leaves out of the scan carry; lax.scan only
        # accepts JAX-typed carry, but eqx.Module trees may contain
        # captured Python callables (e.g. jax.nn.tanh). Static is shared
        # across ensemble members, so we recombine after the scan.
        arrays0, static = eqx.partition(state.params, eqx.is_inexact_array)

        def epoch(
            carry: tuple[PyTree, Any, PRNGKeyArray],
            _: None,
        ) -> tuple[tuple[PyTree, Any, PRNGKeyArray], Float[Array, " E"]]:
            arrays, opt_state, key = carry
            if use_minibatch:
                key, sub = jr.split(key)
                idx = jr.permutation(sub, n)[:bsz]
                xb, yb = x[idx], y[idx]
            else:
                xb, yb = x, y
            params = eqx.combine(arrays, static)
            new_state, losses = ensemble_step(
                EnsembleState(params=params, opt_state=opt_state),
                xb,
                yb,
                log_joint=self.log_joint,
                optimizer=self.optimizer,
                prior_weight=self.prior_weight,
                scale=scale,
            )
            new_arrays, _ = eqx.partition(new_state.params, eqx.is_inexact_array)
            return (new_arrays, new_state.opt_state, key), losses

        (arrays_final, _, _), losses = jax.lax.scan(
            epoch, (arrays0, state.opt_state, perm_key), None, length=num_epochs
        )
        final_params = eqx.combine(arrays_final, static)
        # losses is (T, E) from scan; transpose to (E, T) per the public contract.
        return EnsembleResult(params=final_params, losses=losses.T)

init(seed)

Initialize the ensemble. Mirrors numpyro.infer.SVI.init.

Source code in src/pyrox/inference/_ensemble.py
def init(self, seed: PRNGKeyArray) -> EnsembleState:
    """Initialize the ensemble. Mirrors ``numpyro.infer.SVI.init``."""
    return ensemble_init(
        self.init_fn,
        self.optimizer,
        ensemble_size=self.ensemble_size,
        seed=seed,
    )

run(seed, num_epochs, x, y, *, batch_size=None)

Fit the ensemble end-to-end. Mirrors numpyro.infer.SVI.run.

Internally drives :func:ensemble_step via lax.scan for speed; equivalent to a hand-written Python loop over :meth:update.

Parameters:

Name Type Description Default
seed PRNGKeyArray

PRNG key used for both init and (when applicable) mini-batch index permutation.

required
num_epochs int

Number of optimizer steps per member.

required
x Array

Inputs.

required
y Array

Targets.

required
batch_size int | None

Optional mini-batch size. None ⇒ full-batch.

None

Returns:

Type Description
EnsembleResult

class:EnsembleResult with stacked final params + loss

EnsembleResult

history of shape (E, num_epochs).

Source code in src/pyrox/inference/_ensemble.py
def run(
    self,
    seed: PRNGKeyArray,
    num_epochs: int,
    x: Array,
    y: Array,
    *,
    batch_size: int | None = None,
) -> EnsembleResult:
    """Fit the ensemble end-to-end. Mirrors ``numpyro.infer.SVI.run``.

    Internally drives :func:`ensemble_step` via ``lax.scan`` for
    speed; equivalent to a hand-written Python loop over
    :meth:`update`.

    Args:
        seed: PRNG key used for both init and (when applicable)
            mini-batch index permutation.
        num_epochs: Number of optimizer steps per member.
        x: Inputs.
        y: Targets.
        batch_size: Optional mini-batch size. ``None`` ⇒ full-batch.

    Returns:
        :class:`EnsembleResult` with stacked final params + loss
        history of shape ``(E, num_epochs)``.
    """
    n = y.shape[0]
    # Clamp batch_size to n so that batch_size > n falls back cleanly
    # to full-batch with scale=1, instead of silently downweighting
    # the likelihood by n/batch_size < 1 and biasing toward the prior.
    bsz = n if batch_size is None else min(batch_size, n)
    scale = float(n) / float(bsz)
    use_minibatch = batch_size is not None and bsz < n

    init_key, perm_key = jr.split(seed)
    state = self.init(init_key)

    # Partition non-array leaves out of the scan carry; lax.scan only
    # accepts JAX-typed carry, but eqx.Module trees may contain
    # captured Python callables (e.g. jax.nn.tanh). Static is shared
    # across ensemble members, so we recombine after the scan.
    arrays0, static = eqx.partition(state.params, eqx.is_inexact_array)

    def epoch(
        carry: tuple[PyTree, Any, PRNGKeyArray],
        _: None,
    ) -> tuple[tuple[PyTree, Any, PRNGKeyArray], Float[Array, " E"]]:
        arrays, opt_state, key = carry
        if use_minibatch:
            key, sub = jr.split(key)
            idx = jr.permutation(sub, n)[:bsz]
            xb, yb = x[idx], y[idx]
        else:
            xb, yb = x, y
        params = eqx.combine(arrays, static)
        new_state, losses = ensemble_step(
            EnsembleState(params=params, opt_state=opt_state),
            xb,
            yb,
            log_joint=self.log_joint,
            optimizer=self.optimizer,
            prior_weight=self.prior_weight,
            scale=scale,
        )
        new_arrays, _ = eqx.partition(new_state.params, eqx.is_inexact_array)
        return (new_arrays, new_state.opt_state, key), losses

    (arrays_final, _, _), losses = jax.lax.scan(
        epoch, (arrays0, state.opt_state, perm_key), None, length=num_epochs
    )
    final_params = eqx.combine(arrays_final, static)
    # losses is (T, E) from scan; transpose to (E, T) per the public contract.
    return EnsembleResult(params=final_params, losses=losses.T)

update(state, x_batch, y_batch, *, scale=1.0)

One ensemble update step. Mirrors numpyro.infer.SVI.update.

Parameters:

Name Type Description Default
state EnsembleState

Current :class:EnsembleState.

required
x_batch Array

Batch inputs.

required
y_batch Array

Batch targets.

required
scale float

N / |B| for mini-batch unbiased SGD-MAP. Defaults to 1.0 (full-batch).

1.0
Source code in src/pyrox/inference/_ensemble.py
def update(
    self,
    state: EnsembleState,
    x_batch: Array,
    y_batch: Array,
    *,
    scale: float = 1.0,
) -> tuple[EnsembleState, Float[Array, " E"]]:
    """One ensemble update step. Mirrors ``numpyro.infer.SVI.update``.

    Args:
        state: Current :class:`EnsembleState`.
        x_batch: Batch inputs.
        y_batch: Batch targets.
        scale: ``N / |B|`` for mini-batch unbiased SGD-MAP. Defaults
            to ``1.0`` (full-batch).
    """
    return ensemble_step(
        state,
        x_batch,
        y_batch,
        log_joint=self.log_joint,
        optimizer=self.optimizer,
        prior_weight=self.prior_weight,
        scale=scale,
    )

pyrox.inference.EnsembleVI

Bases: Module

NumPyro-like ensemble variational-inference runner.

Wraps :class:numpyro.infer.SVI + :class:numpyro.infer.Trace_ELBO with the same ensemble surface as :class:EnsembleMAP.

Per-member objective is the tempered ELBO

.. math::

\mathrm{ELBO}(\phi_e) = \mathbb{E}_{q_{\phi_e}}\!
    \bigl[\log p(y \mid x, \theta)\bigr]
    - \beta\, \mathrm{KL}\!\bigl(q_{\phi_e}\,\|\,p\bigr),

where :math:\beta is :attr:kl_weight.

Source code in src/pyrox/inference/_ensemble.py
class EnsembleVI(eqx.Module):
    r"""NumPyro-like ensemble variational-inference runner.

    Wraps :class:`numpyro.infer.SVI` + :class:`numpyro.infer.Trace_ELBO`
    with the same ensemble surface as :class:`EnsembleMAP`.

    Per-member objective is the tempered ELBO

    .. math::

        \mathrm{ELBO}(\phi_e) = \mathbb{E}_{q_{\phi_e}}\!
            \bigl[\log p(y \mid x, \theta)\bigr]
            - \beta\, \mathrm{KL}\!\bigl(q_{\phi_e}\,\|\,p\bigr),

    where :math:`\beta` is :attr:`kl_weight`.
    """

    model_fn: Callable[..., None]
    guide_fn: Callable[..., None]
    optimizer: Any  # numpyro.optim or optax.GradientTransformation
    ensemble_size: int = eqx.field(static=True, default=16)
    kl_weight: float = eqx.field(static=True, default=1.0)
    num_particles: int = eqx.field(static=True, default=1)

    def _build_svi(self) -> Any:
        import numpyro
        from numpyro.infer import SVI, Trace_ELBO

        opt = self.optimizer
        # Duck-type rather than isinstance against numpyro.optim._NumPyroOptim:
        # the latter is a private symbol that may move under refactors. The
        # public contract of a numpyro optimizer is `init_fn` + `update_fn` +
        # `get_params`; if any are missing, assume it's an optax transform.
        is_numpyro_optim = (
            callable(getattr(opt, "init_fn", None))
            and callable(getattr(opt, "update_fn", None))
            and callable(getattr(opt, "get_params", None))
        )
        if not is_numpyro_optim:
            opt = numpyro.optim.optax_to_numpyro(opt)
        if self.kl_weight == 1.0:
            elbo: Any = Trace_ELBO(num_particles=self.num_particles)
        else:
            elbo = _TemperedTraceELBO(
                kl_weight=self.kl_weight, num_particles=self.num_particles
            )
        return SVI(self.model_fn, self.guide_fn, opt, loss=elbo)

    def init(self, seed: PRNGKeyArray, *args: Any, **kwargs: Any) -> Any:
        """Initialize the ensemble of SVI states."""
        svi = self._build_svi()
        keys = jr.split(seed, self.ensemble_size)
        return jax.vmap(lambda k: svi.init(k, *args, **kwargs))(keys)

    def update(
        self, state: Any, *args: Any, **kwargs: Any
    ) -> tuple[Any, Float[Array, " E"]]:
        """One ensemble SVI update. Mirrors ``numpyro.infer.SVI.update``."""
        svi = self._build_svi()
        return jax.vmap(lambda s: svi.update(s, *args, **kwargs))(state)

    def run(
        self,
        seed: PRNGKeyArray,
        num_epochs: int,
        *args: Any,
        **kwargs: Any,
    ) -> EnsembleResult:
        """Fit the ensemble end-to-end via vmapped ``svi.run``."""
        svi = self._build_svi()
        keys = jr.split(seed, self.ensemble_size)

        @jax.vmap
        def _run_one(
            k: PRNGKeyArray,
        ) -> tuple[PyTree, Float[Array, " T"]]:
            r = svi.run(k, num_epochs, *args, progress_bar=False, **kwargs)
            return r.params, r.losses

        params, losses = _run_one(keys)
        return EnsembleResult(params=params, losses=losses)

init(seed, *args, **kwargs)

Initialize the ensemble of SVI states.

Source code in src/pyrox/inference/_ensemble.py
def init(self, seed: PRNGKeyArray, *args: Any, **kwargs: Any) -> Any:
    """Initialize the ensemble of SVI states."""
    svi = self._build_svi()
    keys = jr.split(seed, self.ensemble_size)
    return jax.vmap(lambda k: svi.init(k, *args, **kwargs))(keys)

run(seed, num_epochs, *args, **kwargs)

Fit the ensemble end-to-end via vmapped svi.run.

Source code in src/pyrox/inference/_ensemble.py
def run(
    self,
    seed: PRNGKeyArray,
    num_epochs: int,
    *args: Any,
    **kwargs: Any,
) -> EnsembleResult:
    """Fit the ensemble end-to-end via vmapped ``svi.run``."""
    svi = self._build_svi()
    keys = jr.split(seed, self.ensemble_size)

    @jax.vmap
    def _run_one(
        k: PRNGKeyArray,
    ) -> tuple[PyTree, Float[Array, " T"]]:
        r = svi.run(k, num_epochs, *args, progress_bar=False, **kwargs)
        return r.params, r.losses

    params, losses = _run_one(keys)
    return EnsembleResult(params=params, losses=losses)

update(state, *args, **kwargs)

One ensemble SVI update. Mirrors numpyro.infer.SVI.update.

Source code in src/pyrox/inference/_ensemble.py
def update(
    self, state: Any, *args: Any, **kwargs: Any
) -> tuple[Any, Float[Array, " E"]]:
    """One ensemble SVI update. Mirrors ``numpyro.infer.SVI.update``."""
    svi = self._build_svi()
    return jax.vmap(lambda s: svi.update(s, *args, **kwargs))(state)

Layer 3 — One-shot sugar

pyrox.inference.ensemble_map(log_joint, init_fn, *, ensemble_size, num_epochs, data, seed, batch_size=None, learning_rate=0.005, prior_weight=1.0, optimizer=None)

One-shot wrapper around :class:EnsembleMAP.

Equivalent to EnsembleMAP(log_joint, init_fn, optimizer, ensemble_size=E, prior_weight=w).run(seed, num_epochs, *data, batch_size=B).

Parameters:

Name Type Description Default
log_joint Callable[[PyTree, Array, Array], tuple[Float[Array, ''], Float[Array, '']]]

(params, x_batch, y_batch) -> (loglik, logprior).

required
init_fn Callable[[PRNGKeyArray], PyTree]

key -> params.

required
ensemble_size int

Number of independent MAP fits E.

required
num_epochs int

Optimizer steps per member.

required
data tuple[Array, Array]

(x, y).

required
seed PRNGKeyArray

PRNG key.

required
batch_size int | None

Mini-batch size. None ⇒ full-batch.

None
learning_rate float

Default-Adam learning rate. Ignored if optimizer is supplied.

0.005
prior_weight float

0 ⇒ MLE, 1 ⇒ MAP.

1.0
optimizer GradientTransformation | None

Optional optax.GradientTransformation. Defaults to optax.adam(learning_rate).

None

Returns:

Type Description
PyTree

(params_stacked, losses) — leading (E,) axis on

Float[Array, 'E T']

params; losses shape (E, num_epochs).

Example

params, losses = ensemble_map( ... log_joint, init_fn, ... ensemble_size=16, num_epochs=2000, ... data=(X, y), seed=jr.PRNGKey(0), ... )

Source code in src/pyrox/inference/_ensemble.py
def ensemble_map(
    log_joint: Callable[
        [PyTree, Array, Array],
        tuple[Float[Array, ""], Float[Array, ""]],
    ],
    init_fn: Callable[[PRNGKeyArray], PyTree],
    *,
    ensemble_size: int,
    num_epochs: int,
    data: tuple[Array, Array],
    seed: PRNGKeyArray,
    batch_size: int | None = None,
    learning_rate: float = 5e-3,
    prior_weight: float = 1.0,
    optimizer: optax.GradientTransformation | None = None,
) -> tuple[PyTree, Float[Array, "E T"]]:
    """One-shot wrapper around :class:`EnsembleMAP`.

    Equivalent to ``EnsembleMAP(log_joint, init_fn, optimizer,
    ensemble_size=E, prior_weight=w).run(seed, num_epochs, *data,
    batch_size=B)``.

    Args:
        log_joint: ``(params, x_batch, y_batch) -> (loglik, logprior)``.
        init_fn: ``key -> params``.
        ensemble_size: Number of independent MAP fits ``E``.
        num_epochs: Optimizer steps per member.
        data: ``(x, y)``.
        seed: PRNG key.
        batch_size: Mini-batch size. ``None`` ⇒ full-batch.
        learning_rate: Default-Adam learning rate. Ignored if
            ``optimizer`` is supplied.
        prior_weight: ``0`` ⇒ MLE, ``1`` ⇒ MAP.
        optimizer: Optional ``optax.GradientTransformation``. Defaults
            to ``optax.adam(learning_rate)``.

    Returns:
        ``(params_stacked, losses)`` — leading ``(E,)`` axis on
        ``params``; ``losses`` shape ``(E, num_epochs)``.

    Example:
        >>> params, losses = ensemble_map(
        ...     log_joint, init_fn,
        ...     ensemble_size=16, num_epochs=2000,
        ...     data=(X, y), seed=jr.PRNGKey(0),
        ... )
    """
    optax = _require_optax()
    opt = optimizer if optimizer is not None else optax.adam(learning_rate)
    runner = EnsembleMAP(
        log_joint=log_joint,
        init_fn=init_fn,
        optimizer=opt,
        ensemble_size=ensemble_size,
        prior_weight=prior_weight,
    )
    result = runner.run(seed, num_epochs, *data, batch_size=batch_size)  # ty: ignore[unresolved-attribute]
    return result.params, result.losses

pyrox.inference.ensemble_vi(model_fn, guide_fn, *, ensemble_size, num_epochs, data, seed, kl_weight=1.0, learning_rate=0.005, optimizer=None, num_particles=1)

One-shot wrapper around :class:EnsembleVI.

Parameters:

Name Type Description Default
model_fn Callable[..., None]

NumPyro model (x, y) -> None.

required
guide_fn Callable[..., None]

NumPyro guide.

required
ensemble_size int

Number of SVI fits E.

required
num_epochs int

Steps per member.

required
data tuple[Array, Array]

(x, y).

required
seed PRNGKeyArray

PRNG key.

required
kl_weight float

ELBO temper β. 1.0 ⇒ standard ELBO.

1.0
learning_rate float

Default-Adam learning rate.

0.005
optimizer Any

Optional numpyro.optim or optax.GradientTransformation (auto-wrapped).

None
num_particles int

MC particles per ELBO estimate.

1

Returns:

Type Description
tuple[PyTree, Float[Array, 'E T']]

(guide_params_stacked, losses) with leading (E,) axis.

Source code in src/pyrox/inference/_ensemble.py
def ensemble_vi(
    model_fn: Callable[..., None],
    guide_fn: Callable[..., None],
    *,
    ensemble_size: int,
    num_epochs: int,
    data: tuple[Array, Array],
    seed: PRNGKeyArray,
    kl_weight: float = 1.0,
    learning_rate: float = 5e-3,
    optimizer: Any = None,
    num_particles: int = 1,
) -> tuple[PyTree, Float[Array, "E T"]]:
    """One-shot wrapper around :class:`EnsembleVI`.

    Args:
        model_fn: NumPyro model ``(x, y) -> None``.
        guide_fn: NumPyro guide.
        ensemble_size: Number of SVI fits ``E``.
        num_epochs: Steps per member.
        data: ``(x, y)``.
        seed: PRNG key.
        kl_weight: ELBO temper ``β``. ``1.0`` ⇒ standard ELBO.
        learning_rate: Default-Adam learning rate.
        optimizer: Optional ``numpyro.optim`` or
            ``optax.GradientTransformation`` (auto-wrapped).
        num_particles: MC particles per ELBO estimate.

    Returns:
        ``(guide_params_stacked, losses)`` with leading ``(E,)`` axis.
    """
    import numpyro

    if optimizer is None:
        optimizer = numpyro.optim.Adam(learning_rate)
    runner = EnsembleVI(
        model_fn=model_fn,
        guide_fn=guide_fn,
        optimizer=optimizer,
        ensemble_size=ensemble_size,
        kl_weight=kl_weight,
        num_particles=num_particles,
    )
    result = runner.run(seed, num_epochs, *data)  # ty: ignore[unresolved-attribute]
    return result.params, result.losses

pyrox.inference.ensemble_predict(params_stacked, predict_fn, x_new)

Vmap predict_fn over the leading ensemble axis of params.

Uses :func:equinox.filter_vmap so it works whether params_stacked is a pure-array PyTree or an :class:equinox.Module containing non-array leaves (e.g. captured jax.nn.tanh). Array leaves are mapped over axis 0; non-array leaves are broadcast.

Parameters:

Name Type Description Default
params_stacked PyTree

PyTree returned by :func:ensemble_map / :func:ensemble_vi / :class:EnsembleMAP.run; every array leaf has a leading (E,) axis.

required
predict_fn Callable[[PyTree, Array], Array]

(params, x) -> y.

required
x_new Array

Inputs to predict at; shared across all members.

required

Returns:

Type Description
Array

Stacked predictions with leading (E,) axis.

Source code in src/pyrox/inference/_ensemble.py
def ensemble_predict(
    params_stacked: PyTree,
    predict_fn: Callable[[PyTree, Array], Array],
    x_new: Array,
) -> Array:
    """Vmap ``predict_fn`` over the leading ensemble axis of params.

    Uses :func:`equinox.filter_vmap` so it works whether
    ``params_stacked`` is a pure-array PyTree or an
    :class:`equinox.Module` containing non-array leaves (e.g. captured
    ``jax.nn.tanh``). Array leaves are mapped over axis 0; non-array
    leaves are broadcast.

    Args:
        params_stacked: PyTree returned by :func:`ensemble_map` /
            :func:`ensemble_vi` / :class:`EnsembleMAP.run`; every array
            leaf has a leading ``(E,)`` axis.
        predict_fn: ``(params, x) -> y``.
        x_new: Inputs to predict at; shared across all members.

    Returns:
        Stacked predictions with leading ``(E,)`` axis.
    """
    return eqx.filter_vmap(predict_fn, in_axes=(eqx.if_array(0), None))(
        params_stacked, x_new
    )

Result containers

pyrox.inference.EnsembleState

Bases: NamedTuple

Stacked state for an ensemble of optimizer runs.

Attributes:

Name Type Description
params PyTree

Stacked parameter PyTree. Array leaves carry a leading (E,) axis; non-array leaves (e.g. captured activation functions inside an :class:equinox.Module) are shared.

opt_state PyTree

Stacked optax optimizer state with the same axis convention.

Source code in src/pyrox/inference/_ensemble.py
class EnsembleState(NamedTuple):
    """Stacked state for an ensemble of optimizer runs.

    Attributes:
        params: Stacked parameter PyTree. Array leaves carry a leading
            ``(E,)`` axis; non-array leaves (e.g. captured activation
            functions inside an :class:`equinox.Module`) are shared.
        opt_state: Stacked optax optimizer state with the same axis
            convention.
    """

    params: PyTree
    opt_state: PyTree

pyrox.inference.EnsembleResult

Bases: NamedTuple

Output of a full :meth:EnsembleMAP.run / :meth:EnsembleVI.run.

Attributes:

Name Type Description
params PyTree

Final stacked parameters with leading (E,) axis.

losses Float[Array, 'E T']

(E, num_epochs) per-step loss history.

Source code in src/pyrox/inference/_ensemble.py
class EnsembleResult(NamedTuple):
    """Output of a full :meth:`EnsembleMAP.run` / :meth:`EnsembleVI.run`.

    Attributes:
        params: Final stacked parameters with leading ``(E,)`` axis.
        losses: ``(E, num_epochs)`` per-step loss history.
    """

    params: PyTree
    losses: Float[Array, "E T"]