Skip to content

Spherical Harmonic Solvers

SphericalPoissonSolver

Bases: Module

Spectral Poisson solver on the sphere: ∇²φ = f.

In SHT-coefficient space the mode-by-mode inversion is

φ̂(l, m) = −f̂(l, m) · [l(l+1)/R²]⁻¹    (l ≥ 1)

The l=0 mode is set to zero when zero_mean=True (default) since ∇² annihilates constants on the sphere.

Attributes

grid : SphericalGrid1D or SphericalGrid2D Underlying spherical grid.

Examples

import jax.numpy as jnp grid = SphericalGrid2D.from_N_L(Nx=32, Ny=16) solver = SphericalPoissonSolver(grid=grid) PHI, THETA = grid.X

Laplacian of cos(θ) is −2 cos(θ)/R², so Poisson RHS is that:

R = grid.Ly / jnp.pi f = -2.0 * jnp.cos(THETA) / R**2 phi = solver.solve(f) # ≈ cos(θ) up to an additive constant

Source code in spectraldiffx/_src/spherical/solvers.py
class SphericalPoissonSolver(eqx.Module):
    """Spectral Poisson solver on the sphere:  ∇²φ = f.

    In SHT-coefficient space the mode-by-mode inversion is

        φ̂(l, m) = −f̂(l, m) · [l(l+1)/R²]⁻¹    (l ≥ 1)

    The l=0 mode is set to zero when ``zero_mean=True`` (default) since
    ∇² annihilates constants on the sphere.

    Attributes
    ----------
    grid : SphericalGrid1D or SphericalGrid2D
        Underlying spherical grid.

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> grid = SphericalGrid2D.from_N_L(Nx=32, Ny=16)
    >>> solver = SphericalPoissonSolver(grid=grid)
    >>> PHI, THETA = grid.X
    >>> # Laplacian of cos(θ) is −2 cos(θ)/R², so Poisson RHS is that:
    >>> R = grid.Ly / jnp.pi
    >>> f = -2.0 * jnp.cos(THETA) / R**2
    >>> phi = solver.solve(f)  # ≈ cos(θ) up to an additive constant
    """

    grid: SphericalGrid1D | SphericalGrid2D

    def solve(
        self,
        f: Num[Array, ...],
        zero_mean: bool = True,
        spectral: bool = False,
    ) -> Float[Array, ...]:
        """Solve ∇²φ = f on the sphere.

        Parameters
        ----------
        f : Num[Array, ...]
            Source field.  Shape ``(N,)`` for 1D or ``(Nlat, Nlon)`` for 2D.
        zero_mean : bool
            If ``True``, pin the l=0 mode of φ to zero (gauge fix).  Required
            for well-posedness of Poisson on the sphere.
        spectral : bool
            If ``True``, ``f`` is already a DLT/SHT coefficient array.

        Returns
        -------
        Float[Array, ...]
            Solution in physical space (same shape as ``f``).
        """
        R = _sphere_radius(self.grid)
        f_hat = f if spectral else self.grid.transform(f)
        l = self.grid.l
        eigenval = l * (l + 1) / (R**2)

        if isinstance(self.grid, SphericalGrid1D):
            # Guard l=0 division; the mode is overwritten immediately below.
            denom = jnp.where(eigenval == 0.0, 1.0, eigenval)
            phi_hat = -f_hat / denom
            if zero_mean:
                phi_hat = jnp.where(l == 0.0, 0.0, phi_hat)
        else:
            denom = jnp.where(eigenval[:, None] == 0.0, 1.0, eigenval[:, None])
            phi_hat = -f_hat / denom
            if zero_mean:
                phi_hat = jnp.where(l[:, None] == 0.0, 0.0, phi_hat)

        return self.grid.transform(phi_hat, inverse=True)

Functions

solve(f, zero_mean=True, spectral=False)

Solve ∇²φ = f on the sphere.

Parameters

f : Num[Array, ...] Source field. Shape (N,) for 1D or (Nlat, Nlon) for 2D. zero_mean : bool If True, pin the l=0 mode of φ to zero (gauge fix). Required for well-posedness of Poisson on the sphere. spectral : bool If True, f is already a DLT/SHT coefficient array.

Returns

Float[Array, ...] Solution in physical space (same shape as f).

Source code in spectraldiffx/_src/spherical/solvers.py
def solve(
    self,
    f: Num[Array, ...],
    zero_mean: bool = True,
    spectral: bool = False,
) -> Float[Array, ...]:
    """Solve ∇²φ = f on the sphere.

    Parameters
    ----------
    f : Num[Array, ...]
        Source field.  Shape ``(N,)`` for 1D or ``(Nlat, Nlon)`` for 2D.
    zero_mean : bool
        If ``True``, pin the l=0 mode of φ to zero (gauge fix).  Required
        for well-posedness of Poisson on the sphere.
    spectral : bool
        If ``True``, ``f`` is already a DLT/SHT coefficient array.

    Returns
    -------
    Float[Array, ...]
        Solution in physical space (same shape as ``f``).
    """
    R = _sphere_radius(self.grid)
    f_hat = f if spectral else self.grid.transform(f)
    l = self.grid.l
    eigenval = l * (l + 1) / (R**2)

    if isinstance(self.grid, SphericalGrid1D):
        # Guard l=0 division; the mode is overwritten immediately below.
        denom = jnp.where(eigenval == 0.0, 1.0, eigenval)
        phi_hat = -f_hat / denom
        if zero_mean:
            phi_hat = jnp.where(l == 0.0, 0.0, phi_hat)
    else:
        denom = jnp.where(eigenval[:, None] == 0.0, 1.0, eigenval[:, None])
        phi_hat = -f_hat / denom
        if zero_mean:
            phi_hat = jnp.where(l[:, None] == 0.0, 0.0, phi_hat)

    return self.grid.transform(phi_hat, inverse=True)

SphericalHelmholtzSolver

Bases: Module

Spectral Helmholtz solver on the sphere: (∇² − α) φ = f.

In SHT-coefficient space:

φ̂(l, m) = −f̂(l, m) / [l(l+1)/R² + α]

Non-singular for α > 0; for α = 0 this reduces to Poisson and the l=0 gauge (zero-mean) is enforced by default.

Attributes

grid : SphericalGrid1D or SphericalGrid2D Underlying spherical grid.

Examples

import jax.numpy as jnp grid = SphericalGrid2D.from_N_L(Nx=32, Ny=16) solver = SphericalHelmholtzSolver(grid=grid) PHI, THETA = grid.X R = grid.Ly / jnp.pi alpha = 4.0

For φ = cos θ: (∇² − α) φ = (−2/R² − α) cos θ

f = (-2.0 / R**2 - alpha) * jnp.cos(THETA) phi = solver.solve(f, alpha=alpha, zero_mean=False) # ≈ cos(θ)

Source code in spectraldiffx/_src/spherical/solvers.py
class SphericalHelmholtzSolver(eqx.Module):
    """Spectral Helmholtz solver on the sphere:  (∇² − α) φ = f.

    In SHT-coefficient space:

        φ̂(l, m) = −f̂(l, m) / [l(l+1)/R² + α]

    Non-singular for α > 0; for α = 0 this reduces to Poisson and the
    l=0 gauge (zero-mean) is enforced by default.

    Attributes
    ----------
    grid : SphericalGrid1D or SphericalGrid2D
        Underlying spherical grid.

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> grid = SphericalGrid2D.from_N_L(Nx=32, Ny=16)
    >>> solver = SphericalHelmholtzSolver(grid=grid)
    >>> PHI, THETA = grid.X
    >>> R = grid.Ly / jnp.pi
    >>> alpha = 4.0
    >>> # For φ = cos θ: (∇² − α) φ = (−2/R² − α) cos θ
    >>> f = (-2.0 / R**2 - alpha) * jnp.cos(THETA)
    >>> phi = solver.solve(f, alpha=alpha, zero_mean=False)  # ≈ cos(θ)
    """

    grid: SphericalGrid1D | SphericalGrid2D

    def solve(
        self,
        f: Num[Array, ...],
        alpha: float = 0.0,
        zero_mean: bool = True,
        spectral: bool = False,
    ) -> Float[Array, ...]:
        """Solve (∇² − α) φ = f on the sphere.

        Parameters
        ----------
        f : Num[Array, ...]
            Source field (1D ``(N,)`` or 2D ``(Nlat, Nlon)``).
        alpha : float
            Helmholtz parameter (≥ 0).  α=0 falls back to Poisson.
        zero_mean : bool
            If ``True`` and α = 0, enforce the l=0 gauge by zeroing the
            mean of φ.  Ignored effectively when α > 0 (non-singular).
        spectral : bool
            If ``True``, ``f`` is already a DLT/SHT coefficient array.

        Returns
        -------
        Float[Array, ...]
            Solution in physical space.
        """
        if alpha < 0:
            raise ValueError(f"alpha must be >= 0, got {alpha}")
        R = _sphere_radius(self.grid)
        f_hat = f if spectral else self.grid.transform(f)
        l = self.grid.l
        eigenval = l * (l + 1) / (R**2)

        if isinstance(self.grid, SphericalGrid1D):
            denom = eigenval + alpha
            denom_safe = jnp.where(denom == 0.0, 1.0, denom)
            phi_hat = -f_hat / denom_safe
            if zero_mean:
                phi_hat = jnp.where(l == 0.0, 0.0, phi_hat)
        else:
            denom = eigenval[:, None] + alpha
            denom_safe = jnp.where(denom == 0.0, 1.0, denom)
            phi_hat = -f_hat / denom_safe
            if zero_mean:
                phi_hat = jnp.where(l[:, None] == 0.0, 0.0, phi_hat)

        return self.grid.transform(phi_hat, inverse=True)

Functions

solve(f, alpha=0.0, zero_mean=True, spectral=False)

Solve (∇² − α) φ = f on the sphere.

Parameters

f : Num[Array, ...] Source field (1D (N,) or 2D (Nlat, Nlon)). alpha : float Helmholtz parameter (≥ 0). α=0 falls back to Poisson. zero_mean : bool If True and α = 0, enforce the l=0 gauge by zeroing the mean of φ. Ignored effectively when α > 0 (non-singular). spectral : bool If True, f is already a DLT/SHT coefficient array.

Returns

Float[Array, ...] Solution in physical space.

Source code in spectraldiffx/_src/spherical/solvers.py
def solve(
    self,
    f: Num[Array, ...],
    alpha: float = 0.0,
    zero_mean: bool = True,
    spectral: bool = False,
) -> Float[Array, ...]:
    """Solve (∇² − α) φ = f on the sphere.

    Parameters
    ----------
    f : Num[Array, ...]
        Source field (1D ``(N,)`` or 2D ``(Nlat, Nlon)``).
    alpha : float
        Helmholtz parameter (≥ 0).  α=0 falls back to Poisson.
    zero_mean : bool
        If ``True`` and α = 0, enforce the l=0 gauge by zeroing the
        mean of φ.  Ignored effectively when α > 0 (non-singular).
    spectral : bool
        If ``True``, ``f`` is already a DLT/SHT coefficient array.

    Returns
    -------
    Float[Array, ...]
        Solution in physical space.
    """
    if alpha < 0:
        raise ValueError(f"alpha must be >= 0, got {alpha}")
    R = _sphere_radius(self.grid)
    f_hat = f if spectral else self.grid.transform(f)
    l = self.grid.l
    eigenval = l * (l + 1) / (R**2)

    if isinstance(self.grid, SphericalGrid1D):
        denom = eigenval + alpha
        denom_safe = jnp.where(denom == 0.0, 1.0, denom)
        phi_hat = -f_hat / denom_safe
        if zero_mean:
            phi_hat = jnp.where(l == 0.0, 0.0, phi_hat)
    else:
        denom = eigenval[:, None] + alpha
        denom_safe = jnp.where(denom == 0.0, 1.0, denom)
        phi_hat = -f_hat / denom_safe
        if zero_mean:
            phi_hat = jnp.where(l[:, None] == 0.0, 0.0, phi_hat)

    return self.grid.transform(phi_hat, inverse=True)