NN API¶
The pyrox.nn subpackage ships uncertainty-aware neural network layers in three families:
- Dense / Bayesian-linear layers (
pyrox.nn._layers) — twelve layers covering reparameterization, Flipout, NCP, MC-Dropout, and several random-feature variants. - Bayesian Neural Field stack (
pyrox.nn._bnf) — five layers that together implement the BNF architecture (Saad et al., Nat. Comms. 2024). - Pure-JAX feature helpers (
pyrox.nn._features) — pandas-free building blocks the BNF layers wrap.
Dense / Bayesian-linear layers¶
pyrox.nn.DenseReparameterization
¶
Bases: PyroxModule
Bayesian dense layer via the reparameterization trick.
Samples weight and bias from learned Gaussian posteriors at every forward pass. Registers NumPyro sample sites so the KL between the variational posterior and the prior is tracked by the ELBO.
.. math::
W \sim \mathcal{N}(\mu_W, \sigma_W^2), \quad
b \sim \mathcal{N}(\mu_b, \sigma_b^2), \quad
y = x W + b.
Attributes:
| Name | Type | Description |
|---|---|---|
in_features |
int
|
Input dimension. |
out_features |
int
|
Output dimension. |
bias |
bool
|
Whether to include a bias term. |
prior_scale |
float
|
Scale of the isotropic Gaussian prior on weights and bias. The prior mean is zero. |
pyrox_name |
str | None
|
Explicit scope name for NumPyro site registration. |
Source code in src/pyrox/nn/_layers.py
pyrox.nn.DenseFlipout
¶
Bases: PyroxModule
Bayesian dense layer with Flipout sign-flip structure.
Samples weight from the prior and applies per-example Rademacher sign flips to the weight perturbation (Wen et al., 2018). Under a NumPyro guide that learns the posterior mean, the sign flips decorrelate gradient estimates across minibatch examples.
In model mode (no guide) this is equivalent to
:class:DenseReparameterization — the Flipout variance reduction
activates when a guide provides a posterior centered at a learned
mean.
Attributes:
| Name | Type | Description |
|---|---|---|
in_features |
int
|
Input dimension. |
out_features |
int
|
Output dimension. |
bias |
bool
|
Whether to include a bias term. |
prior_scale |
float
|
Scale of the isotropic Gaussian prior. |
pyrox_name |
str | None
|
Explicit scope name for NumPyro site registration. |
Source code in src/pyrox/nn/_layers.py
pyrox.nn.DenseVariational
¶
Bases: PyroxModule
Dense layer with a user-supplied prior factory.
Provides flexibility over the weight prior by accepting a callable
that builds the prior distribution given the layer shape. The
model samples from the prior; the posterior is handled by a NumPyro
guide (e.g., AutoNormal).
Attributes:
| Name | Type | Description |
|---|---|---|
in_features |
int
|
Input dimension. |
out_features |
int
|
Output dimension. |
make_prior |
Callable[..., Any]
|
Callable |
bias |
bool
|
Whether to include a bias term. |
pyrox_name |
str | None
|
Explicit scope name for NumPyro site registration. |
Source code in src/pyrox/nn/_layers.py
pyrox.nn.MCDropout
¶
Bases: Module
Always-on dropout for Monte Carlo uncertainty estimation.
Unlike standard dropout, :class:MCDropout stays active at
inference time — repeated forward passes with different keys
produce a distribution of outputs whose spread approximates
predictive uncertainty (Gal & Ghahramani, 2016).
Not a :class:PyroxModule — no NumPyro sites are registered.
The stochasticity comes from the explicit PRNG key argument.
Attributes:
| Name | Type | Description |
|---|---|---|
rate |
float
|
Dropout probability in :math: |
Source code in src/pyrox/nn/_layers.py
__call__(x, *, key)
¶
Apply dropout, scaling survivors by 1 / (1 - rate).
Source code in src/pyrox/nn/_layers.py
pyrox.nn.DenseNCP
¶
Bases: PyroxModule
Noise Contrastive Prior dense layer (Hafner et al., 2019).
Decomposes a dense layer into a prior-regularized backbone plus a scaled stochastic perturbation:
.. math::
y = \underbrace{x W_d + b_d}_{\text{backbone}}
+ \underbrace{\sigma \cdot (x W_s + b_s)}_{\text{perturbation}},
where all weights are pyrox_sample sites with Gaussian priors
and :math:\sigma has a LogNormal prior. The backbone carries
the bulk of the signal; the perturbation branch adds calibrated
uncertainty that can be trained via a noise contrastive objective.
Attributes:
| Name | Type | Description |
|---|---|---|
in_features |
int
|
Input dimension. |
out_features |
int
|
Output dimension. |
init_scale |
float
|
Initial value for the perturbation scale
:math: |
pyrox_name |
str | None
|
Explicit scope name for NumPyro site registration. |
Source code in src/pyrox/nn/_layers.py
pyrox.nn.NCPContinuousPerturb
¶
Bases: Module
Input perturbation for the Noise Contrastive Prior pattern.
Adds Gaussian noise scaled by a learned positive scale to the input:
.. math::
\tilde{x} = x + \sigma \epsilon, \qquad
\epsilon \sim \mathcal{N}(0, I).
Place before a deterministic network to inject input uncertainty;
pair with :class:DenseNCP at the output for the full NCP
architecture (Hafner et al., 2019).
Not a :class:PyroxModule — stochasticity comes from the
explicit PRNG key.
Attributes:
| Name | Type | Description |
|---|---|---|
scale |
float | Float[Array, '']
|
Perturbation scale :math: |
Source code in src/pyrox/nn/_layers.py
pyrox.nn.RBFFourierFeatures
¶
Bases: PyroxModule
SSGP-style RFF layer with RBF spectral density.
Both the spectral frequencies :math:W and the lengthscale
:math:\ell are pyrox_sample sites — :math:W has a
standard normal prior (the RBF spectral density) and :math:\ell
has a LogNormal prior. Under SVI, the guide learns a posterior
over both; under a seed handler, they are drawn from the prior.
Attributes:
| Name | Type | Description |
|---|---|---|
in_features |
int
|
Input dimension. |
n_features |
int
|
Number of frequency pairs (output dim
|
init_lengthscale |
float
|
Prior location for the lengthscale. |
pyrox_name |
str | None
|
Explicit scope name for NumPyro site registration. |
Source code in src/pyrox/nn/_layers.py
pyrox.nn.RBFCosineFeatures
¶
Bases: PyroxModule
Cosine-bias variant of random Fourier features for the RBF kernel.
Uses the single-cosine feature map with a bias term:
.. math::
\phi(x) = \sqrt{2 / D}\,\cos(x W / \ell + b)
where :math:W \sim \mathcal{N}(0, I) and
:math:b \sim \mathrm{Uniform}(0, 2\pi). This variant produces
n_features-dimensional output (half the dimension of the
[cos, sin] variant in :class:RBFFourierFeatures) and is
commonly used in Random Kitchen Sinks implementations.
All parameters (:math:W, :math:b, :math:\ell) are
pyrox_sample sites.
Attributes:
| Name | Type | Description |
|---|---|---|
in_features |
int
|
Input dimension. |
n_features |
int
|
Number of random features (= output dimension). |
init_lengthscale |
float
|
Prior location for the lengthscale. |
pyrox_name |
str | None
|
Explicit scope name for NumPyro site registration. |
Source code in src/pyrox/nn/_layers.py
pyrox.nn.MaternFourierFeatures
¶
Bases: PyroxModule
SSGP-style RFF layer with Matern spectral density.
Spectral frequencies :math:W have a StudentT(df=2\nu) prior
(the Matern spectral density). The smoothness :math:\nu controls
the regularity: nu=0.5 (Laplace), nu=1.5 (Matern-3/2),
nu=2.5 (Matern-5/2).
Attributes:
| Name | Type | Description |
|---|---|---|
in_features |
int
|
Input dimension. |
n_features |
int
|
Number of frequency pairs. |
nu |
float
|
Smoothness parameter :math: |
init_lengthscale |
float
|
Prior location for the lengthscale. |
pyrox_name |
str | None
|
Explicit scope name for NumPyro site registration. |
Source code in src/pyrox/nn/_layers.py
pyrox.nn.LaplaceFourierFeatures
¶
Bases: PyroxModule
SSGP-style RFF layer with Laplace (Matern-1/2) spectral density.
Spectral frequencies :math:W have a Cauchy prior (Student-t
with df = 1).
Attributes:
| Name | Type | Description |
|---|---|---|
in_features |
int
|
Input dimension. |
n_features |
int
|
Number of frequency pairs. |
init_lengthscale |
float
|
Prior location for the lengthscale. |
pyrox_name |
str | None
|
Explicit scope name for NumPyro site registration. |
Source code in src/pyrox/nn/_layers.py
pyrox.nn.ArcCosineFourierFeatures
¶
Bases: PyroxModule
Random features for the arc-cosine kernel (Cho & Saul, 2009).
The arc-cosine kernel of order :math:p corresponds to an
infinite-width single-layer ReLU network. The random feature map
is:
.. math::
\phi(x) = \sqrt{2 / D}\,\max(0,\, x W / \ell)^p
where :math:W \sim \mathcal{N}(0, I).
order=0 gives the Heaviside (step) feature; order=1 gives
the ReLU feature (the most common); order=2 gives the squared
ReLU feature.
Attributes:
| Name | Type | Description |
|---|---|---|
in_features |
int
|
Input dimension. |
n_features |
int
|
Number of random features (= output dimension). |
order |
int
|
Kernel order (0, 1, or 2). |
init_lengthscale |
float
|
Prior location for the lengthscale. |
pyrox_name |
str | None
|
Explicit scope name for NumPyro site registration. |
Source code in src/pyrox/nn/_layers.py
pyrox.nn.RandomKitchenSinks
¶
Bases: PyroxModule
Random Kitchen Sinks: RFF + a learned linear head.
Composes any RFF layer (:class:RBFFourierFeatures,
:class:MaternFourierFeatures, :class:LaplaceFourierFeatures)
with a trainable linear projection:
.. math::
y = \phi(x)\, \beta + b
The linear head (beta, bias) is registered via
pyrox_sample with Normal priors.
Attributes:
| Name | Type | Description |
|---|---|---|
rff |
RBFFourierFeatures | MaternFourierFeatures | LaplaceFourierFeatures
|
The underlying RFF feature layer. |
init_beta |
Float[Array, 'D_rff D_out']
|
Initial linear weights. |
init_bias |
Float[Array, ' D_out']
|
Initial bias vector. |
pyrox_name |
str | None
|
Explicit scope name for NumPyro site registration. |
Source code in src/pyrox/nn/_layers.py
init(rff, out_features)
classmethod
¶
Construct from a pre-built RFF layer with zero-initialized head.
Source code in src/pyrox/nn/_layers.py
Wave-4 spectral layers (#41)¶
pyrox.nn.VariationalFourierFeatures
¶
Bases: PyroxModule
VSSGP — RFF with a learnable variational posterior over frequencies.
Standard RFF (e.g. :class:RBFFourierFeatures) treats the spectral
frequencies :math:W as a frozen prior draw; VSSGP (Gal & Turner,
2015) treats :math:W as a latent with a learnable mean-field
posterior, recovering spectral uncertainty on top of the
feature-space uncertainty.
Prior: :math:p(W) = \mathcal{N}(0, I) (RBF spectral density in
lengthscale-1 units). The lengthscale is itself a sampled site
(LogNormal(log init_lengthscale, 1)) so that frequencies are
rescaled to the physical kernel.
Under SVI, attach an :class:~numpyro.infer.autoguide.AutoNormal to
learn the posterior on W; under prior-only seeds, behaves
identically to :class:RBFFourierFeatures.
Attributes:
| Name | Type | Description |
|---|---|---|
in_features |
int
|
Input dimension :math: |
n_features |
int
|
Number of frequency pairs (output dim |
init_lengthscale |
float
|
Prior location for the kernel lengthscale. |
pyrox_name |
str | None
|
Explicit scope name for NumPyro site registration. |
Source code in src/pyrox/nn/_layers.py
pyrox.nn.OrthogonalRandomFeatures
¶
Bases: Module
Orthogonal Random Features (Yu et al., 2016) — variance-reduced RFF.
Frequencies are drawn from blocks of Haar-orthogonal matrices scaled by
independent chi-distributed magnitudes, giving the same RBF kernel
approximation as plain :class:RBFFourierFeatures in expectation but
with provably lower variance for finite n_features.
Frozen at construction time — no priors, no SVI on W. The frequency
matrix is built once from a key and stored as a static array.
Attributes:
| Name | Type | Description |
|---|---|---|
in_features |
int
|
Input dimension :math: |
n_features |
int
|
Number of feature pairs. Must satisfy
|
lengthscale |
Float[Array, '']
|
Fixed kernel lengthscale (no prior; pass a value). |
W |
Float[Array, 'D_in D_orf']
|
Pre-built frequency matrix of shape |
Note
For learnable lengthscale or full Bayesian treatment of the
frequencies, prefer :class:VariationalFourierFeatures.
Source code in src/pyrox/nn/_layers.py
pyrox.nn.HSGPFeatures
¶
Bases: PyroxModule
Hilbert-Space Gaussian Process feature layer (Riutort-Mayol et al., 2023).
A deterministic Laplacian-eigenfunction basis on the bounded box
:math:[-L, L]^D plus learnable per-basis amplitudes with a
kernel-spectral-density prior:
.. math::
\hat{f}(x) = \sum_{j=1}^{M} \alpha_j\,\sqrt{S(\sqrt{\lambda_j})}\,\phi_j(x),
\quad \alpha_j \sim \mathcal{N}(0, 1).
This is the NN-side dual of :class:pyrox.gp.FourierInducingFeatures
— same basis, different prior wiring. As M and L grow, the
induced GP converges to the kernel passed in.
Attributes:
| Name | Type | Description |
|---|---|---|
in_features |
int
|
Input dimension :math: |
num_basis_per_dim |
tuple[int, ...]
|
Per-axis number of 1D eigenfunctions; total
basis count is |
L |
tuple[float, ...]
|
Per-axis box half-width. |
kernel |
Kernel
|
A stationary kernel from :mod: |
pyrox_name |
str | None
|
Explicit scope name for NumPyro site registration. |
Source code in src/pyrox/nn/_layers.py
788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 | |
Bayesian Neural Field stack¶
pyrox.nn.Standardization
¶
Bases: PyroxModule
Apply a fixed-coefficient affine standardization.
.. math::
\tilde x \;=\; \frac{x - \mu}{\sigma}.
Both mu and std are static (fit-time) constants, not
learned. Use :func:pyrox.preprocessing.fit_standardization to
construct from a pandas DataFrame.
Attributes:
| Name | Type | Description |
|---|---|---|
mu |
Float[Array, ' D']
|
Per-feature mean, shape |
std |
Float[Array, ' D']
|
Per-feature standard deviation, shape |
pyrox_name |
str | None
|
Optional override for the per-instance scope name. |
Source code in src/pyrox/nn/_bnf.py
pyrox.nn.FourierFeatures
¶
Bases: PyroxModule
Per-input dyadic-frequency Fourier basis.
For each input column, evaluates 2 * degree Fourier features at
frequencies :math:2\pi \cdot 2^d for :math:d \in \{0, \dots,
\text{degree} - 1\}. Concatenated across all columns.
Wraps :func:pyrox.nn._features.fourier_features per input
dimension.
Attributes:
| Name | Type | Description |
|---|---|---|
degrees |
tuple[int, ...]
|
Number of dyadic frequencies per input column, as a
Python |
rescale |
bool
|
If |
pyrox_name |
str | None
|
Optional scope-name override. |
Source code in src/pyrox/nn/_bnf.py
pyrox.nn.SeasonalFeatures
¶
Bases: PyroxModule
Period-and-harmonic cos/sin basis on a scalar time axis.
For each period :math:\tau_p with :math:H_p harmonics, emits
2 * H_p cos/sin columns. Total output width is :math:2 \sum_p
H_p.
Wraps :func:pyrox.nn._features.seasonal_features. Periods and
harmonics are kept as Python tuples (static) so the inner shape
structure is known at trace time.
Attributes:
| Name | Type | Description |
|---|---|---|
periods |
tuple[float, ...]
|
Period values, |
harmonics |
tuple[int, ...]
|
Harmonics per period, |
rescale |
bool
|
If |
pyrox_name |
str | None
|
Optional scope-name override. |
Source code in src/pyrox/nn/_bnf.py
pyrox.nn.InteractionFeatures
¶
Bases: PyroxModule
Element-wise products on selected pairs of input columns.
Wraps :func:pyrox.nn._features.interaction_features.
Attributes:
| Name | Type | Description |
|---|---|---|
pairs |
tuple[tuple[int, int], ...]
|
Index pairs, |
pyrox_name |
str | None
|
Optional scope-name override. |
Source code in src/pyrox/nn/_bnf.py
pyrox.nn.BayesianNeuralField
¶
Bases: PyroxModule
The full Bayesian Neural Field architecture.
A spatiotemporal MLP with:
- A learned per-input log-scale adjustment (Logistic(0, 1) prior).
- Four feature blocks concatenated into
h_0: rescaled inputs, Fourier features, seasonal features, interaction products. - Per-block
softplus(feature_gain)modulation. - A depth-
LMLP whose layers are :math:h_{\ell+1} = \sigma_\alpha\bigl(g_\ell \cdot W_\ell\, h_\ell / \sqrt{\lvert h_\ell \rvert}\bigr), where :math:\sigma_\alpha = \mathrm{sig}(\beta) \cdot \mathrm{elu} + (1 - \mathrm{sig}(\beta)) \cdot \mathrm{tanh}is a learned mixed activation. - A final linear layer scaled by
softplus(output_gain).
All weights, biases, gains, scales, and the activation logit carry
independent :math:\mathrm{Logistic}(0, 1) priors registered via
:meth:PyroxModule.pyrox_sample.
The :math:1/\sqrt{\text{fan-in}} pre-normalization is the
standard NTK-scaling trick — it makes the layer-wise prior
predictive a fan-in-independent Gaussian process in the
infinite-width limit (Lee et al., 2018).
Attributes:
| Name | Type | Description |
|---|---|---|
input_scales |
tuple[float, ...]
|
Per-input fixed scale (typically training-data
inter-quartile range). Static |
fourier_degrees |
tuple[int, ...]
|
Per-input number of dyadic Fourier
frequencies. Static |
interactions |
tuple[tuple[int, int], ...]
|
Pair-index list for interaction features. Static
|
seasonality_periods |
tuple[float, ...]
|
Periods for seasonal features. Static
|
num_seasonal_harmonics |
tuple[int, ...]
|
Harmonics per period. Static
|
width |
int
|
Hidden layer width. |
depth |
int
|
Number of hidden MLP layers. |
time_col |
int
|
Index of the time column inside |
pyrox_name |
str | None
|
Optional scope-name override. |
Source code in src/pyrox/nn/_bnf.py
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 | |
Pure-JAX feature helpers¶
pyrox.nn.fourier_features(x, max_degree, *, rescale=False)
¶
Cos/sin Fourier basis at dyadic frequencies.
For each input element and each degree :math:d \in \{0, \dots,
D-1\}, evaluates
.. math::
\phi_{d, \cos}(x) = \cos(2\pi \cdot 2^d \cdot x), \qquad
\phi_{d, \sin}(x) = \sin(2\pi \cdot 2^d \cdot x).
Returns the columns concatenated as [cos_0, ..., cos_{D-1},
sin_0, ..., sin_{D-1}], matching Google's bayesnf layout.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, ' N']
|
Length- |
required |
max_degree
|
int
|
Number of dyadic frequencies |
required |
rescale
|
bool
|
If |
False
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'N two_max_degree']
|
Array of shape |
Source code in src/pyrox/nn/_features.py
pyrox.nn.seasonal_features(x, periods, harmonics, *, rescale=False)
¶
Cos/sin features at multiples of :math:2\pi / \tau_p.
For each period :math:\tau_p with :math:H_p harmonics, evaluates
.. math::
\phi_{p, h, \cos}(x) = \cos(2\pi h x / \tau_p), \qquad
\phi_{p, h, \sin}(x) = \sin(2\pi h x / \tau_p),
for :math:h = 1, \dots, H_p. Returns the cos columns concatenated
with the sin columns, length :math:F = \sum_p H_p each.
periods and harmonics are Python sequences (tuples,
lists, or 0-d JAX arrays wrapped at the call site). Keeping them as
Python values lets the function run cleanly under jax.jit and
lax.scan without triggering a concretization error.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, ' N']
|
Time/index input, shape |
required |
periods
|
Sequence[float]
|
Period values. |
required |
harmonics
|
Sequence[int]
|
Harmonics per period. |
required |
rescale
|
bool
|
If |
False
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'N two_F']
|
Array of shape |
Source code in src/pyrox/nn/_features.py
pyrox.nn.seasonal_frequencies(periods, harmonics)
¶
Flatten (period, harmonic_count) pairs into Python lists.
For each period :math:\tau_p with :math:H_p harmonics, emits
frequencies :math:f_{p, h} = h / \tau_p for :math:h = 1, \dots,
H_p. The total length is :math:F = \sum_p H_p.
Inputs are Python sequences, not JAX arrays, so this helper
runs at trace time and never triggers a concretization error under
jax.jit. Most callers won't use it directly; it's exposed for
symmetry with :func:seasonal_features.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
periods
|
Sequence[float]
|
Period values. |
required |
harmonics
|
Sequence[int]
|
Number of harmonics per period. |
required |
Returns:
| Type | Description |
|---|---|
list[int]
|
|
list[float]
|
math: |
Source code in src/pyrox/nn/_features.py
pyrox.nn.interaction_features(x, pairs)
¶
Element-wise products on selected pairs of input columns.
For each pair :math:(i, j) and each row :math:n, computes
:math:x_{n, i} \cdot x_{n, j}.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, 'N D']
|
Input matrix, shape |
required |
pairs
|
Int[Array, 'K 2']
|
Index pairs, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N K']
|
Array of shape |
Source code in src/pyrox/nn/_features.py
pyrox.nn.standardize(x, mu, std)
¶
Affine standardize: (x - mu) / std.
Broadcasts mu and std against x per the JAX broadcasting
rules. std is not clamped; pass a positive value or guard
upstream.
Source code in src/pyrox/nn/_features.py
pyrox.nn.unstandardize(z, mu, std)
¶
Inverse of :func:standardize: z * std + mu.