Geo encoders¶
Geophysical inputs usually arrive as longitude/latitude in degrees, while downstream neural-field and GP features typically want periodic encodings or unit-sphere coordinates. The geo encoders in pyrox.nn make those preprocessing steps first-class and composable.
The canonical spherical-harmonic pipeline is:
import equinox as eqx
from pyrox.nn import (
Cartesian3DEncoder,
Deg2Rad,
SphericalHarmonicEncoder,
)
encoder = eqx.nn.Sequential(
[
Deg2Rad(),
Cartesian3DEncoder(input_unit="radians"),
SphericalHarmonicEncoder(l_max=8, input_mode="cartesian"),
]
)
features = encoder(lonlat_deg) # (N, 81)
Cartesian3DEncoder uses the same axis convention expected by pyrox.gp.SphericalHarmonicInducingFeatures, so the NN and GP spherical paths line up. For temporal complements, see fourier_features and seasonal_features.
Stateful encoder layers¶
pyrox.nn.Deg2Rad
¶
Bases: Module
Element-wise degrees-to-radians conversion.
Stateless equinox.Module wrapper around :func:deg2rad — no
learnable parameters and no sample sites.
Example
import jax.numpy as jnp Deg2Rad()(jnp.array([0.0, 90.0, 180.0])) Array([0. , 1.5707964, 3.1415927], dtype=float32)
Composes with other encoders in eqx.nn.Sequential.¶
import equinox as eqx pipeline = eqx.nn.Sequential([Deg2Rad(), Cartesian3DEncoder()])
Source code in src/pyrox/nn/_layers.py
pyrox.nn.LonLatScale
¶
Bases: Module
Affine-rescale lon/lat columns.
Values inside the given ranges map into [-1, 1]; out-of-range
values are not clipped. The default ranges assume lonlat is
in degrees.
Attributes:
| Name | Type | Description |
|---|---|---|
lon_range |
tuple[float, float]
|
|
lat_range |
tuple[float, float]
|
|
Example
import jax.numpy as jnp LonLatScale()(jnp.array([[0.0, 0.0]])) Array([[0., 0.]], dtype=float32)
Regional grid in degrees:¶
LonLatScale(lon_range=(-10.0, 10.0), lat_range=(40.0, 60.0))( ... jnp.array([[0.0, 50.0]]) ... ) Array([[0., 0.]], dtype=float32)
Source code in src/pyrox/nn/_layers.py
pyrox.nn.Cartesian3DEncoder
¶
Bases: Module
Lift lon/lat coordinates onto the unit sphere :math:S^2.
Stateless wrapper around :func:lonlat_to_cartesian3d. Uses the
same axis convention as
:class:pyrox.gp.SphericalHarmonicInducingFeatures.
Attributes:
| Name | Type | Description |
|---|---|---|
input_unit |
Literal['degrees', 'radians']
|
Whether the input is in |
Example
import jax.numpy as jnp
Prime meridian / equator → +x¶
Cartesian3DEncoder()(jnp.array([[0.0, 0.0]])) Array([[1., 0., 0.]], dtype=float32)
Degrees input:¶
Cartesian3DEncoder(input_unit="degrees")(jnp.array([[0.0, 90.0]]))[:, 2] Array([1.], dtype=float32)
Source code in src/pyrox/nn/_layers.py
pyrox.nn.CyclicEncoder
¶
Bases: Module
Encode periodic inputs as concatenated cos/sin features.
Stateless wrapper around :func:cyclic_encode.
Example
import jax.numpy as jnp CyclicEncoder()(jnp.array([0.0, jnp.pi]))[:, 0] Array([ 1., -1.], dtype=float32)
2-D input: each column encoded independently.¶
CyclicEncoder()(jnp.zeros((3, 2))).shape (3, 4)
Source code in src/pyrox/nn/_layers.py
pyrox.nn.SphericalHarmonicEncoder
¶
Bases: Module
Real spherical-harmonic features on the unit sphere.
Stateless wrapper that evaluates
:func:pyrox._basis.real_spherical_harmonics on either already-
cartesian inputs (input_mode='cartesian') or lon/lat pairs
(input_mode='lonlat', assumed in radians).
Attributes:
| Name | Type | Description |
|---|---|---|
l_max |
int
|
Maximum harmonic degree (must be |
input_mode |
Literal['cartesian', 'lonlat']
|
|
Example
import jax.numpy as jnp xyz = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) SphericalHarmonicEncoder(l_max=3)(xyz).shape (2, 16)
Lon/lat input mode (radians):¶
lonlat = jnp.array([[0.0, 0.0], [0.5 * jnp.pi, 0.0]]) SphericalHarmonicEncoder(l_max=3, input_mode="lonlat")(lonlat).shape (2, 16)
Source code in src/pyrox/nn/_layers.py
Pure-JAX helper functions¶
pyrox.nn.deg2rad(x)
¶
Convert degrees to radians element-wise.
Example
import jax.numpy as jnp deg2rad(jnp.array([0.0, 90.0, 180.0])) Array([0. , 1.5707964, 3.1415927], dtype=float32) deg2rad(jnp.array([[45.0, -90.0], [270.0, 360.0]])).shape (2, 2)
Source code in src/pyrox/nn/_geo.py
pyrox.nn.lonlat_scale(lonlat, *, lon_range=(-180.0, 180.0), lat_range=(-90.0, 90.0))
¶
Affine-rescale lon/lat columns.
Values inside the given ranges map into [-1, 1]; out-of-range
values are not clipped and map outside [-1, 1] linearly. The
default ranges assume lonlat is in degrees; pass matching
lon_range / lat_range in whatever unit you use.
Integer inputs are promoted to float32 before the affine step
so (lonlat - lower) / (upper - lower) is not computed in
integer arithmetic (which would silently round the output to
-1 / 0 / 1).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lonlat
|
Num[Array, 'N 2']
|
Longitude/latitude matrix of shape |
required |
lon_range
|
tuple[float, float]
|
|
(-180.0, 180.0)
|
lat_range
|
tuple[float, float]
|
|
(-90.0, 90.0)
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'N 2']
|
Rescaled lon/lat array of shape |
Example
import jax.numpy as jnp lonlat = jnp.array([[-180.0, -90.0], [0.0, 0.0], [180.0, 90.0]]) lonlat_scale(lonlat) Array([[-1., -1.], [ 0., 0.], [ 1., 1.]], dtype=float32)
Custom domain (e.g. a regional grid in degrees)¶
lonlat_scale( ... jnp.array([[0.0, 50.0]]), ... lon_range=(-10.0, 10.0), ... lat_range=(40.0, 60.0), ... ) Array([[0., 0.]], dtype=float32)
Source code in src/pyrox/nn/_geo.py
pyrox.nn.lonlat_to_cartesian3d(lonlat, *, input_unit='radians')
¶
Lift lon/lat coordinates onto the unit sphere.
Uses the standard parameterization
.. math::
x = \cos(\phi)\cos(\lambda), \quad
y = \cos(\phi)\sin(\lambda), \quad
z = \sin(\phi),
where lon = λ and lat = ϕ. This matches the axis
convention expected by
:class:pyrox.gp.SphericalHarmonicInducingFeatures, so the NN and
GP spherical paths line up.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lonlat
|
Float[Array, 'N 2']
|
Longitude/latitude matrix of shape |
required |
input_unit
|
Literal['degrees', 'radians']
|
Whether |
'radians'
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'N 3']
|
Unit Cartesian coordinates of shape |
Example
import jax.numpy as jnp
Prime meridian / equator → +x¶
lonlat_to_cartesian3d(jnp.array([[0.0, 0.0]])) Array([[1., 0., 0.]], dtype=float32)
90° east / equator → +y¶
lonlat_to_cartesian3d(jnp.array([[90.0, 0.0]]), input_unit="degrees") Array([[...e-08, 1.0000000e+00, 0.0000000e+00]], dtype=float32)
North pole → +z¶
lonlat_to_cartesian3d( ... jnp.array([[0.0, 0.5 * jnp.pi]]) ... )[:, 2] Array([1.], dtype=float32)
Source code in src/pyrox/nn/_geo.py
pyrox.nn.cyclic_encode(angles)
¶
Encode periodic inputs as concatenated cos/sin features.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
angles
|
Float[Array, ' N'] | Float[Array, 'N D']
|
Angle vector |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N F']
|
|
Float[Array, 'N F']
|
laid out as |
Example
import jax.numpy as jnp cyclic_encode(jnp.array([0.0, jnp.pi])) Array([[ 1.0000000e+00, 0.0000000e+00], [-1.0000000e+00, -8.7422777e-08]], dtype=float32)
Multi-dimensional input: each column is encoded independently.¶
cyclic_encode(jnp.zeros((3, 2))).shape (3, 4)
Source code in src/pyrox/nn/_geo.py
pyrox.nn.spherical_harmonic_encode(lonlat, l_max, *, input_unit='radians')
¶
Lift lon/lat to :math:S^2 and evaluate real spherical harmonics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lonlat
|
Float[Array, 'N 2']
|
Longitude/latitude matrix of shape |
required |
l_max
|
int
|
Maximum harmonic degree. |
required |
input_unit
|
Literal['degrees', 'radians']
|
Whether |
'radians'
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'N M']
|
Real spherical-harmonic features of shape |
Example
import jax.numpy as jnp lonlat = jnp.array([[0.0, 0.0], [1.5707964, 0.0]]) spherical_harmonic_encode(lonlat, l_max=3).shape (2, 16)
Pairs with the GP side for a consistent basis:¶
spherical_harmonic_encode(lonlat, l_max=0).shape (2, 1)