Skip to content

Chebyshev Helmholtz Solvers

ChebyshevHelmholtzSolver1D

Bases: Module

1D Chebyshev-collocation Helmholtz/Poisson solver with Dirichlet or Neumann BCs.

Solves the boundary-value problem on [−L, L]:

d²u/dx² − α·u = f(x),     x ∈ [−L, L]

with boundary conditions selected via bc_type:

Dirichlet: u(+L) = bc_right,       u(−L) = bc_left
Neumann:   u'(+L) = bc_right,     u'(−L) = bc_left

For α = 0 this reduces to Poisson.

Method — Boundary-Row Replacement

On Gauss–Lobatto nodes the endpoints x[0]=+L and x[N]=−L are collocation points, so we discretise as

A u = b,   A = D² − α·I,   b = f

and then overwrite rows 0 and N with the boundary equations:

Dirichlet : row 0 ← eᵀ₀,       b[0]  ← bc_right
            row N ← eᵀ_N,      b[N]  ← bc_left
Neumann   : row 0 ← D[0, :],   b[0]  ← bc_right
            row N ← D[N, :],   b[N]  ← bc_left

The resulting (N+1)×(N+1) linear system is solved with :func:jnp.linalg.solve.

Gauss-node grids do not include the endpoints, so this boundary-row method is inapplicable; the constructor validates the grid and raises.

Pure Neumann + Poisson (α = 0) is only solvable up to a constant (constant nullspace of the discretisation); the solver pins the gauge inside the linear system by replacing one interior equation with the point constraint u[N//2] = 0, so the solve is well-posed. Shift the returned field by any constant if a different gauge is needed.

Attributes

grid : ChebyshevGrid1D Must use 'gauss-lobatto' nodes.

Examples

Solve u″ = −π² sin(πx) with u(±1) = 0 (analytic solution u = sin(πx)):

import jax.numpy as jnp grid = ChebyshevGrid1D.from_N_L(N=32, L=1.0) solver = ChebyshevHelmholtzSolver1D(grid=grid) x = grid.x f = -(jnp.pi**2) * jnp.sin(jnp.pi * x) u = solver.solve(f, alpha=0.0, bc_left=0.0, bc_right=0.0)

Neumann example — solve u″ = cos(πx) with u'(±1) = 0:

f = jnp.cos(jnp.pi * grid.x) u = solver.solve(f, alpha=0.0, bc_type="neumann")

Source code in spectraldiffx/_src/chebyshev/solvers.py
class ChebyshevHelmholtzSolver1D(eqx.Module):
    """1D Chebyshev-collocation Helmholtz/Poisson solver with Dirichlet or Neumann BCs.

    Solves the boundary-value problem on [−L, L]:

        d²u/dx² − α·u = f(x),     x ∈ [−L, L]

    with boundary conditions selected via ``bc_type``:

        Dirichlet: u(+L) = bc_right,       u(−L) = bc_left
        Neumann:   u'(+L) = bc_right,     u'(−L) = bc_left

    For α = 0 this reduces to Poisson.

    Method — Boundary-Row Replacement
    ---------------------------------
    On Gauss–Lobatto nodes the endpoints x[0]=+L and x[N]=−L are collocation
    points, so we discretise as

        A u = b,   A = D² − α·I,   b = f

    and then overwrite rows 0 and N with the boundary equations:

        Dirichlet : row 0 ← eᵀ₀,       b[0]  ← bc_right
                    row N ← eᵀ_N,      b[N]  ← bc_left
        Neumann   : row 0 ← D[0, :],   b[0]  ← bc_right
                    row N ← D[N, :],   b[N]  ← bc_left

    The resulting (N+1)×(N+1) linear system is solved with :func:`jnp.linalg.solve`.

    Gauss-node grids do not include the endpoints, so this boundary-row
    method is inapplicable; the constructor validates the grid and raises.

    Pure Neumann + Poisson (α = 0) is only solvable up to a constant
    (constant nullspace of the discretisation); the solver pins the gauge
    inside the linear system by replacing one interior equation with the
    point constraint ``u[N//2] = 0``, so the solve is well-posed.  Shift
    the returned field by any constant if a different gauge is needed.

    Attributes
    ----------
    grid : ChebyshevGrid1D
        Must use ``'gauss-lobatto'`` nodes.

    Examples
    --------
    Solve u″ = −π² sin(πx) with u(±1) = 0 (analytic solution u = sin(πx)):

    >>> import jax.numpy as jnp
    >>> grid = ChebyshevGrid1D.from_N_L(N=32, L=1.0)
    >>> solver = ChebyshevHelmholtzSolver1D(grid=grid)
    >>> x = grid.x
    >>> f = -(jnp.pi**2) * jnp.sin(jnp.pi * x)
    >>> u = solver.solve(f, alpha=0.0, bc_left=0.0, bc_right=0.0)

    Neumann example — solve u″ = cos(πx) with u'(±1) = 0:

    >>> f = jnp.cos(jnp.pi * grid.x)
    >>> u = solver.solve(f, alpha=0.0, bc_type="neumann")
    """

    grid: ChebyshevGrid1D

    def solve(
        self,
        f: Num[Array, "Npts"],
        alpha: float = 0.0,
        bc_left: float = 0.0,
        bc_right: float = 0.0,
        bc_type: BCType = "dirichlet",
    ) -> Float[Array, "Npts"]:
        """Solve (d²/dx² − α) u = f on [−L, L] with Dirichlet or Neumann BCs.

        Parameters
        ----------
        f : Num[Array, "Npts"]
            Source term sampled at the N+1 Gauss–Lobatto nodes
            (ordered x[0]=+L, …, x[N]=−L).
        alpha : float
            Helmholtz parameter (≥ 0).  α=0 gives the Poisson equation.
        bc_left : float
            BC value at x = −L.  Dirichlet: u(−L); Neumann: u'(−L).
        bc_right : float
            BC value at x = +L.  Dirichlet: u(+L); Neumann: u'(+L).
        bc_type : {"dirichlet", "neumann"}
            Boundary-condition flavour.

        Returns
        -------
        Float[Array, "Npts"]
            Solution at the N+1 GL nodes.

        Raises
        ------
        ValueError
            If the grid uses Gauss nodes, the length of ``f`` is wrong,
            or ``alpha < 0``.
        """
        if self.grid.node_type != "gauss-lobatto":
            raise ValueError(
                "ChebyshevHelmholtzSolver1D requires 'gauss-lobatto' nodes — "
                "the boundary-row method evaluates u (or u') at the endpoints "
                "x[0]=+L and x[N]=−L, which Gauss nodes exclude. Got "
                f"node_type='{self.grid.node_type}'."
            )
        if f.shape[0] != self.grid.N + 1:
            raise ValueError(
                f"f must have length N+1={self.grid.N + 1} (Gauss–Lobatto), "
                f"got length {f.shape[0]}."
            )
        if alpha < 0:
            raise ValueError(f"alpha must be >= 0, got {alpha}")
        if bc_type not in ("dirichlet", "neumann"):
            raise ValueError(
                f"bc_type must be 'dirichlet' or 'neumann', got {bc_type!r}"
            )

        D = self.grid.D
        N = self.grid.N

        # A = D² − α·I  (interior operator; boundary rows replaced below)
        A = D @ D - alpha * jnp.eye(N + 1)
        b = f

        if bc_type == "dirichlet":
            # Row 0 → u(+L) = bc_right, row N → u(−L) = bc_left
            A = A.at[0, :].set(0.0).at[0, 0].set(1.0)
            A = A.at[N, :].set(0.0).at[N, N].set(1.0)
        else:  # neumann
            # Row 0 → u'(+L) = D[0,:]·u, row N → u'(−L) = D[N,:]·u
            A = A.at[0, :].set(D[0, :])
            A = A.at[N, :].set(D[N, :])
        b = b.at[0].set(bc_right)
        b = b.at[N].set(bc_left)

        if bc_type == "neumann" and alpha == 0.0:
            # Pure-Neumann Poisson is rank-deficient (constant nullspace:
            # D²·1 = 0 and D·1 = 0, so A·1 = 0).  Pin a gauge inside the
            # linear system by replacing one interior equation with
            # u[middle] = 0.  This removes the singularity before the solve,
            # making it robust across RHS / grid sizes.  The user can shift
            # the result by any constant afterwards if a different gauge is
            # needed.
            mid = N // 2
            gauge_row = jnp.zeros(N + 1).at[mid].set(1.0)
            A = A.at[mid, :].set(gauge_row)
            b = b.at[mid].set(0.0)

        return jnp.linalg.solve(A, b)

Functions

solve(f, alpha=0.0, bc_left=0.0, bc_right=0.0, bc_type='dirichlet')

Solve (d²/dx² − α) u = f on [−L, L] with Dirichlet or Neumann BCs.

Parameters

f : Num[Array, "Npts"] Source term sampled at the N+1 Gauss–Lobatto nodes (ordered x[0]=+L, …, x[N]=−L). alpha : float Helmholtz parameter (≥ 0). α=0 gives the Poisson equation. bc_left : float BC value at x = −L. Dirichlet: u(−L); Neumann: u'(−L). bc_right : float BC value at x = +L. Dirichlet: u(+L); Neumann: u'(+L). bc_type : {"dirichlet", "neumann"} Boundary-condition flavour.

Returns

Float[Array, "Npts"] Solution at the N+1 GL nodes.

Raises

ValueError If the grid uses Gauss nodes, the length of f is wrong, or alpha < 0.

Source code in spectraldiffx/_src/chebyshev/solvers.py
def solve(
    self,
    f: Num[Array, "Npts"],
    alpha: float = 0.0,
    bc_left: float = 0.0,
    bc_right: float = 0.0,
    bc_type: BCType = "dirichlet",
) -> Float[Array, "Npts"]:
    """Solve (d²/dx² − α) u = f on [−L, L] with Dirichlet or Neumann BCs.

    Parameters
    ----------
    f : Num[Array, "Npts"]
        Source term sampled at the N+1 Gauss–Lobatto nodes
        (ordered x[0]=+L, …, x[N]=−L).
    alpha : float
        Helmholtz parameter (≥ 0).  α=0 gives the Poisson equation.
    bc_left : float
        BC value at x = −L.  Dirichlet: u(−L); Neumann: u'(−L).
    bc_right : float
        BC value at x = +L.  Dirichlet: u(+L); Neumann: u'(+L).
    bc_type : {"dirichlet", "neumann"}
        Boundary-condition flavour.

    Returns
    -------
    Float[Array, "Npts"]
        Solution at the N+1 GL nodes.

    Raises
    ------
    ValueError
        If the grid uses Gauss nodes, the length of ``f`` is wrong,
        or ``alpha < 0``.
    """
    if self.grid.node_type != "gauss-lobatto":
        raise ValueError(
            "ChebyshevHelmholtzSolver1D requires 'gauss-lobatto' nodes — "
            "the boundary-row method evaluates u (or u') at the endpoints "
            "x[0]=+L and x[N]=−L, which Gauss nodes exclude. Got "
            f"node_type='{self.grid.node_type}'."
        )
    if f.shape[0] != self.grid.N + 1:
        raise ValueError(
            f"f must have length N+1={self.grid.N + 1} (Gauss–Lobatto), "
            f"got length {f.shape[0]}."
        )
    if alpha < 0:
        raise ValueError(f"alpha must be >= 0, got {alpha}")
    if bc_type not in ("dirichlet", "neumann"):
        raise ValueError(
            f"bc_type must be 'dirichlet' or 'neumann', got {bc_type!r}"
        )

    D = self.grid.D
    N = self.grid.N

    # A = D² − α·I  (interior operator; boundary rows replaced below)
    A = D @ D - alpha * jnp.eye(N + 1)
    b = f

    if bc_type == "dirichlet":
        # Row 0 → u(+L) = bc_right, row N → u(−L) = bc_left
        A = A.at[0, :].set(0.0).at[0, 0].set(1.0)
        A = A.at[N, :].set(0.0).at[N, N].set(1.0)
    else:  # neumann
        # Row 0 → u'(+L) = D[0,:]·u, row N → u'(−L) = D[N,:]·u
        A = A.at[0, :].set(D[0, :])
        A = A.at[N, :].set(D[N, :])
    b = b.at[0].set(bc_right)
    b = b.at[N].set(bc_left)

    if bc_type == "neumann" and alpha == 0.0:
        # Pure-Neumann Poisson is rank-deficient (constant nullspace:
        # D²·1 = 0 and D·1 = 0, so A·1 = 0).  Pin a gauge inside the
        # linear system by replacing one interior equation with
        # u[middle] = 0.  This removes the singularity before the solve,
        # making it robust across RHS / grid sizes.  The user can shift
        # the result by any constant afterwards if a different gauge is
        # needed.
        mid = N // 2
        gauge_row = jnp.zeros(N + 1).at[mid].set(1.0)
        A = A.at[mid, :].set(gauge_row)
        b = b.at[mid].set(0.0)

    return jnp.linalg.solve(A, b)