pyrox¶
Probabilistic modeling with Equinox and NumPyro: Bayesian neural networks, Gaussian processes, and composable GP building blocks.
pyrox bridges Equinox modules and NumPyro traces so one module class can host deterministic parameters, random sample sites, priors, guides, and mode switching — without duplicating inference logic. NumPyro owns inference; pyrox just makes modules visible to it.
Package layout¶
pyrox._core— the Equinox-to-NumPyro bridge.PyroxModule,PyroxParam,PyroxSample,Parameterized,pyrox_method.pyrox.gp— Gaussian process building blocks and protocols (Wave 2+).pyrox.nn— Bayesian and uncertainty-aware NN layers (Wave 3+).
Wave 1 ships pyrox._core. GP and NN subpackages are scaffolded as placeholders until their dedicated waves land.
Installation¶
Or with uv:
Three modeling patterns¶
Pattern A — pure Equinox module injected into a NumPyro model¶
import equinox as eqx
import numpyro
import numpyro.distributions as dist
def model(x, y=None):
net = MLP(key=key) # any eqx.Module
W = numpyro.sample("W", prior)
net = eqx.tree_at(lambda m: m.W, net, W)
f = numpyro.deterministic("f", net(x))
numpyro.sample("obs", dist.Normal(f, 0.1), obs=y)
Pattern B — PyroxModule owns its probabilistic semantics¶
import jax.numpy as jnp
import numpyro.distributions as dist
from pyrox._core import PyroxModule, pyrox_method
class BayesianLinear(PyroxModule):
pyrox_name = "BayesianLinear"
in_features: int
out_features: int
@pyrox_method
def __call__(self, x):
W = self.pyrox_sample(
"weight",
dist.Normal(0, 1)
.expand([self.in_features, self.out_features])
.to_event(2),
)
b = self.pyrox_param("bias", jnp.zeros(self.out_features))
return x @ W + b
Pattern C — Parameterized for constrained params, priors, and guides¶
import jax.numpy as jnp
import numpyro.distributions as dist
from pyrox._core import Parameterized, pyrox_method
class RBFKernel(Parameterized):
pyrox_name = "RBFKernel"
def setup(self):
self.register_param(
"variance", jnp.array(1.0),
constraint=dist.constraints.positive,
)
self.register_param(
"lengthscale", jnp.array(1.0),
constraint=dist.constraints.positive,
)
self.set_prior("variance", dist.LogNormal(0.0, 1.0))
self.autoguide("variance", "normal")
@pyrox_method
def __call__(self, X1, X2):
v = self.get_param("variance")
ls = self.get_param("lengthscale")
sq = jnp.sum((X1[:, None] - X2[None, :]) ** 2 / ls ** 2, axis=-1)
return v * jnp.exp(-0.5 * sq)
Switch kernel.set_mode("guide") to draw variational params instead of sampling the prior. "normal" autoguides respect the registered constraint via TransformedDistribution.