Conditioning¶
A conditioner is a layer c(h, z) → y that transforms an inner activation h based on a context vector z. pyrox.nn ships three concrete conditioners that cover the literature in one consistent API, plus Bayesian variants and a composite that wraps any inner network with per-layer conditioning.
Decision rubric¶
| Pattern | Use when | Cost | Where it shows up |
|---|---|---|---|
ConcatConditioner |
You want a cheap baseline; cond_dim is small. |
(C + K) · C + C per layer |
DiffeqMLP-style CNF vector fields |
AffineModulation (FiLM) |
Feature-wise modulation is enough; you want low generator cost regardless of cond_dim. |
2C · K + 2C per layer |
π-GAN, Modulated SIREN, conditional INRs |
HyperLinear |
You need the full target weight matrix to depend on z; want NIF/MetaSDF. |
K · (C·C_in + C) per layer |
NIF (Pan et al. 2023), MetaSDF, Ha et al. hypernets |
The Bayesian variants put Normal(0, prior_std) priors on the generator weights only. Posterior cost scales with the generator size, not the target network — that's the whole point of Bayesian amortised inference.
End-to-end use¶
import jax.random as jr
import jax.numpy as jnp
from pyrox.nn import SIREN, AffineModulation, ConditionedINR, HyperSIREN
key = jr.key(0)
# 1) FiLM-modulate every hidden layer of a SIREN
inner = SIREN.init(2, 32, 1, depth=4, key=key)
wrapped = ConditionedINR.init(
inner, conditioner_cls=AffineModulation, cond_dim=4, key=key
)
y = wrapped(jnp.ones((10, 2)), jnp.ones((10, 4))) # (10, 1)
# 2) Full NIF stack — ParameterNet → per-layer HyperLinear → ShapeNet (SIREN)
import equinox as eqx
class IdentityNet(eqx.Module):
def __call__(self, mu): return mu
nif = HyperSIREN(
in_features=2, hidden_features=32, out_features=1,
depth=5, cond_dim=3, parameter_net=IdentityNet(), key=key,
)
y = nif(jnp.ones((10, 2)), jnp.ones((3,))) # (10, 1)
For a hands-on walkthrough see the Conditional Neural Fields notebook.
Protocol¶
pyrox.nn.AbstractConditioner
¶
Bases: PyroxModule
Duck-typed protocol for (h, z) -> y conditioning layers.
Concrete subclasses share the contract __call__(h, z) -> Array
where h.shape[-1] == num_features and z.shape[-1] == cond_dim.
There is no abstractmethod enforcement — subclasses simply
implement __call__ decorated with :func:pyrox_method.
Attributes:
| Name | Type | Description |
|---|---|---|
num_features |
int
|
Output channel count, matching |
cond_dim |
int
|
Latent / context dimension, matching |
Source code in src/pyrox/nn/_conditioning.py
Concrete conditioners¶
pyrox.nn.ConcatConditioner
¶
Bases: AbstractConditioner
Concatenate h and z then apply a single Linear.
Cheapest, most expressive in principle, but parameter count grows
linearly with cond_dim: (num_features + cond_dim) * num_features
+ num_features (the bias). No init ceremony required — uses
eqx.nn.Linear defaults.
Attributes:
| Name | Type | Description |
|---|---|---|
proj |
Linear
|
Linear projection |
num_features |
int
|
Output channels |
cond_dim |
int
|
Context dimension |
pyrox_name |
str | None
|
Optional explicit scope name for NumPyro. |
Example
import jax.random as jr, jax.numpy as jnp layer = ConcatConditioner.init(num_features=8, cond_dim=4, key=jr.key(0)) y = layer(jnp.ones((5, 8)), jnp.ones((5, 4))) y.shape (5, 8)
Source code in src/pyrox/nn/_conditioning.py
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | |
init(num_features, cond_dim, *, key, pyrox_name=None)
classmethod
¶
Build a :class:ConcatConditioner with default eqx.nn.Linear init.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_features
|
int
|
Output channel count. |
required |
cond_dim
|
int
|
Context dimension. |
required |
key
|
PRNGKeyArray
|
PRNG key for the projection's init. |
required |
pyrox_name
|
str | None
|
Optional explicit scope name. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
Initialised |
ConcatConditioner
|
class: |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Source code in src/pyrox/nn/_conditioning.py
pyrox.nn.AffineModulation
¶
Bases: AbstractConditioner
Feature-wise Linear Modulation (FiLM): y = γ(z) ⊙ h + β(z).
A single eqx.nn.Linear of output size 2 * num_features
produces the concatenated (raw_β, raw_γ) from the context vector.
The two halves are split on the feature axis via
:func:einops.rearrange (no raw jnp.split), then γ is passed
through the chosen activation:
"one_plus_tanh"(default):γ = 1 + tanh(raw_γ)— identity at init when the generator's bias is zero. The choice that gives FiLM its "does nothing until trained" property."exp":γ = exp(raw_γ)— strictly positive, required for bijection use. In this mode :meth:log_detreturnssum(raw_γ, axis=-1), the closed-form log-Jacobian of an element-wise scale."softplus":γ = softplus(raw_γ)— strictly positive, slower to leave the prior thanexp."identity":γ = raw_γ— no shape guarantee, rarely useful.
Attributes:
| Name | Type | Description |
|---|---|---|
generator |
Linear
|
Linear |
num_features |
int
|
Output channels |
cond_dim |
int
|
Context dimension |
gamma_activation |
GammaActivation
|
Parameterisation of |
pyrox_name |
str | None
|
Optional explicit scope name for NumPyro. |
Example
import jax.random as jr, jax.numpy as jnp film = AffineModulation.init(num_features=8, cond_dim=4, key=jr.key(0)) y = film(jnp.ones((5, 8)), jnp.ones((5, 4))) y.shape (5, 8)
Source code in src/pyrox/nn/_conditioning.py
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 | |
init(num_features, cond_dim, *, key, gamma_activation='one_plus_tanh', pyrox_name=None)
classmethod
¶
Build :class:AffineModulation with the default 2-output Linear generator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_features
|
int
|
Output channel count. |
required |
cond_dim
|
int
|
Context dimension. |
required |
key
|
PRNGKeyArray
|
PRNG key for the generator's init. |
required |
gamma_activation
|
GammaActivation
|
Parameterisation of |
'one_plus_tanh'
|
pyrox_name
|
str | None
|
Optional explicit scope name. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
Initialised |
AffineModulation
|
class: |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Source code in src/pyrox/nn/_conditioning.py
log_det(z)
¶
Sum of log γ across the feature axis.
Only valid when gamma_activation="exp" — that's the only
parameterisation for which log γ = raw_γ exactly. For other
modes this raises :class:NotImplementedError; callers that need
a generic Jacobian must compute it manually.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
z
|
Float[Array, '*batch K'] | Float[Array, ' K']
|
Context array of shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ' *batch']
|
Log-determinant of the diagonal scaling, shape |
Float[Array, ' *batch']
|
scalar for 1-D |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
If |
Source code in src/pyrox/nn/_conditioning.py
pyrox.nn.FiLM = AffineModulation
module-attribute
¶
pyrox.nn.HyperLinear
¶
Bases: AbstractConditioner
Generate a target Linear's (W, b) from z, then apply.
A single eqx.nn.Linear of output size target_out * target_in +
target_out produces the flat parameter vector for an ad-hoc linear
layer; W and b are split out via :func:einops.rearrange.
The forward dispatches on z.ndim:
z.shape == (K,)— shared path: one(W, b)generated and reused across every row ofx. Cheap (one small affine + one matmul).z.shape == (N, K)— per-sample path:(W, b)generated for each row, applied viaeinops.einsum. CostsN * C * C_inflops per call.
The generator weight scale is multiplied by init_scale so the
generated W magnitude starts small and the composite is near-zero
at init. Default init_scale=0.1 matches NIF (Pan et al. 2023).
Attributes:
| Name | Type | Description |
|---|---|---|
generator |
Linear
|
Linear |
target_in |
int
|
Inner |
target_out |
int
|
Inner |
cond_dim |
int
|
Context dimension |
num_features |
int
|
Alias for |
pyrox_name |
str | None
|
Optional explicit scope name for NumPyro. |
Example
import jax.random as jr, jax.numpy as jnp hyper = HyperLinear.init( ... target_in=4, target_out=8, cond_dim=3, key=jr.key(0) ... ) y_shared = hyper(jnp.ones((6, 4)), jnp.ones((3,))) y_persample = hyper(jnp.ones((6, 4)), jnp.ones((6, 3))) (y_shared.shape, y_persample.shape) ((6, 8), (6, 8))
Source code in src/pyrox/nn/_conditioning.py
394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 | |
init(target_in, target_out, cond_dim, *, key, init_scale=0.1, pyrox_name=None)
classmethod
¶
Build a :class:HyperLinear with a small-magnitude generator init.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
target_in
|
int
|
Input dimension of the generated |
required |
target_out
|
int
|
Output dimension of the generated |
required |
cond_dim
|
int
|
Context dimension. |
required |
key
|
PRNGKeyArray
|
PRNG key for generator init. |
required |
init_scale
|
float
|
Multiplicative factor on the generator weights so
the generated |
0.1
|
pyrox_name
|
str | None
|
Optional explicit scope name. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
Initialised |
HyperLinear
|
class: |
Raises:
| Type | Description |
|---|---|
ValueError
|
If any of |
Source code in src/pyrox/nn/_conditioning.py
Bayesian variants¶
pyrox.nn.BayesianConcatConditioner
¶
Bases: AbstractConditioner
:class:ConcatConditioner with Normal priors on the projection.
Registers two NumPyro sample sites — {scope}.proj_W and
{scope}.proj_b — under Normal(0, prior_std). Total of two
sites per forward call; nothing is sampled from the inner h or
the context z.
Attributes:
| Name | Type | Description |
|---|---|---|
num_features |
int
|
Output channels. |
cond_dim |
int
|
Context dimension. |
prior_std |
float
|
Scale of the Normal priors. |
pyrox_name |
str | None
|
Optional explicit scope name for NumPyro. |
Source code in src/pyrox/nn/_conditioning.py
550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 | |
init(num_features, cond_dim, *, prior_std=1.0, pyrox_name=None)
classmethod
¶
Build a :class:BayesianConcatConditioner.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_features
|
int
|
Output channels. |
required |
cond_dim
|
int
|
Context dimension. |
required |
prior_std
|
float
|
Scale of the Normal priors. |
1.0
|
pyrox_name
|
str | None
|
Optional explicit scope name. |
None
|
Source code in src/pyrox/nn/_conditioning.py
pyrox.nn.BayesianAffineModulation
¶
Bases: AbstractConditioner
:class:AffineModulation with Normal priors on the FiLM generator.
Registers two sites — {scope}.gen_W and {scope}.gen_b —
under Normal(0, prior_std). The γ activation is fixed by
construction (default "one_plus_tanh") so the prior over the raw
generator output induces a well-defined prior over γ, β.
Attributes:
| Name | Type | Description |
|---|---|---|
num_features |
int
|
Output channels. |
cond_dim |
int
|
Context dimension. |
gamma_activation |
GammaActivation
|
Parameterisation of |
prior_std |
float
|
Scale of the Normal priors. |
pyrox_name |
str | None
|
Optional explicit scope name. |
Source code in src/pyrox/nn/_conditioning.py
631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 | |
init(num_features, cond_dim, *, gamma_activation='one_plus_tanh', prior_std=1.0, pyrox_name=None)
classmethod
¶
Build a :class:BayesianAffineModulation.
Source code in src/pyrox/nn/_conditioning.py
pyrox.nn.BayesianHyperLinear
¶
Bases: AbstractConditioner
:class:HyperLinear with Normal priors on the generator only.
Two sites: {scope}.gen_W and {scope}.gen_b. The target
weights (W_target, b_target) are generated — not sampled — so
Bayesian inference cost scales with the generator size
cond_dim * (target_out * target_in + target_out), not with the
target-network size. This is the architectural advantage of doing
Bayesian amortised inference via hypernetworks.
Attributes:
| Name | Type | Description |
|---|---|---|
target_in |
int
|
Inner |
target_out |
int
|
Inner |
cond_dim |
int
|
Context dimension |
num_features |
int
|
Alias for |
prior_std |
float
|
Scale of the Normal priors on the generator. |
pyrox_name |
str | None
|
Optional explicit scope name. |
Source code in src/pyrox/nn/_conditioning.py
717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 | |
init(target_in, target_out, cond_dim, *, prior_std=1.0, pyrox_name=None)
classmethod
¶
Build a :class:BayesianHyperLinear.
Source code in src/pyrox/nn/_conditioning.py
Spectral hyper-conditioning¶
HyperFourierFeatures is the conditional analogue of RBFFourierFeatures: instead of sampling the random Fourier features' (W, b, lengthscale) from a fixed prior, a user-supplied parameter network produces them from the context vector. ConditionedRFFNet adds a learnable linear readout — the conditional analogue of RandomKitchenSinks.
pyrox.nn.HyperFourierFeatures
¶
Bases: PyroxModule
Random Fourier features with (W, b, log_lengthscale) from a parameter net.
The deterministic counterpart :class:pyrox.nn.RBFFourierFeatures
samples its frequencies and lengthscale from priors. This layer
instead amortises them over a context vector z via a user-supplied
parameter_net:
.. math::
(W(z), b(z), \log\ell(z)) &= \text{unflatten}(\text{parameter\_net}(z)) \\
\phi(x; z) &= \sqrt{1/n_{\text{features}}}\;
\bigl[\cos(W(z)^\top x / \ell(z) + b(z)),\;
\sin(W(z)^\top x / \ell(z) + b(z))\bigr]
Two execution modes are supported:
- Shared mode (
z.ndim == 1): the parameter net runs once and the generated features are reused across all rows ofx— same efficiency trick as :class:HyperLinear's shared path. - Per-sample mode (
z.ndim == 2): a distinct(W, b, log_lengthscale)is generated per row ofzviajax.vmapand applied witheinops.einsum. This is substantially more expensive in compute and memory because the Fourier parameters are no longer shared across rows ofx, but it is required when eachxrow needs its own context.
The flat output of parameter_net(z) must have size
in_features * n_features + n_features + 1 (frequencies, phases,
log-lengthscale). init does not invoke parameter_net —
a misshapen output surfaces only on the first call.
Attributes:
| Name | Type | Description |
|---|---|---|
parameter_net |
PyroxModule | Module
|
Callable |
in_features |
int
|
Coordinate dimension ( |
n_features |
int
|
Number of frequency pairs; output dim is
|
cond_dim |
int
|
Context dimension expected by |
pyrox_name |
str | None
|
Optional explicit scope name. |
Example
import jax.random as jr, jax.numpy as jnp import equinox as eqx key = jr.key(0)
Parameter net: (cond_dim=2,) -> (1*16 + 16 + 1 = 33,)¶
pnet = eqx.nn.MLP(in_size=2, out_size=33, width_size=32, depth=2, key=key) hff = HyperFourierFeatures.init( ... parameter_net=pnet, in_features=1, n_features=16, cond_dim=2, ... ) y = hff(jnp.ones((5, 1)), jnp.ones((2,))) y.shape (5, 32)
Source code in src/pyrox/nn/_conditioning.py
1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 | |
init(*, parameter_net, in_features, n_features, cond_dim, pyrox_name=None)
classmethod
¶
Build :class:HyperFourierFeatures.
parameter_net is not invoked at construction time, so
Bayesian / numpyro-aware parameter nets that rely on
pyrox_sample work without needing a seed handler at init.
The expected output size is in_features * n_features +
n_features + 1; a mismatch surfaces as a shape error on the
first __call__.
Source code in src/pyrox/nn/_conditioning.py
pyrox.nn.ConditionedRFFNet
¶
Bases: PyroxModule
Conditional analogue of :class:pyrox.nn.RandomKitchenSinks.
Composes a :class:HyperFourierFeatures feature map with a learnable
linear readout. The full forward is
.. math::
y(x; z) = \phi(x; z)\, \beta + b_{\text{out}}
where :math:\phi(x; z) is the HyperFourierFeatures output and
(beta, b_out) are the readout's deterministic weights. For the
Bayesian variant, wrap readout in a DenseReparameterization and
move the priors there — this composite stays minimal.
Attributes:
| Name | Type | Description |
|---|---|---|
feat |
HyperFourierFeatures
|
A :class: |
readout |
Linear
|
|
pyrox_name |
str | None
|
Optional explicit scope name. |
Example
import jax.random as jr, jax.numpy as jnp import equinox as eqx key = jr.key(0) pnet = eqx.nn.MLP( ... in_size=4, out_size=1 * 32 + 32 + 1, width_size=32, depth=2, key=key, ... ) feat = HyperFourierFeatures.init( ... parameter_net=pnet, in_features=1, n_features=32, cond_dim=4, ... ) net = ConditionedRFFNet.init(feat=feat, out_features=1, key=key) y = net(jnp.zeros((10, 1)), jnp.zeros((10, 4))) y.shape (10, 1)
Source code in src/pyrox/nn/_conditioning.py
1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 | |
init(*, feat, out_features, key, pyrox_name=None)
classmethod
¶
Build :class:ConditionedRFFNet with a default linear readout.
Source code in src/pyrox/nn/_conditioning.py
Composites¶
pyrox.nn.ConditionedINR
¶
Bases: PyroxModule
Wrap an inner network's per-layer activations with conditioners.
Given an inner network exposing a layers sequence (true for
:class:pyrox.nn.SIREN and any module that holds a list of callables
named layers), :class:ConditionedINR runs the inner forward and
inserts a conditioner after each non-readout layer:
.. code:: text
z_0 = layer_0(x)
z_0 = cond_0(z_0, c)
z_1 = layer_1(z_0)
z_1 = cond_1(z_1, c)
...
y = layer_{L-1}(z_{L-2}) # readout, not conditioned
The mode="input" shortcut concatenates c to the input
before running inner — useful for inner networks that don't
expose a layers sequence (e.g. plain eqx.nn.MLP instances).
Conditioners must be AbstractConditioner instances whose
num_features matches the corresponding inner layer's output
width. The composite forward registers the union of the inner
network's sample sites and the per-layer conditioners' sites — no
site clashes because each conditioner gets a distinct pyrox_name.
Attributes:
| Name | Type | Description |
|---|---|---|
inner |
PyroxModule | Module
|
Inner network with a |
conditioners |
list[AbstractConditioner]
|
Per-layer conditioner list. Length equals
|
cond_dim |
int
|
Context dimension shared by all conditioners. |
mode |
ConditionedMode
|
|
pyrox_name |
str | None
|
Optional explicit scope for NumPyro. |
Example
import jax.random as jr, jax.numpy as jnp from pyrox.nn import SIREN key = jr.key(0) inner = SIREN.init(2, 32, 1, depth=4, key=key) wrapped = ConditionedINR.init( ... inner, conditioner_cls=AffineModulation, cond_dim=4, key=key ... ) y = wrapped(jnp.zeros((10, 2)), jnp.zeros((10, 4))) y.shape (10, 1)
Source code in src/pyrox/nn/_conditioning.py
826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 | |
init(inner, *, conditioner_cls, cond_dim, key, mode='feature', pyrox_name=None, **conditioner_kwargs)
classmethod
¶
Build a :class:ConditionedINR around inner.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inner
|
PyroxModule | Module
|
Inner network. Must have |
required |
conditioner_cls
|
type[AbstractConditioner]
|
One of :class: |
required |
cond_dim
|
int
|
Context dimension passed to each conditioner. |
required |
key
|
PRNGKeyArray
|
PRNG key, split internally for each conditioner. |
required |
mode
|
ConditionedMode
|
|
'feature'
|
pyrox_name
|
str | None
|
Optional explicit scope name. |
None
|
**conditioner_kwargs
|
object
|
Extra kwargs forwarded to each
|
{}
|
Returns:
| Name | Type | Description |
|---|---|---|
Initialised |
ConditionedINR
|
class: |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Source code in src/pyrox/nn/_conditioning.py
884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 | |
pyrox.nn.HyperSIREN(in_features, hidden_features, out_features, *, depth, cond_dim, parameter_net, key, first_omega=30.0, hidden_omega=30.0, c=6.0, init_scale=0.1)
¶
NIF-style ShapeNet/ParameterNet composite (Pan, Brunton, Kutz — JMLR 2023).
Builds a SIREN shape-net of the requested topology, then constructs a
parallel list of :class:HyperLinear generators — one per SIREN layer
— whose init_scale is calibrated per Sitzmann regime so the
expected magnitude of each generated W matches the half-width
of Sitzmann's :func:pyrox.nn._layers._siren_W_limit at init.
Without this calibration the ShapeNet's pre-activation variance is
wrong and training is unstable.
The user-supplied parameter_net runs once on mu per forward
call to produce the latent z; z then drives every per-layer
:class:HyperLinear. parameter_net must be callable with signature
(P,) -> (cond_dim,).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Coordinate dimension of the SIREN. |
required |
hidden_features
|
int
|
Hidden width. |
required |
out_features
|
int
|
Output dimension. |
required |
depth
|
int
|
SIREN depth (must be ≥ 2). |
required |
cond_dim
|
int
|
Latent dimension produced by |
required |
parameter_net
|
Module
|
User-supplied callable |
required |
key
|
PRNGKeyArray
|
PRNG key, split internally for the SIREN init and the hyper generators. |
required |
first_omega
|
float
|
First-layer |
30.0
|
hidden_omega
|
float
|
Hidden-layer |
30.0
|
c
|
float
|
SIREN Theorem-1 constant. |
6.0
|
init_scale
|
float
|
Multiplicative factor applied on top of the per-regime
calibration; default |
0.1
|
Returns:
| Type | Description |
|---|---|
_GeneratedSiren
|
A composite that takes |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Source code in src/pyrox/nn/_conditioning.py
1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 | |