Skip to content

NN API

The pyrox.nn subpackage ships uncertainty-aware neural network layers in three families:

  1. Dense / Bayesian-linear layers (pyrox.nn._layers) — twelve layers covering reparameterization, Flipout, NCP, MC-Dropout, and several random-feature variants.
  2. Bayesian Neural Field stack (pyrox.nn._bnf) — five layers that together implement the BNF architecture (Saad et al., Nat. Comms. 2024).
  3. Pure-JAX feature helpers (pyrox.nn._features) — pandas-free building blocks the BNF layers wrap.

Dense / Bayesian-linear layers

pyrox.nn.DenseReparameterization

Bases: PyroxModule

Bayesian dense layer via the reparameterization trick.

Samples weight and bias from learned Gaussian posteriors at every forward pass. Registers NumPyro sample sites so the KL between the variational posterior and the prior is tracked by the ELBO.

.. math::

W \sim \mathcal{N}(\mu_W, \sigma_W^2), \quad
b \sim \mathcal{N}(\mu_b, \sigma_b^2), \quad
y = x W + b.

Attributes:

Name Type Description
in_features int

Input dimension.

out_features int

Output dimension.

bias bool

Whether to include a bias term.

prior_scale float

Scale of the isotropic Gaussian prior on weights and bias. The prior mean is zero.

pyrox_name str | None

Explicit scope name for NumPyro site registration.

Source code in src/pyrox/nn/_layers.py
class DenseReparameterization(PyroxModule):
    r"""Bayesian dense layer via the reparameterization trick.

    Samples weight and bias from learned Gaussian posteriors at every
    forward pass. Registers NumPyro sample sites so the KL between the
    variational posterior and the prior is tracked by the ELBO.

    .. math::

        W \sim \mathcal{N}(\mu_W, \sigma_W^2), \quad
        b \sim \mathcal{N}(\mu_b, \sigma_b^2), \quad
        y = x W + b.

    Attributes:
        in_features: Input dimension.
        out_features: Output dimension.
        bias: Whether to include a bias term.
        prior_scale: Scale of the isotropic Gaussian prior on weights
            and bias. The prior mean is zero.
        pyrox_name: Explicit scope name for NumPyro site registration.
    """

    in_features: int
    out_features: int
    bias: bool = True
    prior_scale: float = 1.0
    pyrox_name: str | None = None

    @pyrox_method
    def __call__(self, x: Float[Array, "*batch D_in"]) -> Float[Array, "*batch D_out"]:
        prior_w = dist.Normal(
            jnp.zeros((self.in_features, self.out_features)),
            self.prior_scale,
        ).to_event(2)
        W = self.pyrox_sample("weight", prior_w)
        out = x @ W
        if self.bias:
            prior_b = dist.Normal(
                jnp.zeros(self.out_features), self.prior_scale
            ).to_event(1)
            b = self.pyrox_sample("bias", prior_b)
            out = out + b
        return out

pyrox.nn.DenseFlipout

Bases: PyroxModule

Bayesian dense layer with Flipout sign-flip structure.

Samples weight from the prior and applies per-example Rademacher sign flips to the weight perturbation (Wen et al., 2018). Under a NumPyro guide that learns the posterior mean, the sign flips decorrelate gradient estimates across minibatch examples.

In model mode (no guide) this is equivalent to :class:DenseReparameterization — the Flipout variance reduction activates when a guide provides a posterior centered at a learned mean.

Attributes:

Name Type Description
in_features int

Input dimension.

out_features int

Output dimension.

bias bool

Whether to include a bias term.

prior_scale float

Scale of the isotropic Gaussian prior.

pyrox_name str | None

Explicit scope name for NumPyro site registration.

Source code in src/pyrox/nn/_layers.py
class DenseFlipout(PyroxModule):
    r"""Bayesian dense layer with Flipout sign-flip structure.

    Samples weight from the prior and applies per-example Rademacher
    sign flips to the weight perturbation (Wen et al., 2018). Under a
    NumPyro guide that learns the posterior mean, the sign flips
    decorrelate gradient estimates across minibatch examples.

    In model mode (no guide) this is equivalent to
    :class:`DenseReparameterization` — the Flipout variance reduction
    activates when a guide provides a posterior centered at a learned
    mean.

    Attributes:
        in_features: Input dimension.
        out_features: Output dimension.
        bias: Whether to include a bias term.
        prior_scale: Scale of the isotropic Gaussian prior.
        pyrox_name: Explicit scope name for NumPyro site registration.
    """

    in_features: int
    out_features: int
    bias: bool = True
    prior_scale: float = 1.0
    pyrox_name: str | None = None

    @pyrox_method
    def __call__(self, x: Float[Array, "*batch D_in"]) -> Float[Array, "*batch D_out"]:
        prior_w = dist.Normal(
            jnp.zeros((self.in_features, self.out_features)),
            self.prior_scale,
        ).to_event(2)
        W = self.pyrox_sample("weight", prior_w)
        out = x @ W

        if self.bias:
            prior_b = dist.Normal(
                jnp.zeros(self.out_features), self.prior_scale
            ).to_event(1)
            b = self.pyrox_sample("bias", prior_b)
            out = out + b
        return out

pyrox.nn.DenseVariational

Bases: PyroxModule

Dense layer with a user-supplied prior factory.

Provides flexibility over the weight prior by accepting a callable that builds the prior distribution given the layer shape. The model samples from the prior; the posterior is handled by a NumPyro guide (e.g., AutoNormal).

Attributes:

Name Type Description
in_features int

Input dimension.

out_features int

Output dimension.

make_prior Callable[..., Any]

Callable (in_features, out_features) -> Distribution.

bias bool

Whether to include a bias term.

pyrox_name str | None

Explicit scope name for NumPyro site registration.

Source code in src/pyrox/nn/_layers.py
class DenseVariational(PyroxModule):
    r"""Dense layer with a user-supplied prior factory.

    Provides flexibility over the weight prior by accepting a callable
    that builds the prior distribution given the layer shape. The
    model samples from the prior; the posterior is handled by a NumPyro
    guide (e.g., ``AutoNormal``).

    Attributes:
        in_features: Input dimension.
        out_features: Output dimension.
        make_prior: Callable ``(in_features, out_features) -> Distribution``.
        bias: Whether to include a bias term.
        pyrox_name: Explicit scope name for NumPyro site registration.
    """

    in_features: int
    out_features: int
    make_prior: Callable[..., Any] = eqx.field(static=True)
    bias: bool = True
    pyrox_name: str | None = None

    @pyrox_method
    def __call__(self, x: Float[Array, "*batch D_in"]) -> Float[Array, "*batch D_out"]:
        prior = self.make_prior(self.in_features, self.out_features)
        W = self.pyrox_sample("weight", prior)
        out = x @ W
        if self.bias:
            b = self.pyrox_sample(
                "bias",
                dist.Normal(jnp.zeros(self.out_features), 1.0).to_event(1),
            )
            out = out + b
        return out

pyrox.nn.MCDropout

Bases: Module

Always-on dropout for Monte Carlo uncertainty estimation.

Unlike standard dropout, :class:MCDropout stays active at inference time — repeated forward passes with different keys produce a distribution of outputs whose spread approximates predictive uncertainty (Gal & Ghahramani, 2016).

Not a :class:PyroxModule — no NumPyro sites are registered. The stochasticity comes from the explicit PRNG key argument.

Attributes:

Name Type Description
rate float

Dropout probability in :math:(0, 1).

Source code in src/pyrox/nn/_layers.py
class MCDropout(eqx.Module):
    """Always-on dropout for Monte Carlo uncertainty estimation.

    Unlike standard dropout, :class:`MCDropout` stays active at
    inference time — repeated forward passes with different keys
    produce a distribution of outputs whose spread approximates
    predictive uncertainty (Gal & Ghahramani, 2016).

    Not a :class:`PyroxModule` — no NumPyro sites are registered.
    The stochasticity comes from the explicit PRNG ``key`` argument.

    Attributes:
        rate: Dropout probability in :math:`(0, 1)`.
    """

    rate: float = 0.5

    def __post_init__(self) -> None:
        if not 0.0 <= self.rate < 1.0:
            raise ValueError(f"rate must be in [0, 1), got {self.rate}.")

    def __call__(
        self,
        x: Float[Array, ...],
        *,
        key: Array,
    ) -> Float[Array, ...]:
        """Apply dropout, scaling survivors by ``1 / (1 - rate)``."""
        keep = jax.random.bernoulli(key, 1.0 - self.rate, x.shape)
        return jnp.where(keep, x / (1.0 - self.rate), 0.0)

__call__(x, *, key)

Apply dropout, scaling survivors by 1 / (1 - rate).

Source code in src/pyrox/nn/_layers.py
def __call__(
    self,
    x: Float[Array, ...],
    *,
    key: Array,
) -> Float[Array, ...]:
    """Apply dropout, scaling survivors by ``1 / (1 - rate)``."""
    keep = jax.random.bernoulli(key, 1.0 - self.rate, x.shape)
    return jnp.where(keep, x / (1.0 - self.rate), 0.0)

pyrox.nn.DenseNCP

Bases: PyroxModule

Noise Contrastive Prior dense layer (Hafner et al., 2019).

Decomposes a dense layer into a prior-regularized backbone plus a scaled stochastic perturbation:

.. math::

y = \underbrace{x W_d + b_d}_{\text{backbone}}
  + \underbrace{\sigma \cdot (x W_s + b_s)}_{\text{perturbation}},

where all weights are pyrox_sample sites with Gaussian priors and :math:\sigma has a LogNormal prior. The backbone carries the bulk of the signal; the perturbation branch adds calibrated uncertainty that can be trained via a noise contrastive objective.

Attributes:

Name Type Description
in_features int

Input dimension.

out_features int

Output dimension.

init_scale float

Initial value for the perturbation scale :math:\sigma.

pyrox_name str | None

Explicit scope name for NumPyro site registration.

Source code in src/pyrox/nn/_layers.py
class DenseNCP(PyroxModule):
    r"""Noise Contrastive Prior dense layer (Hafner et al., 2019).

    Decomposes a dense layer into a prior-regularized backbone plus a
    scaled stochastic perturbation:

    .. math::

        y = \underbrace{x W_d + b_d}_{\text{backbone}}
          + \underbrace{\sigma \cdot (x W_s + b_s)}_{\text{perturbation}},

    where all weights are ``pyrox_sample`` sites with Gaussian priors
    and :math:`\sigma` has a ``LogNormal`` prior. The backbone carries
    the bulk of the signal; the perturbation branch adds calibrated
    uncertainty that can be trained via a noise contrastive objective.

    Attributes:
        in_features: Input dimension.
        out_features: Output dimension.
        init_scale: Initial value for the perturbation scale
            :math:`\sigma`.
        pyrox_name: Explicit scope name for NumPyro site registration.
    """

    in_features: int
    out_features: int
    init_scale: float = 1.0
    pyrox_name: str | None = None

    @pyrox_method
    def __call__(self, x: Float[Array, "*batch D_in"]) -> Float[Array, "*batch D_out"]:
        W_d = self.pyrox_sample(
            "weight_det",
            dist.Normal(jnp.zeros((self.in_features, self.out_features)), 1.0).to_event(
                2
            ),
        )
        b_d = self.pyrox_sample(
            "bias_det",
            dist.Normal(jnp.zeros(self.out_features), 1.0).to_event(1),
        )
        det = x @ W_d + b_d

        W_s = self.pyrox_sample(
            "weight_stoch",
            dist.Normal(jnp.zeros((self.in_features, self.out_features)), 1.0).to_event(
                2
            ),
        )
        b_s = self.pyrox_sample(
            "bias_stoch",
            dist.Normal(jnp.zeros(self.out_features), 1.0).to_event(1),
        )
        scale = self.pyrox_sample(
            "scale",
            dist.LogNormal(jnp.log(jnp.maximum(jnp.array(self.init_scale), 1e-6)), 1.0),
        )
        stoch = scale * (x @ W_s + b_s)

        return det + stoch

pyrox.nn.NCPContinuousPerturb

Bases: Module

Input perturbation for the Noise Contrastive Prior pattern.

Adds Gaussian noise scaled by a learned positive scale to the input:

.. math::

\tilde{x} = x + \sigma \epsilon, \qquad
\epsilon \sim \mathcal{N}(0, I).

Place before a deterministic network to inject input uncertainty; pair with :class:DenseNCP at the output for the full NCP architecture (Hafner et al., 2019).

Not a :class:PyroxModule — stochasticity comes from the explicit PRNG key.

Attributes:

Name Type Description
scale float | Float[Array, '']

Perturbation scale :math:\sigma.

Source code in src/pyrox/nn/_layers.py
class NCPContinuousPerturb(eqx.Module):
    r"""Input perturbation for the Noise Contrastive Prior pattern.

    Adds Gaussian noise scaled by a learned positive scale to the
    input:

    .. math::

        \tilde{x} = x + \sigma \epsilon, \qquad
        \epsilon \sim \mathcal{N}(0, I).

    Place before a deterministic network to inject input uncertainty;
    pair with :class:`DenseNCP` at the output for the full NCP
    architecture (Hafner et al., 2019).

    Not a :class:`PyroxModule` — stochasticity comes from the
    explicit PRNG ``key``.

    Attributes:
        scale: Perturbation scale :math:`\sigma`.
    """

    scale: float | Float[Array, ""] = 1.0

    def __call__(
        self,
        x: Float[Array, "*batch D"],
        *,
        key: Array,
    ) -> Float[Array, "*batch D"]:
        eps = jax.random.normal(key, x.shape, dtype=x.dtype)
        return x + self.scale * eps

pyrox.nn.RBFFourierFeatures

Bases: PyroxModule

SSGP-style RFF layer with RBF spectral density.

Both the spectral frequencies :math:W and the lengthscale :math:\ell are pyrox_sample sites — :math:W has a standard normal prior (the RBF spectral density) and :math:\ell has a LogNormal prior. Under SVI, the guide learns a posterior over both; under a seed handler, they are drawn from the prior.

Attributes:

Name Type Description
in_features int

Input dimension.

n_features int

Number of frequency pairs (output dim 2 * n_features).

init_lengthscale float

Prior location for the lengthscale.

pyrox_name str | None

Explicit scope name for NumPyro site registration.

Source code in src/pyrox/nn/_layers.py
class RBFFourierFeatures(PyroxModule):
    r"""SSGP-style RFF layer with RBF spectral density.

    Both the spectral frequencies :math:`W` and the lengthscale
    :math:`\ell` are ``pyrox_sample`` sites — :math:`W` has a
    standard normal prior (the RBF spectral density) and :math:`\ell`
    has a ``LogNormal`` prior. Under SVI, the guide learns a posterior
    over both; under a seed handler, they are drawn from the prior.

    Attributes:
        in_features: Input dimension.
        n_features: Number of frequency pairs (output dim
            ``2 * n_features``).
        init_lengthscale: Prior location for the lengthscale.
        pyrox_name: Explicit scope name for NumPyro site registration.
    """

    in_features: int = eqx.field(static=True)
    n_features: int = eqx.field(static=True)
    init_lengthscale: float = 1.0
    pyrox_name: str | None = None

    @classmethod
    def init(
        cls,
        in_features: int,
        n_features: int,
        *,
        lengthscale: float = 1.0,
    ) -> RBFFourierFeatures:
        if lengthscale <= 0:
            raise ValueError(f"lengthscale must be > 0, got {lengthscale}.")
        return cls(
            in_features=in_features,
            n_features=n_features,
            init_lengthscale=lengthscale,
        )

    @pyrox_method
    def __call__(self, x: Float[Array, "*batch D_in"]) -> Float[Array, "*batch D_rff"]:
        W = self.pyrox_sample(
            "W",
            dist.Normal(0.0, 1.0)
            .expand([self.in_features, self.n_features])
            .to_event(2),
        )
        ls = self.pyrox_sample(
            "lengthscale",
            dist.LogNormal(jnp.log(jnp.asarray(self.init_lengthscale)), 1.0),
        )
        return _rff_forward(W, ls, self.n_features, x)

pyrox.nn.RBFCosineFeatures

Bases: PyroxModule

Cosine-bias variant of random Fourier features for the RBF kernel.

Uses the single-cosine feature map with a bias term:

.. math::

\phi(x) = \sqrt{2 / D}\,\cos(x W / \ell + b)

where :math:W \sim \mathcal{N}(0, I) and :math:b \sim \mathrm{Uniform}(0, 2\pi). This variant produces n_features-dimensional output (half the dimension of the [cos, sin] variant in :class:RBFFourierFeatures) and is commonly used in Random Kitchen Sinks implementations.

All parameters (:math:W, :math:b, :math:\ell) are pyrox_sample sites.

Attributes:

Name Type Description
in_features int

Input dimension.

n_features int

Number of random features (= output dimension).

init_lengthscale float

Prior location for the lengthscale.

pyrox_name str | None

Explicit scope name for NumPyro site registration.

Source code in src/pyrox/nn/_layers.py
class RBFCosineFeatures(PyroxModule):
    r"""Cosine-bias variant of random Fourier features for the RBF kernel.

    Uses the single-cosine feature map with a bias term:

    .. math::

        \phi(x) = \sqrt{2 / D}\,\cos(x W / \ell + b)

    where :math:`W \sim \mathcal{N}(0, I)` and
    :math:`b \sim \mathrm{Uniform}(0, 2\pi)`. This variant produces
    ``n_features``-dimensional output (half the dimension of the
    ``[cos, sin]`` variant in :class:`RBFFourierFeatures`) and is
    commonly used in Random Kitchen Sinks implementations.

    All parameters (:math:`W`, :math:`b`, :math:`\ell`) are
    ``pyrox_sample`` sites.

    Attributes:
        in_features: Input dimension.
        n_features: Number of random features (= output dimension).
        init_lengthscale: Prior location for the lengthscale.
        pyrox_name: Explicit scope name for NumPyro site registration.
    """

    in_features: int = eqx.field(static=True)
    n_features: int = eqx.field(static=True)
    init_lengthscale: float = 1.0
    pyrox_name: str | None = None

    @classmethod
    def init(
        cls,
        in_features: int,
        n_features: int,
        *,
        lengthscale: float = 1.0,
    ) -> RBFCosineFeatures:
        return cls(
            in_features=in_features,
            n_features=n_features,
            init_lengthscale=lengthscale,
        )

    @pyrox_method
    def __call__(self, x: Float[Array, "*batch D_in"]) -> Float[Array, "*batch D_rff"]:
        W = self.pyrox_sample(
            "W",
            dist.Normal(0.0, 1.0)
            .expand([self.in_features, self.n_features])
            .to_event(2),
        )
        b = self.pyrox_sample(
            "b",
            dist.Uniform(0.0, 2.0 * jnp.pi).expand([self.n_features]).to_event(1),
        )
        ls = self.pyrox_sample(
            "lengthscale",
            dist.LogNormal(jnp.log(jnp.asarray(self.init_lengthscale)), 1.0),
        )
        z = x @ W / ls + b
        return jnp.sqrt(2.0 / self.n_features) * jnp.cos(z)

pyrox.nn.MaternFourierFeatures

Bases: PyroxModule

SSGP-style RFF layer with Matern spectral density.

Spectral frequencies :math:W have a StudentT(df=2\nu) prior (the Matern spectral density). The smoothness :math:\nu controls the regularity: nu=0.5 (Laplace), nu=1.5 (Matern-3/2), nu=2.5 (Matern-5/2).

Attributes:

Name Type Description
in_features int

Input dimension.

n_features int

Number of frequency pairs.

nu float

Smoothness parameter :math:\nu.

init_lengthscale float

Prior location for the lengthscale.

pyrox_name str | None

Explicit scope name for NumPyro site registration.

Source code in src/pyrox/nn/_layers.py
class MaternFourierFeatures(PyroxModule):
    r"""SSGP-style RFF layer with Matern spectral density.

    Spectral frequencies :math:`W` have a ``StudentT(df=2\nu)`` prior
    (the Matern spectral density). The smoothness :math:`\nu` controls
    the regularity: ``nu=0.5`` (Laplace), ``nu=1.5`` (Matern-3/2),
    ``nu=2.5`` (Matern-5/2).

    Attributes:
        in_features: Input dimension.
        n_features: Number of frequency pairs.
        nu: Smoothness parameter :math:`\nu`.
        init_lengthscale: Prior location for the lengthscale.
        pyrox_name: Explicit scope name for NumPyro site registration.
    """

    in_features: int = eqx.field(static=True)
    n_features: int = eqx.field(static=True)
    nu: float = eqx.field(static=True, default=1.5)
    init_lengthscale: float = 1.0
    pyrox_name: str | None = None

    @classmethod
    def init(
        cls,
        in_features: int,
        n_features: int,
        *,
        nu: float = 1.5,
        lengthscale: float = 1.0,
    ) -> MaternFourierFeatures:
        if lengthscale <= 0:
            raise ValueError(f"lengthscale must be > 0, got {lengthscale}.")
        if nu <= 0:
            raise ValueError(f"nu must be > 0, got {nu}.")
        return cls(
            in_features=in_features,
            n_features=n_features,
            nu=nu,
            init_lengthscale=lengthscale,
        )

    @pyrox_method
    def __call__(self, x: Float[Array, "*batch D_in"]) -> Float[Array, "*batch D_rff"]:
        W = self.pyrox_sample(
            "W",
            dist.StudentT(df=2.0 * self.nu, loc=0.0, scale=1.0)
            .expand([self.in_features, self.n_features])
            .to_event(2),
        )
        ls = self.pyrox_sample(
            "lengthscale",
            dist.LogNormal(jnp.log(jnp.asarray(self.init_lengthscale)), 1.0),
        )
        return _rff_forward(W, ls, self.n_features, x)

pyrox.nn.LaplaceFourierFeatures

Bases: PyroxModule

SSGP-style RFF layer with Laplace (Matern-1/2) spectral density.

Spectral frequencies :math:W have a Cauchy prior (Student-t with df = 1).

Attributes:

Name Type Description
in_features int

Input dimension.

n_features int

Number of frequency pairs.

init_lengthscale float

Prior location for the lengthscale.

pyrox_name str | None

Explicit scope name for NumPyro site registration.

Source code in src/pyrox/nn/_layers.py
class LaplaceFourierFeatures(PyroxModule):
    r"""SSGP-style RFF layer with Laplace (Matern-1/2) spectral density.

    Spectral frequencies :math:`W` have a ``Cauchy`` prior (Student-t
    with ``df = 1``).

    Attributes:
        in_features: Input dimension.
        n_features: Number of frequency pairs.
        init_lengthscale: Prior location for the lengthscale.
        pyrox_name: Explicit scope name for NumPyro site registration.
    """

    in_features: int = eqx.field(static=True)
    n_features: int = eqx.field(static=True)
    init_lengthscale: float = 1.0
    pyrox_name: str | None = None

    @classmethod
    def init(
        cls,
        in_features: int,
        n_features: int,
        *,
        lengthscale: float = 1.0,
    ) -> LaplaceFourierFeatures:
        return cls(
            in_features=in_features,
            n_features=n_features,
            init_lengthscale=lengthscale,
        )

    @pyrox_method
    def __call__(self, x: Float[Array, "*batch D_in"]) -> Float[Array, "*batch D_rff"]:
        W = self.pyrox_sample(
            "W",
            dist.StudentT(df=1.0, loc=0.0, scale=1.0)
            .expand([self.in_features, self.n_features])
            .to_event(2),
        )
        ls = self.pyrox_sample(
            "lengthscale",
            dist.LogNormal(jnp.log(jnp.asarray(self.init_lengthscale)), 1.0),
        )
        return _rff_forward(W, ls, self.n_features, x)

pyrox.nn.ArcCosineFourierFeatures

Bases: PyroxModule

Random features for the arc-cosine kernel (Cho & Saul, 2009).

The arc-cosine kernel of order :math:p corresponds to an infinite-width single-layer ReLU network. The random feature map is:

.. math::

\phi(x) = \sqrt{2 / D}\,\max(0,\, x W / \ell)^p

where :math:W \sim \mathcal{N}(0, I).

order=0 gives the Heaviside (step) feature; order=1 gives the ReLU feature (the most common); order=2 gives the squared ReLU feature.

Attributes:

Name Type Description
in_features int

Input dimension.

n_features int

Number of random features (= output dimension).

order int

Kernel order (0, 1, or 2).

init_lengthscale float

Prior location for the lengthscale.

pyrox_name str | None

Explicit scope name for NumPyro site registration.

Source code in src/pyrox/nn/_layers.py
class ArcCosineFourierFeatures(PyroxModule):
    r"""Random features for the arc-cosine kernel (Cho & Saul, 2009).

    The arc-cosine kernel of order :math:`p` corresponds to an
    infinite-width single-layer ReLU network. The random feature map
    is:

    .. math::

        \phi(x) = \sqrt{2 / D}\,\max(0,\, x W / \ell)^p

    where :math:`W \sim \mathcal{N}(0, I)`.

    ``order=0`` gives the Heaviside (step) feature; ``order=1`` gives
    the ReLU feature (the most common); ``order=2`` gives the squared
    ReLU feature.

    Attributes:
        in_features: Input dimension.
        n_features: Number of random features (= output dimension).
        order: Kernel order (0, 1, or 2).
        init_lengthscale: Prior location for the lengthscale.
        pyrox_name: Explicit scope name for NumPyro site registration.
    """

    in_features: int = eqx.field(static=True)
    n_features: int = eqx.field(static=True)
    order: int = eqx.field(static=True, default=1)
    init_lengthscale: float = 1.0
    pyrox_name: str | None = None

    @classmethod
    def init(
        cls,
        in_features: int,
        n_features: int,
        *,
        order: int = 1,
        lengthscale: float = 1.0,
    ) -> ArcCosineFourierFeatures:
        return cls(
            in_features=in_features,
            n_features=n_features,
            order=order,
            init_lengthscale=lengthscale,
        )

    @pyrox_method
    def __call__(self, x: Float[Array, "*batch D_in"]) -> Float[Array, "*batch D_rff"]:
        W = self.pyrox_sample(
            "W",
            dist.Normal(0.0, 1.0)
            .expand([self.in_features, self.n_features])
            .to_event(2),
        )
        ls = self.pyrox_sample(
            "lengthscale",
            dist.LogNormal(jnp.log(jnp.asarray(self.init_lengthscale)), 1.0),
        )
        z = x @ W / ls
        if self.order == 0:
            h = (z > 0.0).astype(x.dtype)
        else:
            h = jnp.maximum(z, 0.0) ** self.order
        return jnp.sqrt(2.0 / self.n_features) * h

pyrox.nn.RandomKitchenSinks

Bases: PyroxModule

Random Kitchen Sinks: RFF + a learned linear head.

Composes any RFF layer (:class:RBFFourierFeatures, :class:MaternFourierFeatures, :class:LaplaceFourierFeatures) with a trainable linear projection:

.. math::

y = \phi(x)\, \beta + b

The linear head (beta, bias) is registered via pyrox_sample with Normal priors.

Attributes:

Name Type Description
rff RBFFourierFeatures | MaternFourierFeatures | LaplaceFourierFeatures

The underlying RFF feature layer.

init_beta Float[Array, 'D_rff D_out']

Initial linear weights.

init_bias Float[Array, ' D_out']

Initial bias vector.

pyrox_name str | None

Explicit scope name for NumPyro site registration.

Source code in src/pyrox/nn/_layers.py
class RandomKitchenSinks(PyroxModule):
    r"""Random Kitchen Sinks: RFF + a learned linear head.

    Composes any RFF layer (:class:`RBFFourierFeatures`,
    :class:`MaternFourierFeatures`, :class:`LaplaceFourierFeatures`)
    with a trainable linear projection:

    .. math::

        y = \phi(x)\, \beta + b

    The linear head (``beta``, ``bias``) is registered via
    ``pyrox_sample`` with ``Normal`` priors.

    Attributes:
        rff: The underlying RFF feature layer.
        init_beta: Initial linear weights.
        init_bias: Initial bias vector.
        pyrox_name: Explicit scope name for NumPyro site registration.
    """

    rff: RBFFourierFeatures | MaternFourierFeatures | LaplaceFourierFeatures
    init_beta: Float[Array, "D_rff D_out"]
    init_bias: Float[Array, " D_out"]
    pyrox_name: str | None = None

    @classmethod
    def init(
        cls,
        rff: RBFFourierFeatures | MaternFourierFeatures | LaplaceFourierFeatures,
        out_features: int,
    ) -> RandomKitchenSinks:
        """Construct from a pre-built RFF layer with zero-initialized head."""
        beta = jnp.zeros((2 * rff.n_features, out_features))
        bias = jnp.zeros(out_features)
        return cls(rff=rff, init_beta=beta, init_bias=bias)

    @pyrox_method
    def __call__(self, x: Float[Array, "*batch D_in"]) -> Float[Array, "*batch D_out"]:
        phi = self.rff(x)
        beta = self.pyrox_sample(
            "beta",
            dist.Normal(self.init_beta, 1.0).to_event(2),
        )
        bias = self.pyrox_sample(
            "bias",
            dist.Normal(self.init_bias, 1.0).to_event(1),
        )
        return phi @ beta + bias

init(rff, out_features) classmethod

Construct from a pre-built RFF layer with zero-initialized head.

Source code in src/pyrox/nn/_layers.py
@classmethod
def init(
    cls,
    rff: RBFFourierFeatures | MaternFourierFeatures | LaplaceFourierFeatures,
    out_features: int,
) -> RandomKitchenSinks:
    """Construct from a pre-built RFF layer with zero-initialized head."""
    beta = jnp.zeros((2 * rff.n_features, out_features))
    bias = jnp.zeros(out_features)
    return cls(rff=rff, init_beta=beta, init_bias=bias)

Wave-4 spectral layers (#41)

pyrox.nn.VariationalFourierFeatures

Bases: PyroxModule

VSSGP — RFF with a learnable variational posterior over frequencies.

Standard RFF (e.g. :class:RBFFourierFeatures) treats the spectral frequencies :math:W as a frozen prior draw; VSSGP (Gal & Turner, 2015) treats :math:W as a latent with a learnable mean-field posterior, recovering spectral uncertainty on top of the feature-space uncertainty.

Prior: :math:p(W) = \mathcal{N}(0, I) (RBF spectral density in lengthscale-1 units). The lengthscale is itself a sampled site (LogNormal(log init_lengthscale, 1)) so that frequencies are rescaled to the physical kernel.

Under SVI, attach an :class:~numpyro.infer.autoguide.AutoNormal to learn the posterior on W; under prior-only seeds, behaves identically to :class:RBFFourierFeatures.

Attributes:

Name Type Description
in_features int

Input dimension :math:D.

n_features int

Number of frequency pairs (output dim 2 * n_features).

init_lengthscale float

Prior location for the kernel lengthscale.

pyrox_name str | None

Explicit scope name for NumPyro site registration.

Source code in src/pyrox/nn/_layers.py
class VariationalFourierFeatures(PyroxModule):
    r"""VSSGP — RFF with a learnable variational posterior over frequencies.

    Standard RFF (e.g. :class:`RBFFourierFeatures`) treats the spectral
    frequencies :math:`W` as a frozen prior draw; VSSGP (Gal & Turner,
    2015) treats :math:`W` as a latent with a learnable mean-field
    posterior, recovering spectral *uncertainty* on top of the
    feature-space uncertainty.

    Prior: :math:`p(W) = \mathcal{N}(0, I)` (RBF spectral density in
    lengthscale-1 units). The lengthscale is itself a sampled site
    (``LogNormal(log init_lengthscale, 1)``) so that frequencies are
    rescaled to the physical kernel.

    Under SVI, attach an :class:`~numpyro.infer.autoguide.AutoNormal` to
    learn the posterior on ``W``; under prior-only seeds, behaves
    identically to :class:`RBFFourierFeatures`.

    Attributes:
        in_features: Input dimension :math:`D`.
        n_features: Number of frequency pairs (output dim ``2 * n_features``).
        init_lengthscale: Prior location for the kernel lengthscale.
        pyrox_name: Explicit scope name for NumPyro site registration.
    """

    in_features: int = eqx.field(static=True)
    n_features: int = eqx.field(static=True)
    init_lengthscale: float = 1.0
    pyrox_name: str | None = None

    @classmethod
    def init(
        cls,
        in_features: int,
        n_features: int,
        *,
        lengthscale: float = 1.0,
    ) -> VariationalFourierFeatures:
        if lengthscale <= 0:
            raise ValueError(f"lengthscale must be > 0, got {lengthscale}.")
        return cls(
            in_features=in_features,
            n_features=n_features,
            init_lengthscale=lengthscale,
        )

    @pyrox_method
    def __call__(self, x: Float[Array, "*batch D_in"]) -> Float[Array, "*batch D_rff"]:
        # Same prior as RBFFourierFeatures — the *posterior* is what differs
        # under SVI: an attached AutoGuide learns q(W) instead of forcing W
        # to its prior draw.
        W = self.pyrox_sample(
            "W",
            dist.Normal(0.0, 1.0)
            .expand([self.in_features, self.n_features])
            .to_event(2),
        )
        ls = self.pyrox_sample(
            "lengthscale",
            dist.LogNormal(jnp.log(jnp.asarray(self.init_lengthscale)), 1.0),
        )
        return _rff_forward(W, ls, self.n_features, x)

pyrox.nn.OrthogonalRandomFeatures

Bases: Module

Orthogonal Random Features (Yu et al., 2016) — variance-reduced RFF.

Frequencies are drawn from blocks of Haar-orthogonal matrices scaled by independent chi-distributed magnitudes, giving the same RBF kernel approximation as plain :class:RBFFourierFeatures in expectation but with provably lower variance for finite n_features.

Frozen at construction time — no priors, no SVI on W. The frequency matrix is built once from a key and stored as a static array.

Attributes:

Name Type Description
in_features int

Input dimension :math:D.

n_features int

Number of feature pairs. Must satisfy n_features % in_features == 0 so that ORF blocks tile cleanly.

lengthscale Float[Array, '']

Fixed kernel lengthscale (no prior; pass a value).

W Float[Array, 'D_in D_orf']

Pre-built frequency matrix of shape (in_features, 2 * n_features // 2).

Note

For learnable lengthscale or full Bayesian treatment of the frequencies, prefer :class:VariationalFourierFeatures.

Source code in src/pyrox/nn/_layers.py
class OrthogonalRandomFeatures(eqx.Module):
    r"""Orthogonal Random Features (Yu et al., 2016) — variance-reduced RFF.

    Frequencies are drawn from blocks of Haar-orthogonal matrices scaled by
    independent chi-distributed magnitudes, giving the same RBF kernel
    approximation as plain :class:`RBFFourierFeatures` *in expectation* but
    with provably lower variance for finite ``n_features``.

    Frozen at construction time — no priors, no SVI on ``W``. The frequency
    matrix is built once from a ``key`` and stored as a static array.

    Attributes:
        in_features: Input dimension :math:`D`.
        n_features: Number of feature pairs. Must satisfy
            ``n_features % in_features == 0`` so that ORF blocks tile cleanly.
        lengthscale: Fixed kernel lengthscale (no prior; pass a value).
        W: Pre-built frequency matrix of shape ``(in_features, 2 * n_features // 2)``.

    Note:
        For learnable lengthscale or full Bayesian treatment of the
        frequencies, prefer :class:`VariationalFourierFeatures`.
    """

    in_features: int = eqx.field(static=True)
    n_features: int = eqx.field(static=True)
    lengthscale: Float[Array, ""]
    W: Float[Array, "D_in D_orf"]

    @classmethod
    def init(
        cls,
        in_features: int,
        n_features: int,
        *,
        key: jax.Array,
        lengthscale: float = 1.0,
    ) -> OrthogonalRandomFeatures:
        if lengthscale <= 0:
            raise ValueError(f"lengthscale must be > 0, got {lengthscale}.")
        if n_features % in_features != 0:
            raise ValueError(
                f"n_features ({n_features}) must be divisible by in_features "
                f"({in_features}) so ORF blocks tile cleanly."
            )
        n_blocks = n_features // in_features
        W = _orthogonal_blocks(in_features, n_blocks, key=key)
        return cls(
            in_features=in_features,
            n_features=n_features,
            lengthscale=jnp.asarray(lengthscale),
            W=W,
        )

    def __call__(self, x: Float[Array, "*batch D_in"]) -> Float[Array, "*batch D_rff"]:
        return _rff_forward(self.W, self.lengthscale, self.n_features, x)

pyrox.nn.HSGPFeatures

Bases: PyroxModule

Hilbert-Space Gaussian Process feature layer (Riutort-Mayol et al., 2023).

A deterministic Laplacian-eigenfunction basis on the bounded box :math:[-L, L]^D plus learnable per-basis amplitudes with a kernel-spectral-density prior:

.. math::

\hat{f}(x) = \sum_{j=1}^{M} \alpha_j\,\sqrt{S(\sqrt{\lambda_j})}\,\phi_j(x),
\quad \alpha_j \sim \mathcal{N}(0, 1).

This is the NN-side dual of :class:pyrox.gp.FourierInducingFeatures — same basis, different prior wiring. As M and L grow, the induced GP converges to the kernel passed in.

Attributes:

Name Type Description
in_features int

Input dimension :math:D.

num_basis_per_dim tuple[int, ...]

Per-axis number of 1D eigenfunctions; total basis count is prod(num_basis_per_dim).

L tuple[float, ...]

Per-axis box half-width.

kernel Kernel

A stationary kernel from :mod:pyrox.gp whose spectral density supplies the per-basis prior variance. Currently :class:pyrox.gp.RBF and :class:pyrox.gp.Matern are supported by :func:pyrox._basis.spectral_density.

pyrox_name str | None

Explicit scope name for NumPyro site registration.

Source code in src/pyrox/nn/_layers.py
class HSGPFeatures(PyroxModule):
    r"""Hilbert-Space Gaussian Process feature layer (Riutort-Mayol et al., 2023).

    A *deterministic* Laplacian-eigenfunction basis on the bounded box
    :math:`[-L, L]^D` plus learnable per-basis amplitudes with a
    kernel-spectral-density prior:

    .. math::

        \hat{f}(x) = \sum_{j=1}^{M} \alpha_j\,\sqrt{S(\sqrt{\lambda_j})}\,\phi_j(x),
        \quad \alpha_j \sim \mathcal{N}(0, 1).

    This is the NN-side dual of :class:`pyrox.gp.FourierInducingFeatures`
    — same basis, different prior wiring. As ``M`` and ``L`` grow, the
    induced GP converges to the kernel passed in.

    Attributes:
        in_features: Input dimension :math:`D`.
        num_basis_per_dim: Per-axis number of 1D eigenfunctions; total
            basis count is ``prod(num_basis_per_dim)``.
        L: Per-axis box half-width.
        kernel: A stationary kernel from :mod:`pyrox.gp` whose spectral
            density supplies the per-basis prior variance. Currently
            :class:`pyrox.gp.RBF` and :class:`pyrox.gp.Matern` are
            supported by :func:`pyrox._basis.spectral_density`.
        pyrox_name: Explicit scope name for NumPyro site registration.
    """

    in_features: int = eqx.field(static=True)
    num_basis_per_dim: tuple[int, ...] = eqx.field(static=True)
    L: tuple[float, ...] = eqx.field(static=True)
    kernel: Kernel
    pyrox_name: str | None = None

    @classmethod
    def init(
        cls,
        in_features: int,
        num_basis_per_dim: int | tuple[int, ...],
        L: float | tuple[float, ...],
        *,
        kernel: Kernel,
    ) -> HSGPFeatures:
        if isinstance(num_basis_per_dim, int):
            num_basis_per_dim = (num_basis_per_dim,) * in_features
        if isinstance(L, int | float):
            L = (float(L),) * in_features
        if len(num_basis_per_dim) != in_features:
            raise ValueError(
                f"num_basis_per_dim length ({len(num_basis_per_dim)}) "
                f"must match in_features ({in_features})."
            )
        if len(L) != in_features:
            raise ValueError(
                f"L length ({len(L)}) must match in_features ({in_features})."
            )
        if any(L_d <= 0 for L_d in L):
            raise ValueError(f"L must be all positive; got {L}.")
        if any(M_d < 1 for M_d in num_basis_per_dim):
            raise ValueError(
                f"num_basis_per_dim must be all >= 1; got {num_basis_per_dim}."
            )
        return cls(
            in_features=in_features,
            num_basis_per_dim=tuple(num_basis_per_dim),
            L=tuple(float(L_d) for L_d in L),
            kernel=kernel,
        )

    @property
    def num_basis(self) -> int:
        n = 1
        for m in self.num_basis_per_dim:
            n *= m
        return n

    @pyrox_method
    def __call__(self, x: Float[Array, "N D_in"]) -> Float[Array, " N"]:
        Phi, lam = fourier_basis(x, self.num_basis_per_dim, self.L)  # (N, M), (M,)
        # Spectral density evaluated under the kernel's own context so any
        # priors on (variance, lengthscale) register exactly once.
        with _kernel_context(self.kernel):
            S = spectral_density(self.kernel, lam, D=self.in_features)
        sqrt_S = jnp.sqrt(S)
        alpha = self.pyrox_sample(
            "alpha",
            dist.Normal(0.0, 1.0).expand([self.num_basis]).to_event(1),
        )
        return jnp.einsum("nm,m->n", Phi, sqrt_S * alpha)

Bayesian Neural Field stack

pyrox.nn.Standardization

Bases: PyroxModule

Apply a fixed-coefficient affine standardization.

.. math::

\tilde x \;=\; \frac{x - \mu}{\sigma}.

Both mu and std are static (fit-time) constants, not learned. Use :func:pyrox.preprocessing.fit_standardization to construct from a pandas DataFrame.

Attributes:

Name Type Description
mu Float[Array, ' D']

Per-feature mean, shape (D,).

std Float[Array, ' D']

Per-feature standard deviation, shape (D,). Must be strictly positive — guard upstream.

pyrox_name str | None

Optional override for the per-instance scope name.

Source code in src/pyrox/nn/_bnf.py
class Standardization(PyroxModule):
    r"""Apply a fixed-coefficient affine standardization.

    .. math::

        \tilde x \;=\; \frac{x - \mu}{\sigma}.

    Both ``mu`` and ``std`` are static (fit-time) constants, not
    learned. Use :func:`pyrox.preprocessing.fit_standardization` to
    construct from a pandas DataFrame.

    Attributes:
        mu: Per-feature mean, shape ``(D,)``.
        std: Per-feature standard deviation, shape ``(D,)``. Must be
            strictly positive — guard upstream.
        pyrox_name: Optional override for the per-instance scope name.
    """

    mu: Float[Array, " D"]
    std: Float[Array, " D"]
    pyrox_name: str | None = None

    @pyrox_method
    def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N D"]:
        return (x - self.mu) / self.std

pyrox.nn.FourierFeatures

Bases: PyroxModule

Per-input dyadic-frequency Fourier basis.

For each input column, evaluates 2 * degree Fourier features at frequencies :math:2\pi \cdot 2^d for :math:d \in \{0, \dots, \text{degree} - 1\}. Concatenated across all columns.

Wraps :func:pyrox.nn._features.fourier_features per input dimension.

Attributes:

Name Type Description
degrees tuple[int, ...]

Number of dyadic frequencies per input column, as a Python tuple[int, ...]. A column with degree = 0 contributes no features. Marked static so the loop over columns unrolls at trace time.

rescale bool

If True, divide each (cos_d, sin_d) pair by d + 1 to bias the prior toward lower frequencies.

pyrox_name str | None

Optional scope-name override.

Source code in src/pyrox/nn/_bnf.py
class FourierFeatures(PyroxModule):
    r"""Per-input dyadic-frequency Fourier basis.

    For each input column, evaluates ``2 * degree`` Fourier features at
    frequencies :math:`2\pi \cdot 2^d` for :math:`d \in \{0, \dots,
    \text{degree} - 1\}`. Concatenated across all columns.

    Wraps :func:`pyrox.nn._features.fourier_features` per input
    dimension.

    Attributes:
        degrees: Number of dyadic frequencies per input column, as a
            Python ``tuple[int, ...]``. A column with ``degree = 0``
            contributes no features. Marked ``static`` so the loop
            over columns unrolls at trace time.
        rescale: If ``True``, divide each ``(cos_d, sin_d)`` pair by
            ``d + 1`` to bias the prior toward lower frequencies.
        pyrox_name: Optional scope-name override.
    """

    degrees: tuple[int, ...] = eqx.field(static=True)
    rescale: bool = eqx.field(static=True, default=False)
    pyrox_name: str | None = None

    @pyrox_method
    def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N F"]:
        feats = []
        for col_idx, d in enumerate(self.degrees):
            if d <= 0:
                continue
            feats.append(fourier_features(x[:, col_idx], d, rescale=self.rescale))
        if not feats:
            return jnp.zeros((x.shape[0], 0), dtype=x.dtype)
        return jnp.concatenate(feats, axis=-1)

pyrox.nn.SeasonalFeatures

Bases: PyroxModule

Period-and-harmonic cos/sin basis on a scalar time axis.

For each period :math:\tau_p with :math:H_p harmonics, emits 2 * H_p cos/sin columns. Total output width is :math:2 \sum_p H_p.

Wraps :func:pyrox.nn._features.seasonal_features. Periods and harmonics are kept as Python tuples (static) so the inner shape structure is known at trace time.

Attributes:

Name Type Description
periods tuple[float, ...]

Period values, tuple[float, ...].

harmonics tuple[int, ...]

Harmonics per period, tuple[int, ...].

rescale bool

If True, divide each (cos, sin) pair by its within-period harmonic index.

pyrox_name str | None

Optional scope-name override.

Source code in src/pyrox/nn/_bnf.py
class SeasonalFeatures(PyroxModule):
    r"""Period-and-harmonic cos/sin basis on a scalar time axis.

    For each period :math:`\tau_p` with :math:`H_p` harmonics, emits
    ``2 * H_p`` cos/sin columns. Total output width is :math:`2 \sum_p
    H_p`.

    Wraps :func:`pyrox.nn._features.seasonal_features`. Periods and
    harmonics are kept as Python tuples (static) so the inner shape
    structure is known at trace time.

    Attributes:
        periods: Period values, ``tuple[float, ...]``.
        harmonics: Harmonics per period, ``tuple[int, ...]``.
        rescale: If ``True``, divide each ``(cos, sin)`` pair by its
            within-period harmonic index.
        pyrox_name: Optional scope-name override.
    """

    periods: tuple[float, ...] = eqx.field(static=True)
    harmonics: tuple[int, ...] = eqx.field(static=True)
    rescale: bool = eqx.field(static=True, default=False)
    pyrox_name: str | None = None

    @pyrox_method
    def __call__(self, t: Float[Array, " N"]) -> Float[Array, "N F"]:
        return seasonal_features(t, self.periods, self.harmonics, rescale=self.rescale)

pyrox.nn.InteractionFeatures

Bases: PyroxModule

Element-wise products on selected pairs of input columns.

Wraps :func:pyrox.nn._features.interaction_features.

Attributes:

Name Type Description
pairs tuple[tuple[int, int], ...]

Index pairs, tuple[tuple[int, int], ...]. Empty tuple produces an (N, 0) output. Static so the count K is known at trace time.

pyrox_name str | None

Optional scope-name override.

Source code in src/pyrox/nn/_bnf.py
class InteractionFeatures(PyroxModule):
    r"""Element-wise products on selected pairs of input columns.

    Wraps :func:`pyrox.nn._features.interaction_features`.

    Attributes:
        pairs: Index pairs, ``tuple[tuple[int, int], ...]``. Empty
            tuple produces an ``(N, 0)`` output. Static so the count
            ``K`` is known at trace time.
        pyrox_name: Optional scope-name override.
    """

    pairs: tuple[tuple[int, int], ...] = eqx.field(static=True)
    pyrox_name: str | None = None

    @pyrox_method
    def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N K"]:
        if not self.pairs:
            return jnp.zeros((x.shape[0], 0), dtype=x.dtype)
        return interaction_features(x, jnp.asarray(self.pairs, dtype=jnp.int32))

pyrox.nn.BayesianNeuralField

Bases: PyroxModule

The full Bayesian Neural Field architecture.

A spatiotemporal MLP with:

  1. A learned per-input log-scale adjustment (Logistic(0, 1) prior).
  2. Four feature blocks concatenated into h_0: rescaled inputs, Fourier features, seasonal features, interaction products.
  3. Per-block softplus(feature_gain) modulation.
  4. A depth-L MLP whose layers are :math:h_{\ell+1} = \sigma_\alpha\bigl(g_\ell \cdot W_\ell\, h_\ell / \sqrt{\lvert h_\ell \rvert}\bigr), where :math:\sigma_\alpha = \mathrm{sig}(\beta) \cdot \mathrm{elu} + (1 - \mathrm{sig}(\beta)) \cdot \mathrm{tanh} is a learned mixed activation.
  5. A final linear layer scaled by softplus(output_gain).

All weights, biases, gains, scales, and the activation logit carry independent :math:\mathrm{Logistic}(0, 1) priors registered via :meth:PyroxModule.pyrox_sample.

The :math:1/\sqrt{\text{fan-in}} pre-normalization is the standard NTK-scaling trick — it makes the layer-wise prior predictive a fan-in-independent Gaussian process in the infinite-width limit (Lee et al., 2018).

Attributes:

Name Type Description
input_scales tuple[float, ...]

Per-input fixed scale (typically training-data inter-quartile range). Static tuple[float, ...].

fourier_degrees tuple[int, ...]

Per-input number of dyadic Fourier frequencies. Static tuple[int, ...]; use 0 to skip a column.

interactions tuple[tuple[int, int], ...]

Pair-index list for interaction features. Static tuple[tuple[int, int], ...]; empty for none.

seasonality_periods tuple[float, ...]

Periods for seasonal features. Static tuple[float, ...]. The time variable is taken from input column time_col.

num_seasonal_harmonics tuple[int, ...]

Harmonics per period. Static tuple[int, ...].

width int

Hidden layer width.

depth int

Number of hidden MLP layers.

time_col int

Index of the time column inside x used for seasonal features (default 0).

pyrox_name str | None

Optional scope-name override.

Source code in src/pyrox/nn/_bnf.py
class BayesianNeuralField(PyroxModule):
    r"""The full Bayesian Neural Field architecture.

    A spatiotemporal MLP with:

    1. A learned per-input log-scale adjustment (Logistic(0, 1) prior).
    2. Four feature blocks concatenated into ``h_0``: rescaled inputs,
       Fourier features, seasonal features, interaction products.
    3. Per-block ``softplus(feature_gain)`` modulation.
    4. A depth-``L`` MLP whose layers are
       :math:`h_{\ell+1} = \sigma_\alpha\bigl(g_\ell \cdot W_\ell\, h_\ell
       / \sqrt{\lvert h_\ell \rvert}\bigr)`, where :math:`\sigma_\alpha
       = \mathrm{sig}(\beta) \cdot \mathrm{elu} + (1 - \mathrm{sig}(\beta))
       \cdot \mathrm{tanh}` is a learned mixed activation.
    5. A final linear layer scaled by ``softplus(output_gain)``.

    All weights, biases, gains, scales, and the activation logit carry
    independent :math:`\mathrm{Logistic}(0, 1)` priors registered via
    :meth:`PyroxModule.pyrox_sample`.

    The :math:`1/\sqrt{\text{fan-in}}` pre-normalization is the
    standard NTK-scaling trick — it makes the layer-wise prior
    predictive a fan-in-independent Gaussian process in the
    infinite-width limit (Lee et al., 2018).

    Attributes:
        input_scales: Per-input fixed scale (typically training-data
            inter-quartile range). Static ``tuple[float, ...]``.
        fourier_degrees: Per-input number of dyadic Fourier
            frequencies. Static ``tuple[int, ...]``; use ``0`` to skip
            a column.
        interactions: Pair-index list for interaction features. Static
            ``tuple[tuple[int, int], ...]``; empty for none.
        seasonality_periods: Periods for seasonal features. Static
            ``tuple[float, ...]``. The time variable is taken from
            input column ``time_col``.
        num_seasonal_harmonics: Harmonics per period. Static
            ``tuple[int, ...]``.
        width: Hidden layer width.
        depth: Number of hidden MLP layers.
        time_col: Index of the time column inside ``x`` used for
            seasonal features (default 0).
        pyrox_name: Optional scope-name override.
    """

    input_scales: tuple[float, ...] = eqx.field(static=True)
    fourier_degrees: tuple[int, ...] = eqx.field(static=True)
    interactions: tuple[tuple[int, int], ...] = eqx.field(static=True)
    seasonality_periods: tuple[float, ...] = eqx.field(static=True)
    num_seasonal_harmonics: tuple[int, ...] = eqx.field(static=True)
    width: int = eqx.field(static=True)
    depth: int = eqx.field(static=True)
    time_col: int = eqx.field(static=True, default=0)
    pyrox_name: str | None = None

    @staticmethod
    def _logistic_prior(shape: tuple[int, ...]) -> dist.Distribution:
        """Independent Logistic(0, 1) prior over an array of given shape."""
        if not shape:
            return dist.Logistic(0.0, 1.0)
        return dist.Logistic(jnp.zeros(shape), 1.0).to_event(len(shape))

    @pyrox_method
    def __call__(self, x: Float[Array, "N D"]) -> Float[Array, " N"]:
        d_in = len(self.input_scales)
        input_scales = jnp.asarray(self.input_scales, dtype=jnp.float32)

        # 1. Input rescaling: x / (input_scales * exp(log_scale_adjustment)).
        log_scale_adjustment = self.pyrox_sample(
            "log_scale_adjustment",
            self._logistic_prior((d_in,)),
        )
        scaled_x = x / (input_scales * jnp.exp(log_scale_adjustment))

        # 2. Build the four feature blocks.
        feature_blocks: list[Float[Array, "N F_block"]] = [scaled_x]

        # Fourier per input dim (only for degrees > 0).
        for col_idx, d in enumerate(self.fourier_degrees):
            if d > 0:
                feature_blocks.append(
                    fourier_features(scaled_x[:, col_idx], d, rescale=True)
                )

        # Seasonal on the time column.
        if self.seasonality_periods and any(self.num_seasonal_harmonics):
            feature_blocks.append(
                seasonal_features(
                    x[:, self.time_col],
                    self.seasonality_periods,
                    self.num_seasonal_harmonics,
                    rescale=True,
                )
            )

        # Interaction products.
        if self.interactions:
            feature_blocks.append(
                interaction_features(
                    scaled_x, jnp.asarray(self.interactions, dtype=jnp.int32)
                )
            )

        # 3. Per-block softplus(feature_gain) modulation.
        gated_blocks: list[Float[Array, "N F_block"]] = []
        for b_idx, block in enumerate(feature_blocks):
            if block.shape[-1] == 0:
                continue
            gain = self.pyrox_sample(
                f"feature_gain_{b_idx}",
                self._logistic_prior(()),
            )
            gated_blocks.append(block * jax.nn.softplus(gain))
        h = jnp.concatenate(gated_blocks, axis=-1)

        # 4. Mixed elu/tanh activation, learned mix weight.
        logit_activation_weight = self.pyrox_sample(
            "logit_activation_weight",
            self._logistic_prior(()),
        )
        alpha = jax.nn.sigmoid(logit_activation_weight)

        def activation(z: Float[Array, ...]) -> Float[Array, ...]:
            return alpha * jax.nn.elu(z) + (1.0 - alpha) * jnp.tanh(z)

        # 5. Hidden MLP layers.
        for layer_idx in range(self.depth):
            fan_in = h.shape[-1]
            W = self.pyrox_sample(
                f"layer_{layer_idx}_W",
                self._logistic_prior((fan_in, self.width)),
            )
            b = self.pyrox_sample(
                f"layer_{layer_idx}_b",
                self._logistic_prior((self.width,)),
            )
            layer_gain = self.pyrox_sample(
                f"layer_{layer_idx}_gain",
                self._logistic_prior(()),
            )
            h = h / jnp.sqrt(fan_in)
            h = activation(jax.nn.softplus(layer_gain) * (h @ W + b))

        # 6. Output linear, scaled by softplus(output_gain).
        fan_in = h.shape[-1]
        W_out = self.pyrox_sample(
            "output_W",
            self._logistic_prior((fan_in, 1)),
        )
        b_out = self.pyrox_sample(
            "output_b",
            self._logistic_prior((1,)),
        )
        output_gain = self.pyrox_sample(
            "output_gain",
            self._logistic_prior(()),
        )
        h = h / jnp.sqrt(fan_in)
        return (jax.nn.softplus(output_gain) * (h @ W_out + b_out)).squeeze(-1)

Pure-JAX feature helpers

pyrox.nn.fourier_features(x, max_degree, *, rescale=False)

Cos/sin Fourier basis at dyadic frequencies.

For each input element and each degree :math:d \in \{0, \dots, D-1\}, evaluates

.. math::

\phi_{d, \cos}(x) = \cos(2\pi \cdot 2^d \cdot x), \qquad
\phi_{d, \sin}(x) = \sin(2\pi \cdot 2^d \cdot x).

Returns the columns concatenated as [cos_0, ..., cos_{D-1}, sin_0, ..., sin_{D-1}], matching Google's bayesnf layout.

Parameters:

Name Type Description Default
x Float[Array, ' N']

Length-N input vector.

required
max_degree int

Number of dyadic frequencies D. Output has 2 * max_degree columns.

required
rescale bool

If True, divide each (cos_d, sin_d) pair by d + 1 to bias the prior toward lower-frequency basis functions.

False

Returns:

Type Description
Float[Array, 'N two_max_degree']

Array of shape (N, 2 * max_degree).

Source code in src/pyrox/nn/_features.py
def fourier_features(
    x: Float[Array, " N"],
    max_degree: int,
    *,
    rescale: bool = False,
) -> Float[Array, "N two_max_degree"]:
    r"""Cos/sin Fourier basis at dyadic frequencies.

    For each input element and each degree :math:`d \in \{0, \dots,
    D-1\}`, evaluates

    .. math::

        \phi_{d, \cos}(x) = \cos(2\pi \cdot 2^d \cdot x), \qquad
        \phi_{d, \sin}(x) = \sin(2\pi \cdot 2^d \cdot x).

    Returns the columns concatenated as ``[cos_0, ..., cos_{D-1},
    sin_0, ..., sin_{D-1}]``, matching Google's bayesnf layout.

    Args:
        x: Length-``N`` input vector.
        max_degree: Number of dyadic frequencies ``D``. Output has
            ``2 * max_degree`` columns.
        rescale: If ``True``, divide each ``(cos_d, sin_d)`` pair by
            ``d + 1`` to bias the prior toward lower-frequency basis
            functions.

    Returns:
        Array of shape ``(N, 2 * max_degree)``.
    """
    degrees = jnp.arange(max_degree)
    # `repeat` builds (N, D) frequencies without an explicit reshape.
    z = repeat(x, "n -> n d", d=max_degree) * (2.0 * jnp.pi * 2.0**degrees)
    feats = jnp.concatenate([jnp.cos(z), jnp.sin(z)], axis=-1)
    if rescale:
        denom = jnp.concatenate([degrees + 1, degrees + 1])
        feats = feats / denom
    return feats

pyrox.nn.seasonal_features(x, periods, harmonics, *, rescale=False)

Cos/sin features at multiples of :math:2\pi / \tau_p.

For each period :math:\tau_p with :math:H_p harmonics, evaluates

.. math::

\phi_{p, h, \cos}(x) = \cos(2\pi h x / \tau_p), \qquad
\phi_{p, h, \sin}(x) = \sin(2\pi h x / \tau_p),

for :math:h = 1, \dots, H_p. Returns the cos columns concatenated with the sin columns, length :math:F = \sum_p H_p each.

periods and harmonics are Python sequences (tuples, lists, or 0-d JAX arrays wrapped at the call site). Keeping them as Python values lets the function run cleanly under jax.jit and lax.scan without triggering a concretization error.

Parameters:

Name Type Description Default
x Float[Array, ' N']

Time/index input, shape (N,).

required
periods Sequence[float]

Period values.

required
harmonics Sequence[int]

Harmonics per period.

required
rescale bool

If True, divide each (cos, sin) pair by its within-period harmonic index, biasing the prior toward longer-wavelength modes within each period.

False

Returns:

Type Description
Float[Array, 'N two_F']

Array of shape (N, 2 * F).

Source code in src/pyrox/nn/_features.py
def seasonal_features(
    x: Float[Array, " N"],
    periods: Sequence[float],
    harmonics: Sequence[int],
    *,
    rescale: bool = False,
) -> Float[Array, "N two_F"]:
    r"""Cos/sin features at multiples of :math:`2\pi / \tau_p`.

    For each period :math:`\tau_p` with :math:`H_p` harmonics, evaluates

    .. math::

        \phi_{p, h, \cos}(x) = \cos(2\pi h x / \tau_p), \qquad
        \phi_{p, h, \sin}(x) = \sin(2\pi h x / \tau_p),

    for :math:`h = 1, \dots, H_p`. Returns the cos columns concatenated
    with the sin columns, length :math:`F = \sum_p H_p` each.

    ``periods`` and ``harmonics`` are **Python sequences** (tuples,
    lists, or 0-d JAX arrays wrapped at the call site). Keeping them as
    Python values lets the function run cleanly under ``jax.jit`` and
    ``lax.scan`` without triggering a concretization error.

    Args:
        x: Time/index input, shape ``(N,)``.
        periods: Period values.
        harmonics: Harmonics per period.
        rescale: If ``True``, divide each ``(cos, sin)`` pair by its
            within-period harmonic index, biasing the prior toward
            longer-wavelength modes within each period.

    Returns:
        Array of shape ``(N, 2 * F)``.
    """
    _, freq_list = seasonal_frequencies(periods, harmonics)
    if not freq_list:
        return jnp.zeros((x.shape[0], 0), dtype=x.dtype)
    freqs = jnp.asarray(freq_list, dtype=jnp.float32)
    z = repeat(x, "n -> n f", f=freqs.shape[0]) * (2.0 * jnp.pi * freqs)
    feats = jnp.concatenate([jnp.cos(z), jnp.sin(z)], axis=-1)
    if rescale:
        # Rescale by within-period harmonic index (1, 2, ..., H_p).
        h_within_list: list[float] = []
        for n_h in harmonics:
            h_within_list.extend(range(1, int(n_h) + 1))
        h_within = jnp.asarray(h_within_list, dtype=jnp.float32)
        denom = jnp.concatenate([h_within, h_within])
        feats = feats / denom
    return feats

pyrox.nn.seasonal_frequencies(periods, harmonics)

Flatten (period, harmonic_count) pairs into Python lists.

For each period :math:\tau_p with :math:H_p harmonics, emits frequencies :math:f_{p, h} = h / \tau_p for :math:h = 1, \dots, H_p. The total length is :math:F = \sum_p H_p.

Inputs are Python sequences, not JAX arrays, so this helper runs at trace time and never triggers a concretization error under jax.jit. Most callers won't use it directly; it's exposed for symmetry with :func:seasonal_features.

Parameters:

Name Type Description Default
periods Sequence[float]

Period values.

required
harmonics Sequence[int]

Number of harmonics per period.

required

Returns:

Type Description
list[int]

(period_index, frequency): two Python lists of length

list[float]

math:F = \sum_p H_p.

Source code in src/pyrox/nn/_features.py
def seasonal_frequencies(
    periods: Sequence[float],
    harmonics: Sequence[int],
) -> tuple[list[int], list[float]]:
    r"""Flatten ``(period, harmonic_count)`` pairs into Python lists.

    For each period :math:`\tau_p` with :math:`H_p` harmonics, emits
    frequencies :math:`f_{p, h} = h / \tau_p` for :math:`h = 1, \dots,
    H_p`. The total length is :math:`F = \sum_p H_p`.

    Inputs are **Python sequences**, not JAX arrays, so this helper
    runs at trace time and never triggers a concretization error under
    ``jax.jit``. Most callers won't use it directly; it's exposed for
    symmetry with :func:`seasonal_features`.

    Args:
        periods: Period values.
        harmonics: Number of harmonics per period.

    Returns:
        ``(period_index, frequency)``: two Python lists of length
        :math:`F = \sum_p H_p`.
    """
    period_index: list[int] = []
    freqs: list[float] = []
    for p_idx, (period, n_h) in enumerate(zip(periods, harmonics, strict=True)):
        for h in range(1, int(n_h) + 1):
            period_index.append(p_idx)
            freqs.append(float(h) / float(period))
    return period_index, freqs

pyrox.nn.interaction_features(x, pairs)

Element-wise products on selected pairs of input columns.

For each pair :math:(i, j) and each row :math:n, computes :math:x_{n, i} \cdot x_{n, j}.

Parameters:

Name Type Description Default
x Float[Array, 'N D']

Input matrix, shape (N, D).

required
pairs Int[Array, 'K 2']

Index pairs, shape (K, 2). Empty pairs yield an (N, 0) output.

required

Returns:

Type Description
Float[Array, 'N K']

Array of shape (N, K) of pairwise products.

Source code in src/pyrox/nn/_features.py
def interaction_features(
    x: Float[Array, "N D"],
    pairs: Int[Array, "K 2"],
) -> Float[Array, "N K"]:
    r"""Element-wise products on selected pairs of input columns.

    For each pair :math:`(i, j)` and each row :math:`n`, computes
    :math:`x_{n, i} \cdot x_{n, j}`.

    Args:
        x: Input matrix, shape ``(N, D)``.
        pairs: Index pairs, shape ``(K, 2)``. Empty pairs yield an
            ``(N, 0)`` output.

    Returns:
        Array of shape ``(N, K)`` of pairwise products.
    """
    if pairs.shape[0] == 0:
        return jnp.zeros((x.shape[0], 0), dtype=x.dtype)
    # x[:, pairs] has shape (N, K, 2); reduce the last axis with prod.
    selected = x[:, pairs]
    return jnp.prod(selected, axis=-1)

pyrox.nn.standardize(x, mu, std)

Affine standardize: (x - mu) / std.

Broadcasts mu and std against x per the JAX broadcasting rules. std is not clamped; pass a positive value or guard upstream.

Source code in src/pyrox/nn/_features.py
def standardize(
    x: Float[Array, "*shape"],
    mu: Float[Array, "*shape"],
    std: Float[Array, "*shape"],
) -> Float[Array, "*shape"]:
    """Affine standardize: ``(x - mu) / std``.

    Broadcasts ``mu`` and ``std`` against ``x`` per the JAX broadcasting
    rules. ``std`` is *not* clamped; pass a positive value or guard
    upstream.
    """
    return (x - mu) / std

pyrox.nn.unstandardize(z, mu, std)

Inverse of :func:standardize: z * std + mu.

Source code in src/pyrox/nn/_features.py
def unstandardize(
    z: Float[Array, "*shape"],
    mu: Float[Array, "*shape"],
    std: Float[Array, "*shape"],
) -> Float[Array, "*shape"]:
    """Inverse of :func:`standardize`: ``z * std + mu``."""
    return z * std + mu