Skip to content

Core API

The pyrox._core subpackage owns the Equinox-to-NumPyro bridge that every other subpackage composes on top of.

Public surface

Core: Equinox-to-NumPyro bridge primitives.

Public surface:

  • :class:PyroxModule — Equinox module with pyrox_param / pyrox_sample
  • :class:PyroxParam — declarative parameter descriptor
  • :class:PyroxSample — declarative sample descriptor
  • :class:Parameterized — param registry with priors, guides, and modes
  • :func:pyrox_method — decorator that activates the per-call context

Parameterized

Bases: PyroxModule

Shared base for modules with priors, constraints, and mode switching.

Subclasses typically declare parameters inside :meth:setup, which is invoked automatically after __init__ completes. Use :meth:register_param to declare a parameter, :meth:set_prior to attach a prior, :meth:autoguide to pick a guide type, and :meth:set_mode to switch between sampling from the prior and sampling from the guide.

Per-instance state (params, priors, guides, mode) lives in a class-level registry keyed by id(self). Cleanup happens via :mod:weakref.finalize when the instance is collected; call :meth:_teardown for explicit cleanup.

Source code in src/pyrox/_core/parameterized.py
class Parameterized(PyroxModule):
    """Shared base for modules with priors, constraints, and mode switching.

    Subclasses typically declare parameters inside :meth:`setup`, which is
    invoked automatically after ``__init__`` completes. Use
    :meth:`register_param` to declare a parameter, :meth:`set_prior` to
    attach a prior, :meth:`autoguide` to pick a guide type, and
    :meth:`set_mode` to switch between sampling from the prior and
    sampling from the guide.

    Per-instance state (params, priors, guides, mode) lives in a
    class-level registry keyed by ``id(self)``. Cleanup happens via
    :mod:`weakref.finalize` when the instance is collected; call
    :meth:`_teardown` for explicit cleanup.
    """

    _registry: ClassVar[dict[int, _State]] = {}

    def __post_init__(self) -> None:
        setup = getattr(self, "setup", None)
        if callable(setup):
            setup()

    def _state(self) -> _State:
        key = id(self)
        state = Parameterized._registry.get(key)
        if state is None:
            state = _State()
            Parameterized._registry[key] = state
            with contextlib.suppress(TypeError):
                weakref.finalize(self, Parameterized._registry.pop, key, None)
        return state

    def _entry(self, name: str) -> _Entry:
        entry = self._state().params.get(name)
        if entry is None:
            raise KeyError(
                f"parameter {name!r} not registered; call register_param first"
            )
        return entry

    def register_param(
        self,
        name: str,
        init_value: Any,
        constraint: Any = None,
    ) -> None:
        self._state().params[name] = _Entry(
            init_value=init_value, constraint=constraint
        )

    def set_prior(self, name: str, prior: Any) -> None:
        self._entry(name).prior = prior

    def autoguide(self, name: str, guide_type: GuideType) -> None:
        if guide_type not in _VALID_GUIDES:
            raise ValueError(
                f"guide_type must be one of {sorted(_VALID_GUIDES)!r}, "
                f"got {guide_type!r}"
            )
        self._entry(name).guide_type = guide_type

    def set_mode(self, mode: Mode) -> None:
        if mode not in ("model", "guide"):
            raise ValueError(f"mode must be 'model' or 'guide', got {mode!r}")
        self._state().mode = mode

    def get_param(self, name: str) -> Any:
        entry = self._entry(name)
        state = self._state()
        if state.mode == "model" and entry.prior is not None:
            return self.pyrox_sample(name, entry.prior)
        if state.mode == "guide" and entry.prior is not None:
            return self._guide_param(name, entry)
        return self.pyrox_param(name, entry.init_value, constraint=entry.constraint)

    def load_pyro_samples(self) -> None:
        for name in list(self._state().params):
            self.get_param(name)

    def _teardown(self) -> None:
        Parameterized._registry.pop(id(self), None)
        super()._teardown()

    def _guide_param(self, name: str, entry: _Entry) -> Any:
        guide = entry.guide_type
        if guide == "delta":
            return self.pyrox_param(name, entry.init_value, constraint=entry.constraint)
        if guide == "normal":
            return self._guide_normal(name, entry)
        raise NotImplementedError(
            f"guide_type {guide!r} is not yet supported at the "
            "get_param level; materialize via a dedicated guide layer."
        )

    def _guide_normal(self, name: str, entry: _Entry) -> Any:
        """Mean-field normal guide in unconstrained space.

        When ``entry.constraint`` is non-trivial, the latent site is a
        ``TransformedDistribution`` wrapping ``Normal(loc, scale)`` with
        the constraint's bijection, so guide draws always land in the
        prior's support. ``loc`` is initialized by the inverse transform
        of ``init_value`` so guide and prior agree at step zero.
        """
        init = jnp.asarray(entry.init_value)
        if _is_real_support(entry.constraint):
            loc = self.pyrox_param(f"{name}_loc", init)
            scale = self.pyrox_param(
                f"{name}_scale",
                jnp.ones_like(init) * 0.1,
                constraint=dist.constraints.positive,
            )
            return self.pyrox_sample(name, dist.Normal(loc, scale))
        transform = _biject_to(entry.constraint)
        loc = self.pyrox_param(f"{name}_loc", transform.inv(init))
        scale = self.pyrox_param(
            f"{name}_scale",
            jnp.ones_like(init) * 0.1,
            constraint=dist.constraints.positive,
        )
        base = dist.Normal(loc, scale)
        return self.pyrox_sample(name, dist.TransformedDistribution(base, transform))

PyroxModule

Bases: Module

Equinox module with NumPyro site registration and per-call caching.

Subclasses register deterministic parameters via :meth:pyrox_param and random variables via :meth:pyrox_sample. Wrap the method that drives registration (typically __call__) with :func:pyrox_method so the per-call _Context is active for the duration of the call.

Without the decorator the cache is inactive and duplicate references to the same site within one trace will hit NumPyro's uniqueness check.

Source code in src/pyrox/_core/pyrox_module.py
class PyroxModule(eqx.Module):
    """Equinox module with NumPyro site registration and per-call caching.

    Subclasses register deterministic parameters via :meth:`pyrox_param`
    and random variables via :meth:`pyrox_sample`. Wrap the method that
    drives registration (typically ``__call__``) with :func:`pyrox_method`
    so the per-call ``_Context`` is active for the duration of the call.

    Without the decorator the cache is inactive and duplicate references
    to the same site within one trace will hit NumPyro's uniqueness check.
    """

    _contexts: ClassVar[dict[int, _Context]] = {}

    def _get_context(self) -> _Context:
        key = id(self)
        ctx = PyroxModule._contexts.get(key)
        if ctx is None:
            ctx = _Context()
            PyroxModule._contexts[key] = ctx
            with contextlib.suppress(TypeError):
                weakref.finalize(self, PyroxModule._contexts.pop, key, None)
        return ctx

    def _pyrox_scope_name(self) -> str:
        """Per-instance scope used when building fully-qualified site names.

        Uses an explicit ``pyrox_name`` attribute if the subclass defines
        one (as a field or class variable); otherwise falls back to a
        ``{ClassName}_{id}`` tag so sibling instances of the same class
        register distinct sites within a single trace. The id-based
        fallback is stable within a Python process but not across runs —
        set ``pyrox_name`` explicitly for checkpoint-portable names.
        """
        name = getattr(self, "pyrox_name", None)
        if isinstance(name, str) and name:
            return name
        return f"{type(self).__name__}_{id(self):x}"

    def _pyrox_fullname(self, name: str) -> str:
        return f"{self._pyrox_scope_name()}.{name}"

    def pyrox_param(
        self,
        name: str,
        init_value: Any,
        *,
        constraint: Any = None,
        event_dim: int | None = None,
    ) -> Any:
        ctx = self._get_context()
        fullname = self._pyrox_fullname(name)
        if ctx.active:
            cached = ctx.get(fullname)
            if cached is not None:
                return cached
        kwargs: dict[str, Any] = {}
        if constraint is not None:
            kwargs["constraint"] = constraint
        if event_dim is not None:
            kwargs["event_dim"] = event_dim
        value = numpyro.param(fullname, init_value, **kwargs)
        return ctx.set(fullname, value)

    def pyrox_sample(self, name: str, prior: Any) -> Any:
        ctx = self._get_context()
        fullname = self._pyrox_fullname(name)
        if ctx.active:
            cached = ctx.get(fullname)
            if cached is not None:
                return cached
        resolved = (
            prior(self)
            if callable(prior) and not isinstance(prior, dist.Distribution)
            else prior
        )
        if isinstance(resolved, dist.Distribution):
            value = numpyro.sample(fullname, resolved)
        else:
            value = numpyro.deterministic(fullname, resolved)
        return ctx.set(fullname, value)

    def _teardown(self) -> None:
        """Remove this instance's cached context.

        Class-level registries are keyed by ``id(self)``. Equinox modules
        are typically weak-referenceable, so cleanup normally happens via
        :mod:`weakref.finalize`. Call this explicitly in environments where
        weak refs are not available or when you need deterministic cleanup.
        """
        PyroxModule._contexts.pop(id(self), None)

PyroxParam

Bases: NamedTuple

Lightweight metadata container for a parameter site.

Bundles init value, constraint, and optional event dimension as a single descriptor. This type is a plain value object — higher-level APIs that consume it (for example a future declarative registration helper) live elsewhere; :meth:PyroxModule.pyrox_param takes the fields individually as keyword arguments.

Attributes:

Name Type Description
init_value Any

Initial value, lazy callable, or None to look up an existing param site by name.

constraint Any

NumPyro constraint on the parameter domain; None means unconstrained real.

event_dim int | None

Number of rightmost event dimensions, or None.

Source code in src/pyrox/_core/descriptors.py
class PyroxParam(NamedTuple):
    """Lightweight metadata container for a parameter site.

    Bundles init value, constraint, and optional event dimension as a
    single descriptor. This type is a plain value object — higher-level
    APIs that consume it (for example a future declarative registration
    helper) live elsewhere; :meth:`PyroxModule.pyrox_param` takes the
    fields individually as keyword arguments.

    Attributes:
        init_value: Initial value, lazy callable, or ``None`` to look up
            an existing param site by name.
        constraint: NumPyro constraint on the parameter domain; ``None``
            means unconstrained real.
        event_dim: Number of rightmost event dimensions, or ``None``.
    """

    init_value: Any = None
    constraint: Any = None
    event_dim: int | None = None

PyroxSample dataclass

Lightweight metadata container for a random sample site.

Wraps the prior — either a :class:numpyro.distributions.Distribution or a callable (self) -> Distribution for dependent priors that reference other sampled values on the same module. Like :class:PyroxParam, this is a plain value object; call :meth:PyroxModule.pyrox_sample with the underlying prior directly.

Source code in src/pyrox/_core/descriptors.py
@dataclass(frozen=True)
class PyroxSample:
    """Lightweight metadata container for a random sample site.

    Wraps the prior — either a :class:`numpyro.distributions.Distribution`
    or a callable ``(self) -> Distribution`` for dependent priors that
    reference other sampled values on the same module. Like
    :class:`PyroxParam`, this is a plain value object; call
    :meth:`PyroxModule.pyrox_sample` with the underlying prior directly.
    """

    prior: Any | Callable[[Any], Any]

pyrox_method(fn)

Wrap a method so its body runs inside the module's per-call context.

Apply to __call__ (and any other method that registers pyrox sites) so the _Context cache is active for the duration of the call. The cache is cleared when the outermost decorated call returns.

Source code in src/pyrox/_core/pyrox_module.py
def pyrox_method(fn: Callable[..., Any]) -> Callable[..., Any]:
    """Wrap a method so its body runs inside the module's per-call context.

    Apply to ``__call__`` (and any other method that registers pyrox sites)
    so the ``_Context`` cache is active for the duration of the call. The
    cache is cleared when the outermost decorated call returns.
    """

    @functools.wraps(fn)
    def wrapper(self: PyroxModule, *args: Any, **kwargs: Any) -> Any:
        with self._get_context():
            return fn(self, *args, **kwargs)

    return wrapper

PyroxModule

pyrox._core.PyroxModule

Bases: Module

Equinox module with NumPyro site registration and per-call caching.

Subclasses register deterministic parameters via :meth:pyrox_param and random variables via :meth:pyrox_sample. Wrap the method that drives registration (typically __call__) with :func:pyrox_method so the per-call _Context is active for the duration of the call.

Without the decorator the cache is inactive and duplicate references to the same site within one trace will hit NumPyro's uniqueness check.

Source code in src/pyrox/_core/pyrox_module.py
class PyroxModule(eqx.Module):
    """Equinox module with NumPyro site registration and per-call caching.

    Subclasses register deterministic parameters via :meth:`pyrox_param`
    and random variables via :meth:`pyrox_sample`. Wrap the method that
    drives registration (typically ``__call__``) with :func:`pyrox_method`
    so the per-call ``_Context`` is active for the duration of the call.

    Without the decorator the cache is inactive and duplicate references
    to the same site within one trace will hit NumPyro's uniqueness check.
    """

    _contexts: ClassVar[dict[int, _Context]] = {}

    def _get_context(self) -> _Context:
        key = id(self)
        ctx = PyroxModule._contexts.get(key)
        if ctx is None:
            ctx = _Context()
            PyroxModule._contexts[key] = ctx
            with contextlib.suppress(TypeError):
                weakref.finalize(self, PyroxModule._contexts.pop, key, None)
        return ctx

    def _pyrox_scope_name(self) -> str:
        """Per-instance scope used when building fully-qualified site names.

        Uses an explicit ``pyrox_name`` attribute if the subclass defines
        one (as a field or class variable); otherwise falls back to a
        ``{ClassName}_{id}`` tag so sibling instances of the same class
        register distinct sites within a single trace. The id-based
        fallback is stable within a Python process but not across runs —
        set ``pyrox_name`` explicitly for checkpoint-portable names.
        """
        name = getattr(self, "pyrox_name", None)
        if isinstance(name, str) and name:
            return name
        return f"{type(self).__name__}_{id(self):x}"

    def _pyrox_fullname(self, name: str) -> str:
        return f"{self._pyrox_scope_name()}.{name}"

    def pyrox_param(
        self,
        name: str,
        init_value: Any,
        *,
        constraint: Any = None,
        event_dim: int | None = None,
    ) -> Any:
        ctx = self._get_context()
        fullname = self._pyrox_fullname(name)
        if ctx.active:
            cached = ctx.get(fullname)
            if cached is not None:
                return cached
        kwargs: dict[str, Any] = {}
        if constraint is not None:
            kwargs["constraint"] = constraint
        if event_dim is not None:
            kwargs["event_dim"] = event_dim
        value = numpyro.param(fullname, init_value, **kwargs)
        return ctx.set(fullname, value)

    def pyrox_sample(self, name: str, prior: Any) -> Any:
        ctx = self._get_context()
        fullname = self._pyrox_fullname(name)
        if ctx.active:
            cached = ctx.get(fullname)
            if cached is not None:
                return cached
        resolved = (
            prior(self)
            if callable(prior) and not isinstance(prior, dist.Distribution)
            else prior
        )
        if isinstance(resolved, dist.Distribution):
            value = numpyro.sample(fullname, resolved)
        else:
            value = numpyro.deterministic(fullname, resolved)
        return ctx.set(fullname, value)

    def _teardown(self) -> None:
        """Remove this instance's cached context.

        Class-level registries are keyed by ``id(self)``. Equinox modules
        are typically weak-referenceable, so cleanup normally happens via
        :mod:`weakref.finalize`. Call this explicitly in environments where
        weak refs are not available or when you need deterministic cleanup.
        """
        PyroxModule._contexts.pop(id(self), None)

pyrox_method

pyrox._core.pyrox_method(fn)

Wrap a method so its body runs inside the module's per-call context.

Apply to __call__ (and any other method that registers pyrox sites) so the _Context cache is active for the duration of the call. The cache is cleared when the outermost decorated call returns.

Source code in src/pyrox/_core/pyrox_module.py
def pyrox_method(fn: Callable[..., Any]) -> Callable[..., Any]:
    """Wrap a method so its body runs inside the module's per-call context.

    Apply to ``__call__`` (and any other method that registers pyrox sites)
    so the ``_Context`` cache is active for the duration of the call. The
    cache is cleared when the outermost decorated call returns.
    """

    @functools.wraps(fn)
    def wrapper(self: PyroxModule, *args: Any, **kwargs: Any) -> Any:
        with self._get_context():
            return fn(self, *args, **kwargs)

    return wrapper

PyroxParam

pyrox._core.PyroxParam

Bases: NamedTuple

Lightweight metadata container for a parameter site.

Bundles init value, constraint, and optional event dimension as a single descriptor. This type is a plain value object — higher-level APIs that consume it (for example a future declarative registration helper) live elsewhere; :meth:PyroxModule.pyrox_param takes the fields individually as keyword arguments.

Attributes:

Name Type Description
init_value Any

Initial value, lazy callable, or None to look up an existing param site by name.

constraint Any

NumPyro constraint on the parameter domain; None means unconstrained real.

event_dim int | None

Number of rightmost event dimensions, or None.

Source code in src/pyrox/_core/descriptors.py
class PyroxParam(NamedTuple):
    """Lightweight metadata container for a parameter site.

    Bundles init value, constraint, and optional event dimension as a
    single descriptor. This type is a plain value object — higher-level
    APIs that consume it (for example a future declarative registration
    helper) live elsewhere; :meth:`PyroxModule.pyrox_param` takes the
    fields individually as keyword arguments.

    Attributes:
        init_value: Initial value, lazy callable, or ``None`` to look up
            an existing param site by name.
        constraint: NumPyro constraint on the parameter domain; ``None``
            means unconstrained real.
        event_dim: Number of rightmost event dimensions, or ``None``.
    """

    init_value: Any = None
    constraint: Any = None
    event_dim: int | None = None

PyroxSample

pyrox._core.PyroxSample dataclass

Lightweight metadata container for a random sample site.

Wraps the prior — either a :class:numpyro.distributions.Distribution or a callable (self) -> Distribution for dependent priors that reference other sampled values on the same module. Like :class:PyroxParam, this is a plain value object; call :meth:PyroxModule.pyrox_sample with the underlying prior directly.

Source code in src/pyrox/_core/descriptors.py
@dataclass(frozen=True)
class PyroxSample:
    """Lightweight metadata container for a random sample site.

    Wraps the prior — either a :class:`numpyro.distributions.Distribution`
    or a callable ``(self) -> Distribution`` for dependent priors that
    reference other sampled values on the same module. Like
    :class:`PyroxParam`, this is a plain value object; call
    :meth:`PyroxModule.pyrox_sample` with the underlying prior directly.
    """

    prior: Any | Callable[[Any], Any]

Parameterized

pyrox._core.Parameterized

Bases: PyroxModule

Shared base for modules with priors, constraints, and mode switching.

Subclasses typically declare parameters inside :meth:setup, which is invoked automatically after __init__ completes. Use :meth:register_param to declare a parameter, :meth:set_prior to attach a prior, :meth:autoguide to pick a guide type, and :meth:set_mode to switch between sampling from the prior and sampling from the guide.

Per-instance state (params, priors, guides, mode) lives in a class-level registry keyed by id(self). Cleanup happens via :mod:weakref.finalize when the instance is collected; call :meth:_teardown for explicit cleanup.

Source code in src/pyrox/_core/parameterized.py
class Parameterized(PyroxModule):
    """Shared base for modules with priors, constraints, and mode switching.

    Subclasses typically declare parameters inside :meth:`setup`, which is
    invoked automatically after ``__init__`` completes. Use
    :meth:`register_param` to declare a parameter, :meth:`set_prior` to
    attach a prior, :meth:`autoguide` to pick a guide type, and
    :meth:`set_mode` to switch between sampling from the prior and
    sampling from the guide.

    Per-instance state (params, priors, guides, mode) lives in a
    class-level registry keyed by ``id(self)``. Cleanup happens via
    :mod:`weakref.finalize` when the instance is collected; call
    :meth:`_teardown` for explicit cleanup.
    """

    _registry: ClassVar[dict[int, _State]] = {}

    def __post_init__(self) -> None:
        setup = getattr(self, "setup", None)
        if callable(setup):
            setup()

    def _state(self) -> _State:
        key = id(self)
        state = Parameterized._registry.get(key)
        if state is None:
            state = _State()
            Parameterized._registry[key] = state
            with contextlib.suppress(TypeError):
                weakref.finalize(self, Parameterized._registry.pop, key, None)
        return state

    def _entry(self, name: str) -> _Entry:
        entry = self._state().params.get(name)
        if entry is None:
            raise KeyError(
                f"parameter {name!r} not registered; call register_param first"
            )
        return entry

    def register_param(
        self,
        name: str,
        init_value: Any,
        constraint: Any = None,
    ) -> None:
        self._state().params[name] = _Entry(
            init_value=init_value, constraint=constraint
        )

    def set_prior(self, name: str, prior: Any) -> None:
        self._entry(name).prior = prior

    def autoguide(self, name: str, guide_type: GuideType) -> None:
        if guide_type not in _VALID_GUIDES:
            raise ValueError(
                f"guide_type must be one of {sorted(_VALID_GUIDES)!r}, "
                f"got {guide_type!r}"
            )
        self._entry(name).guide_type = guide_type

    def set_mode(self, mode: Mode) -> None:
        if mode not in ("model", "guide"):
            raise ValueError(f"mode must be 'model' or 'guide', got {mode!r}")
        self._state().mode = mode

    def get_param(self, name: str) -> Any:
        entry = self._entry(name)
        state = self._state()
        if state.mode == "model" and entry.prior is not None:
            return self.pyrox_sample(name, entry.prior)
        if state.mode == "guide" and entry.prior is not None:
            return self._guide_param(name, entry)
        return self.pyrox_param(name, entry.init_value, constraint=entry.constraint)

    def load_pyro_samples(self) -> None:
        for name in list(self._state().params):
            self.get_param(name)

    def _teardown(self) -> None:
        Parameterized._registry.pop(id(self), None)
        super()._teardown()

    def _guide_param(self, name: str, entry: _Entry) -> Any:
        guide = entry.guide_type
        if guide == "delta":
            return self.pyrox_param(name, entry.init_value, constraint=entry.constraint)
        if guide == "normal":
            return self._guide_normal(name, entry)
        raise NotImplementedError(
            f"guide_type {guide!r} is not yet supported at the "
            "get_param level; materialize via a dedicated guide layer."
        )

    def _guide_normal(self, name: str, entry: _Entry) -> Any:
        """Mean-field normal guide in unconstrained space.

        When ``entry.constraint`` is non-trivial, the latent site is a
        ``TransformedDistribution`` wrapping ``Normal(loc, scale)`` with
        the constraint's bijection, so guide draws always land in the
        prior's support. ``loc`` is initialized by the inverse transform
        of ``init_value`` so guide and prior agree at step zero.
        """
        init = jnp.asarray(entry.init_value)
        if _is_real_support(entry.constraint):
            loc = self.pyrox_param(f"{name}_loc", init)
            scale = self.pyrox_param(
                f"{name}_scale",
                jnp.ones_like(init) * 0.1,
                constraint=dist.constraints.positive,
            )
            return self.pyrox_sample(name, dist.Normal(loc, scale))
        transform = _biject_to(entry.constraint)
        loc = self.pyrox_param(f"{name}_loc", transform.inv(init))
        scale = self.pyrox_param(
            f"{name}_scale",
            jnp.ones_like(init) * 0.1,
            constraint=dist.constraints.positive,
        )
        base = dist.Normal(loc, scale)
        return self.pyrox_sample(name, dist.TransformedDistribution(base, transform))