Skip to content

Geo encoders

Geophysical inputs usually arrive as longitude/latitude in degrees, while downstream neural-field and GP features typically want periodic encodings or unit-sphere coordinates. The geo encoders in pyrox.nn make those preprocessing steps first-class and composable.

The canonical spherical-harmonic pipeline is:

import equinox as eqx

from pyrox.nn import (
    Cartesian3DEncoder,
    Deg2Rad,
    SphericalHarmonicEncoder,
)

encoder = eqx.nn.Sequential(
    [
        Deg2Rad(),
        Cartesian3DEncoder(input_unit="radians"),
        SphericalHarmonicEncoder(l_max=8, input_mode="cartesian"),
    ]
)
features = encoder(lonlat_deg)  # (N, 81)

Cartesian3DEncoder uses the same axis convention expected by pyrox.gp.SphericalHarmonicInducingFeatures, so the NN and GP spherical paths line up. For temporal complements, see fourier_features and seasonal_features.

Stateful encoder layers

pyrox.nn.Deg2Rad

Bases: Module

Element-wise degrees-to-radians conversion.

Stateless equinox.Module wrapper around :func:deg2rad — no learnable parameters and no sample sites.

Example

import jax.numpy as jnp Deg2Rad()(jnp.array([0.0, 90.0, 180.0])) Array([0. , 1.5707964, 3.1415927], dtype=float32)

Composes with other encoders in eqx.nn.Sequential.

import equinox as eqx pipeline = eqx.nn.Sequential([Deg2Rad(), Cartesian3DEncoder()])

Source code in src/pyrox/nn/_layers.py
class Deg2Rad(eqx.Module):
    """Element-wise degrees-to-radians conversion.

    Stateless ``equinox.Module`` wrapper around :func:`deg2rad` — no
    learnable parameters and no sample sites.

    Example:
        >>> import jax.numpy as jnp
        >>> Deg2Rad()(jnp.array([0.0, 90.0, 180.0]))
        Array([0.       , 1.5707964, 3.1415927], dtype=float32)
        >>> # Composes with other encoders in eqx.nn.Sequential.
        >>> import equinox as eqx
        >>> pipeline = eqx.nn.Sequential([Deg2Rad(), Cartesian3DEncoder()])
    """

    def __call__(self, x: Float[Array, ...]) -> Float[Array, ...]:
        return deg2rad(x)

pyrox.nn.LonLatScale

Bases: Module

Affine-rescale lon/lat columns.

Values inside the given ranges map into [-1, 1]; out-of-range values are not clipped. The default ranges assume lonlat is in degrees.

Attributes:

Name Type Description
lon_range tuple[float, float]

(min, max) longitude domain (must satisfy min < max).

lat_range tuple[float, float]

(min, max) latitude domain (must satisfy min < max).

Example

import jax.numpy as jnp LonLatScale()(jnp.array([[0.0, 0.0]])) Array([[0., 0.]], dtype=float32)

Regional grid in degrees:

LonLatScale(lon_range=(-10.0, 10.0), lat_range=(40.0, 60.0))( ... jnp.array([[0.0, 50.0]]) ... ) Array([[0., 0.]], dtype=float32)

Source code in src/pyrox/nn/_layers.py
class LonLatScale(eqx.Module):
    """Affine-rescale lon/lat columns.

    Values inside the given ranges map into ``[-1, 1]``; out-of-range
    values are *not* clipped. The default ranges assume ``lonlat`` is
    in degrees.

    Attributes:
        lon_range: ``(min, max)`` longitude domain (must satisfy
            ``min < max``).
        lat_range: ``(min, max)`` latitude domain (must satisfy
            ``min < max``).

    Example:
        >>> import jax.numpy as jnp
        >>> LonLatScale()(jnp.array([[0.0, 0.0]]))
        Array([[0., 0.]], dtype=float32)
        >>> # Regional grid in degrees:
        >>> LonLatScale(lon_range=(-10.0, 10.0), lat_range=(40.0, 60.0))(
        ...     jnp.array([[0.0, 50.0]])
        ... )
        Array([[0., 0.]], dtype=float32)
    """

    lon_range: tuple[float, float] = eqx.field(static=True, default=(-180.0, 180.0))
    lat_range: tuple[float, float] = eqx.field(static=True, default=(-90.0, 90.0))

    def __post_init__(self) -> None:
        _validate_range(self.lon_range, name="lon_range")
        _validate_range(self.lat_range, name="lat_range")

    def __call__(self, lonlat: Num[Array, "N 2"]) -> Float[Array, "N 2"]:
        return lonlat_scale(
            lonlat,
            lon_range=self.lon_range,
            lat_range=self.lat_range,
        )

pyrox.nn.Cartesian3DEncoder

Bases: Module

Lift lon/lat coordinates onto the unit sphere :math:S^2.

Stateless wrapper around :func:lonlat_to_cartesian3d. Uses the same axis convention as :class:pyrox.gp.SphericalHarmonicInducingFeatures.

Attributes:

Name Type Description
input_unit Literal['degrees', 'radians']

Whether the input is in "degrees" or "radians".

Example

import jax.numpy as jnp

Prime meridian / equator → +x

Cartesian3DEncoder()(jnp.array([[0.0, 0.0]])) Array([[1., 0., 0.]], dtype=float32)

Degrees input:

Cartesian3DEncoder(input_unit="degrees")(jnp.array([[0.0, 90.0]]))[:, 2] Array([1.], dtype=float32)

Source code in src/pyrox/nn/_layers.py
class Cartesian3DEncoder(eqx.Module):
    """Lift lon/lat coordinates onto the unit sphere :math:`S^2`.

    Stateless wrapper around :func:`lonlat_to_cartesian3d`. Uses the
    same axis convention as
    :class:`pyrox.gp.SphericalHarmonicInducingFeatures`.

    Attributes:
        input_unit: Whether the input is in ``"degrees"`` or
            ``"radians"``.

    Example:
        >>> import jax.numpy as jnp
        >>> # Prime meridian / equator → +x
        >>> Cartesian3DEncoder()(jnp.array([[0.0, 0.0]]))
        Array([[1., 0., 0.]], dtype=float32)
        >>> # Degrees input:
        >>> Cartesian3DEncoder(input_unit="degrees")(jnp.array([[0.0, 90.0]]))[:, 2]
        Array([1.], dtype=float32)
    """

    input_unit: Literal["degrees", "radians"] = eqx.field(
        static=True, default="radians"
    )

    def __post_init__(self) -> None:
        _validate_input_unit(self.input_unit)

    def __call__(self, lonlat: Float[Array, "N 2"]) -> Float[Array, "N 3"]:
        return lonlat_to_cartesian3d(lonlat, input_unit=self.input_unit)

pyrox.nn.CyclicEncoder

Bases: Module

Encode periodic inputs as concatenated cos/sin features.

Stateless wrapper around :func:cyclic_encode.

Example

import jax.numpy as jnp CyclicEncoder()(jnp.array([0.0, jnp.pi]))[:, 0] Array([ 1., -1.], dtype=float32)

2-D input: each column encoded independently.

CyclicEncoder()(jnp.zeros((3, 2))).shape (3, 4)

Source code in src/pyrox/nn/_layers.py
class CyclicEncoder(eqx.Module):
    """Encode periodic inputs as concatenated cos/sin features.

    Stateless wrapper around :func:`cyclic_encode`.

    Example:
        >>> import jax.numpy as jnp
        >>> CyclicEncoder()(jnp.array([0.0, jnp.pi]))[:, 0]
        Array([ 1., -1.], dtype=float32)
        >>> # 2-D input: each column encoded independently.
        >>> CyclicEncoder()(jnp.zeros((3, 2))).shape
        (3, 4)
    """

    def __call__(
        self,
        angles: Float[Array, " N"] | Float[Array, "N D"],
    ) -> Float[Array, "N F"]:
        return cyclic_encode(angles)

pyrox.nn.SphericalHarmonicEncoder

Bases: Module

Real spherical-harmonic features on the unit sphere.

Stateless wrapper that evaluates :func:pyrox._basis.real_spherical_harmonics on either already- cartesian inputs (input_mode='cartesian') or lon/lat pairs (input_mode='lonlat', assumed in radians).

Attributes:

Name Type Description
l_max int

Maximum harmonic degree (must be >= 0). The output has (l_max + 1) ** 2 features.

input_mode Literal['cartesian', 'lonlat']

"cartesian" for (N, 3) unit-sphere inputs or "lonlat" for (N, 2) lon/lat pairs in radians.

Example

import jax.numpy as jnp xyz = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) SphericalHarmonicEncoder(l_max=3)(xyz).shape (2, 16)

Lon/lat input mode (radians):

lonlat = jnp.array([[0.0, 0.0], [0.5 * jnp.pi, 0.0]]) SphericalHarmonicEncoder(l_max=3, input_mode="lonlat")(lonlat).shape (2, 16)

Source code in src/pyrox/nn/_layers.py
class SphericalHarmonicEncoder(eqx.Module):
    """Real spherical-harmonic features on the unit sphere.

    Stateless wrapper that evaluates
    :func:`pyrox._basis.real_spherical_harmonics` on either already-
    cartesian inputs (``input_mode='cartesian'``) or lon/lat pairs
    (``input_mode='lonlat'``, assumed in radians).

    Attributes:
        l_max: Maximum harmonic degree (must be ``>= 0``). The output
            has ``(l_max + 1) ** 2`` features.
        input_mode: ``"cartesian"`` for ``(N, 3)`` unit-sphere inputs
            or ``"lonlat"`` for ``(N, 2)`` lon/lat pairs in radians.

    Example:
        >>> import jax.numpy as jnp
        >>> xyz = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
        >>> SphericalHarmonicEncoder(l_max=3)(xyz).shape
        (2, 16)
        >>> # Lon/lat input mode (radians):
        >>> lonlat = jnp.array([[0.0, 0.0], [0.5 * jnp.pi, 0.0]])
        >>> SphericalHarmonicEncoder(l_max=3, input_mode="lonlat")(lonlat).shape
        (2, 16)
    """

    l_max: int = eqx.field(static=True)
    input_mode: Literal["cartesian", "lonlat"] = eqx.field(
        static=True, default="cartesian"
    )

    def __post_init__(self) -> None:
        if self.l_max < 0:
            raise ValueError(f"l_max must be >= 0; got {self.l_max}.")
        if self.input_mode not in {"cartesian", "lonlat"}:
            raise ValueError(
                f"input_mode must be 'cartesian' or 'lonlat'; got {self.input_mode!r}."
            )

    @property
    def num_features(self) -> int:
        return (self.l_max + 1) ** 2

    def __call__(
        self,
        x: Float[Array, "N 3"] | Float[Array, "N 2"],
    ) -> Float[Array, "N M"]:
        if self.input_mode == "cartesian":
            if x.ndim != 2 or x.shape[-1] != 3:
                raise ValueError(
                    "x must be (N, 3) when input_mode='cartesian'; "
                    f"got shape {x.shape}."
                )
            unit_xyz = x
        else:
            if x.ndim != 2 or x.shape[-1] != 2:
                raise ValueError(
                    f"x must be (N, 2) when input_mode='lonlat'; got shape {x.shape}."
                )
            unit_xyz = lonlat_to_cartesian3d(x, input_unit="radians")
        return real_spherical_harmonics(unit_xyz, l_max=self.l_max)

Pure-JAX helper functions

pyrox.nn.deg2rad(x)

Convert degrees to radians element-wise.

Example

import jax.numpy as jnp deg2rad(jnp.array([0.0, 90.0, 180.0])) Array([0. , 1.5707964, 3.1415927], dtype=float32) deg2rad(jnp.array([[45.0, -90.0], [270.0, 360.0]])).shape (2, 2)

Source code in src/pyrox/nn/_geo.py
def deg2rad(x: Float[Array, ...]) -> Float[Array, ...]:
    r"""Convert degrees to radians element-wise.

    Example:
        >>> import jax.numpy as jnp
        >>> deg2rad(jnp.array([0.0, 90.0, 180.0]))
        Array([0.       , 1.5707964, 3.1415927], dtype=float32)
        >>> deg2rad(jnp.array([[45.0, -90.0], [270.0, 360.0]])).shape
        (2, 2)
    """
    return x * (jnp.pi / 180.0)

pyrox.nn.lonlat_scale(lonlat, *, lon_range=(-180.0, 180.0), lat_range=(-90.0, 90.0))

Affine-rescale lon/lat columns.

Values inside the given ranges map into [-1, 1]; out-of-range values are not clipped and map outside [-1, 1] linearly. The default ranges assume lonlat is in degrees; pass matching lon_range / lat_range in whatever unit you use.

Integer inputs are promoted to float32 before the affine step so (lonlat - lower) / (upper - lower) is not computed in integer arithmetic (which would silently round the output to -1 / 0 / 1).

Parameters:

Name Type Description Default
lonlat Num[Array, 'N 2']

Longitude/latitude matrix of shape (N, 2).

required
lon_range tuple[float, float]

(min, max) longitude domain (must satisfy min < max).

(-180.0, 180.0)
lat_range tuple[float, float]

(min, max) latitude domain (must satisfy min < max).

(-90.0, 90.0)

Returns:

Type Description
Float[Array, 'N 2']

Rescaled lon/lat array of shape (N, 2).

Example

import jax.numpy as jnp lonlat = jnp.array([[-180.0, -90.0], [0.0, 0.0], [180.0, 90.0]]) lonlat_scale(lonlat) Array([[-1., -1.], [ 0., 0.], [ 1., 1.]], dtype=float32)

Custom domain (e.g. a regional grid in degrees)

lonlat_scale( ... jnp.array([[0.0, 50.0]]), ... lon_range=(-10.0, 10.0), ... lat_range=(40.0, 60.0), ... ) Array([[0., 0.]], dtype=float32)

Source code in src/pyrox/nn/_geo.py
def lonlat_scale(
    lonlat: Num[Array, "N 2"],
    *,
    lon_range: tuple[float, float] = (-180.0, 180.0),
    lat_range: tuple[float, float] = (-90.0, 90.0),
) -> Float[Array, "N 2"]:
    """Affine-rescale lon/lat columns.

    Values inside the given ranges map into ``[-1, 1]``; out-of-range
    values are not clipped and map outside ``[-1, 1]`` linearly. The
    default ranges assume ``lonlat`` is in degrees; pass matching
    ``lon_range`` / ``lat_range`` in whatever unit you use.

    Integer inputs are promoted to ``float32`` before the affine step
    so ``(lonlat - lower) / (upper - lower)`` is not computed in
    integer arithmetic (which would silently round the output to
    ``-1 / 0 / 1``).

    Args:
        lonlat: Longitude/latitude matrix of shape ``(N, 2)``.
        lon_range: ``(min, max)`` longitude domain (must satisfy
            ``min < max``).
        lat_range: ``(min, max)`` latitude domain (must satisfy
            ``min < max``).

    Returns:
        Rescaled lon/lat array of shape ``(N, 2)``.

    Example:
        >>> import jax.numpy as jnp
        >>> lonlat = jnp.array([[-180.0, -90.0], [0.0, 0.0], [180.0, 90.0]])
        >>> lonlat_scale(lonlat)
        Array([[-1., -1.],
               [ 0.,  0.],
               [ 1.,  1.]], dtype=float32)
        >>> # Custom domain (e.g. a regional grid in degrees)
        >>> lonlat_scale(
        ...     jnp.array([[0.0, 50.0]]),
        ...     lon_range=(-10.0, 10.0),
        ...     lat_range=(40.0, 60.0),
        ... )
        Array([[0., 0.]], dtype=float32)
    """
    _validate_lonlat_shape(lonlat)
    _validate_range(lon_range, name="lon_range")
    _validate_range(lat_range, name="lat_range")

    lonlat = _promote_to_floating(lonlat)
    lower = jnp.asarray([lon_range[0], lat_range[0]], dtype=lonlat.dtype)
    upper = jnp.asarray([lon_range[1], lat_range[1]], dtype=lonlat.dtype)
    return 2.0 * (lonlat - lower) / (upper - lower) - 1.0

pyrox.nn.lonlat_to_cartesian3d(lonlat, *, input_unit='radians')

Lift lon/lat coordinates onto the unit sphere.

Uses the standard parameterization

.. math::

x = \cos(\phi)\cos(\lambda), \quad
y = \cos(\phi)\sin(\lambda), \quad
z = \sin(\phi),

where lon = λ and lat = ϕ. This matches the axis convention expected by :class:pyrox.gp.SphericalHarmonicInducingFeatures, so the NN and GP spherical paths line up.

Parameters:

Name Type Description Default
lonlat Float[Array, 'N 2']

Longitude/latitude matrix of shape (N, 2).

required
input_unit Literal['degrees', 'radians']

Whether lonlat is in "degrees" or "radians".

'radians'

Returns:

Type Description
Float[Array, 'N 3']

Unit Cartesian coordinates of shape (N, 3).

Example

import jax.numpy as jnp

Prime meridian / equator → +x

lonlat_to_cartesian3d(jnp.array([[0.0, 0.0]])) Array([[1., 0., 0.]], dtype=float32)

90° east / equator → +y

lonlat_to_cartesian3d(jnp.array([[90.0, 0.0]]), input_unit="degrees") Array([[...e-08, 1.0000000e+00, 0.0000000e+00]], dtype=float32)

North pole → +z

lonlat_to_cartesian3d( ... jnp.array([[0.0, 0.5 * jnp.pi]]) ... )[:, 2] Array([1.], dtype=float32)

Source code in src/pyrox/nn/_geo.py
def lonlat_to_cartesian3d(
    lonlat: Float[Array, "N 2"],
    *,
    input_unit: Literal["degrees", "radians"] = "radians",
) -> Float[Array, "N 3"]:
    r"""Lift lon/lat coordinates onto the unit sphere.

    Uses the standard parameterization

    .. math::

        x = \cos(\phi)\cos(\lambda), \quad
        y = \cos(\phi)\sin(\lambda), \quad
        z = \sin(\phi),

    where ``lon = λ`` and ``lat = ϕ``. This matches the axis
    convention expected by
    :class:`pyrox.gp.SphericalHarmonicInducingFeatures`, so the NN and
    GP spherical paths line up.

    Args:
        lonlat: Longitude/latitude matrix of shape ``(N, 2)``.
        input_unit: Whether ``lonlat`` is in ``"degrees"`` or
            ``"radians"``.

    Returns:
        Unit Cartesian coordinates of shape ``(N, 3)``.

    Example:
        >>> import jax.numpy as jnp
        >>> # Prime meridian / equator → +x
        >>> lonlat_to_cartesian3d(jnp.array([[0.0, 0.0]]))
        Array([[1., 0., 0.]], dtype=float32)
        >>> # 90° east / equator → +y
        >>> lonlat_to_cartesian3d(jnp.array([[90.0, 0.0]]), input_unit="degrees")
        Array([[...e-08, 1.0000000e+00, 0.0000000e+00]], dtype=float32)
        >>> # North pole → +z
        >>> lonlat_to_cartesian3d(
        ...     jnp.array([[0.0, 0.5 * jnp.pi]])
        ... )[:, 2]
        Array([1.], dtype=float32)
    """
    _validate_lonlat_shape(lonlat)
    _validate_input_unit(input_unit)

    angles = deg2rad(lonlat) if input_unit == "degrees" else lonlat
    lon = angles[:, 0]
    lat = angles[:, 1]
    cos_lat = jnp.cos(lat)
    return jnp.stack(
        [
            cos_lat * jnp.cos(lon),
            cos_lat * jnp.sin(lon),
            jnp.sin(lat),
        ],
        axis=-1,
    )

pyrox.nn.cyclic_encode(angles)

Encode periodic inputs as concatenated cos/sin features.

Parameters:

Name Type Description Default
angles Float[Array, ' N'] | Float[Array, 'N D']

Angle vector (N,) or matrix (N, D) in radians.

required

Returns:

Type Description
Float[Array, 'N F']

(N, 2) for vector input or (N, 2 * D) for matrix input,

Float[Array, 'N F']

laid out as [cos_0, ..., cos_{D-1}, sin_0, ..., sin_{D-1}].

Example

import jax.numpy as jnp cyclic_encode(jnp.array([0.0, jnp.pi])) Array([[ 1.0000000e+00, 0.0000000e+00], [-1.0000000e+00, -8.7422777e-08]], dtype=float32)

Multi-dimensional input: each column is encoded independently.

cyclic_encode(jnp.zeros((3, 2))).shape (3, 4)

Source code in src/pyrox/nn/_geo.py
def cyclic_encode(
    angles: Float[Array, " N"] | Float[Array, "N D"],
) -> Float[Array, "N F"]:
    """Encode periodic inputs as concatenated cos/sin features.

    Args:
        angles: Angle vector ``(N,)`` or matrix ``(N, D)`` in radians.

    Returns:
        ``(N, 2)`` for vector input or ``(N, 2 * D)`` for matrix input,
        laid out as ``[cos_0, ..., cos_{D-1}, sin_0, ..., sin_{D-1}]``.

    Example:
        >>> import jax.numpy as jnp
        >>> cyclic_encode(jnp.array([0.0, jnp.pi]))
        Array([[ 1.0000000e+00,  0.0000000e+00],
               [-1.0000000e+00, -8.7422777e-08]], dtype=float32)
        >>> # Multi-dimensional input: each column is encoded independently.
        >>> cyclic_encode(jnp.zeros((3, 2))).shape
        (3, 4)
    """
    if angles.ndim == 1:
        promoted = angles[:, None]
    elif angles.ndim == 2:
        promoted = angles
    else:
        raise ValueError(f"angles must be (N,) or (N, D); got shape {angles.shape}.")
    return jnp.concatenate([jnp.cos(promoted), jnp.sin(promoted)], axis=-1)

pyrox.nn.spherical_harmonic_encode(lonlat, l_max, *, input_unit='radians')

Lift lon/lat to :math:S^2 and evaluate real spherical harmonics.

Parameters:

Name Type Description Default
lonlat Float[Array, 'N 2']

Longitude/latitude matrix of shape (N, 2).

required
l_max int

Maximum harmonic degree.

required
input_unit Literal['degrees', 'radians']

Whether lonlat is in "degrees" or "radians".

'radians'

Returns:

Type Description
Float[Array, 'N M']

Real spherical-harmonic features of shape (N, (l_max + 1)^2).

Example

import jax.numpy as jnp lonlat = jnp.array([[0.0, 0.0], [1.5707964, 0.0]]) spherical_harmonic_encode(lonlat, l_max=3).shape (2, 16)

Pairs with the GP side for a consistent basis:

spherical_harmonic_encode(lonlat, l_max=0).shape (2, 1)

Source code in src/pyrox/nn/_geo.py
def spherical_harmonic_encode(
    lonlat: Float[Array, "N 2"],
    l_max: int,
    *,
    input_unit: Literal["degrees", "radians"] = "radians",
) -> Float[Array, "N M"]:
    """Lift lon/lat to :math:`S^2` and evaluate real spherical harmonics.

    Args:
        lonlat: Longitude/latitude matrix of shape ``(N, 2)``.
        l_max: Maximum harmonic degree.
        input_unit: Whether ``lonlat`` is in ``"degrees"`` or
            ``"radians"``.

    Returns:
        Real spherical-harmonic features of shape ``(N, (l_max + 1)^2)``.

    Example:
        >>> import jax.numpy as jnp
        >>> lonlat = jnp.array([[0.0, 0.0], [1.5707964, 0.0]])
        >>> spherical_harmonic_encode(lonlat, l_max=3).shape
        (2, 16)
        >>> # Pairs with the GP side for a consistent basis:
        >>> spherical_harmonic_encode(lonlat, l_max=0).shape
        (2, 1)
    """
    unit_xyz = lonlat_to_cartesian3d(lonlat, input_unit=input_unit)
    return real_spherical_harmonics(unit_xyz, l_max=l_max)