Skip to content

Spherical Harmonic Transform

SphericalHarmonicTransform

Bases: Module

Full 2D Spherical Harmonic Transform (SHT).

Wraps a :class:SphericalGrid2D with forward/inverse methods. The Associated Legendre Polynomial matrices are precomputed once at construction time (scipy call) and stored as JAX arrays.

Mathematical Formulation

Expansion in orthonormal spherical harmonics:

u(θ, φ) = Σₗ Σₘ û(l, m) · P̃ₗᵐ(cos θ) · e^{imφ}

Forward SHT:

Step 1 (FFT in longitude):
    ûₘ(θⱼ) = Σₖ u(θⱼ, φₖ) · e^{−imφₖ}
Step 2 (Legendre transform in colatitude):
    û(l, m) = Σⱼ wⱼ · P̃ₗᵐ(cos θⱼ) · ûₘ(θⱼ)

Inverse SHT:

Step 1 (synthesis in θ):
    ûₘ(θⱼ) = Σₗ û(l, m) · P̃ₗᵐ(cos θⱼ)
Step 2 (IFFT in φ):
    u(θⱼ, φₖ) = IFFTₖ(ûₘ(θⱼ))

Attributes

grid : SphericalGrid2D Underlying lat-lon grid carrying the ALP matrices.

Notes

The _P_lm attribute lives on :attr:grid; this class is a thin, more-semantic wrapper around :meth:SphericalGrid2D.transform.

Examples

import jax.numpy as jnp grid = SphericalGrid2D.from_N_L(Nx=32, Ny=16) sht = SphericalHarmonicTransform(grid=grid) PHI, THETA = grid.X u = jnp.sin(THETA) * jnp.cos(PHI) # = Re Y₁¹(θ, φ) up to a constant u_hat = sht.to_spectral(u) u_back = sht.from_spectral(u_hat) # ≈ u

Source code in spectraldiffx/_src/spherical/harmonics.py
class SphericalHarmonicTransform(eqx.Module):
    """Full 2D Spherical Harmonic Transform (SHT).

    Wraps a :class:`SphericalGrid2D` with forward/inverse methods.  The
    Associated Legendre Polynomial matrices are precomputed once at
    construction time (scipy call) and stored as JAX arrays.

    Mathematical Formulation
    ------------------------
    Expansion in orthonormal spherical harmonics:

        u(θ, φ) = Σₗ Σₘ û(l, m) · P̃ₗᵐ(cos θ) · e^{imφ}

    Forward SHT:

        Step 1 (FFT in longitude):
            ûₘ(θⱼ) = Σₖ u(θⱼ, φₖ) · e^{−imφₖ}
        Step 2 (Legendre transform in colatitude):
            û(l, m) = Σⱼ wⱼ · P̃ₗᵐ(cos θⱼ) · ûₘ(θⱼ)

    Inverse SHT:

        Step 1 (synthesis in θ):
            ûₘ(θⱼ) = Σₗ û(l, m) · P̃ₗᵐ(cos θⱼ)
        Step 2 (IFFT in φ):
            u(θⱼ, φₖ) = IFFTₖ(ûₘ(θⱼ))

    Attributes
    ----------
    grid : SphericalGrid2D
        Underlying lat-lon grid carrying the ALP matrices.

    Notes
    -----
    The ``_P_lm`` attribute lives on :attr:`grid`; this class is a thin,
    more-semantic wrapper around :meth:`SphericalGrid2D.transform`.

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> grid = SphericalGrid2D.from_N_L(Nx=32, Ny=16)
    >>> sht = SphericalHarmonicTransform(grid=grid)
    >>> PHI, THETA = grid.X
    >>> u = jnp.sin(THETA) * jnp.cos(PHI)  # = Re Y₁¹(θ, φ) up to a constant
    >>> u_hat = sht.to_spectral(u)
    >>> u_back = sht.from_spectral(u_hat)  # ≈ u
    """

    grid: SphericalGrid2D

    def to_spectral(self, u: Float[Array, "Nlat Nlon"]) -> Complex[Array, "Nl Nm"]:
        """Forward SHT: physical u(θ, φ) → spectral û(l, m).

            û(l, m) = Σⱼ wⱼ · P̃ₗᵐ(cos θⱼ) · FFTφ{u(θⱼ, ·)}[m]

        Parameters
        ----------
        u : Float[Array, "Nlat Nlon"]
            Real-valued physical field on the (Nlat, Nlon) lat-lon grid.

        Returns
        -------
        Complex[Array, "Nl Nm"]
            Spectral coefficients.  Rows index harmonic degree l;
            columns index zonal wavenumber m in FFT order.
        """
        return self.grid.transform(u, inverse=False)

    def from_spectral(
        self, u_hat: Complex[Array, "Nl Nm"]
    ) -> Float[Array, "Nlat Nlon"]:
        """Inverse SHT: spectral û(l, m) → physical u(θ, φ)."""
        return self.grid.transform(u_hat, inverse=True)

    def to_spectral_1d(self, u_col: Float[Array, "Nlat"]) -> Float[Array, "Nl"]:
        """1D forward Discrete Legendre Transform (zonal mean, m=0)."""
        # m=0 is at fft_idx=0
        P0 = self.grid._P_lm[0]  # (Nl, Nlat)
        w = self.grid._weights_y  # (Nlat,)
        return P0 @ (w * u_col)

    def from_spectral_1d(self, c: Float[Array, "Nl"]) -> Float[Array, "Nlat"]:
        """1D inverse DLT (zonal mean, m=0)."""
        P0 = self.grid._P_lm[0]  # (Nl, Nlat)
        return P0.T @ c

Functions

to_spectral(u)

Forward SHT: physical u(θ, φ) → spectral û(l, m).

û(l, m) = Σⱼ wⱼ · P̃ₗᵐ(cos θⱼ) · FFTφ{u(θⱼ, ·)}[m]
Parameters

u : Float[Array, "Nlat Nlon"] Real-valued physical field on the (Nlat, Nlon) lat-lon grid.

Returns

Complex[Array, "Nl Nm"] Spectral coefficients. Rows index harmonic degree l; columns index zonal wavenumber m in FFT order.

Source code in spectraldiffx/_src/spherical/harmonics.py
def to_spectral(self, u: Float[Array, "Nlat Nlon"]) -> Complex[Array, "Nl Nm"]:
    """Forward SHT: physical u(θ, φ) → spectral û(l, m).

        û(l, m) = Σⱼ wⱼ · P̃ₗᵐ(cos θⱼ) · FFTφ{u(θⱼ, ·)}[m]

    Parameters
    ----------
    u : Float[Array, "Nlat Nlon"]
        Real-valued physical field on the (Nlat, Nlon) lat-lon grid.

    Returns
    -------
    Complex[Array, "Nl Nm"]
        Spectral coefficients.  Rows index harmonic degree l;
        columns index zonal wavenumber m in FFT order.
    """
    return self.grid.transform(u, inverse=False)

from_spectral(u_hat)

Inverse SHT: spectral û(l, m) → physical u(θ, φ).

Source code in spectraldiffx/_src/spherical/harmonics.py
def from_spectral(
    self, u_hat: Complex[Array, "Nl Nm"]
) -> Float[Array, "Nlat Nlon"]:
    """Inverse SHT: spectral û(l, m) → physical u(θ, φ)."""
    return self.grid.transform(u_hat, inverse=True)

to_spectral_1d(u_col)

1D forward Discrete Legendre Transform (zonal mean, m=0).

Source code in spectraldiffx/_src/spherical/harmonics.py
def to_spectral_1d(self, u_col: Float[Array, "Nlat"]) -> Float[Array, "Nl"]:
    """1D forward Discrete Legendre Transform (zonal mean, m=0)."""
    # m=0 is at fft_idx=0
    P0 = self.grid._P_lm[0]  # (Nl, Nlat)
    w = self.grid._weights_y  # (Nlat,)
    return P0 @ (w * u_col)

from_spectral_1d(c)

1D inverse DLT (zonal mean, m=0).

Source code in spectraldiffx/_src/spherical/harmonics.py
def from_spectral_1d(self, c: Float[Array, "Nl"]) -> Float[Array, "Nlat"]:
    """1D inverse DLT (zonal mean, m=0)."""
    P0 = self.grid._P_lm[0]  # (Nl, Nlat)
    return P0.T @ c