Bases: Module
1D Chebyshev-collocation Helmholtz/Poisson solver with Dirichlet or Neumann BCs.
Solves the boundary-value problem on [−L, L]:
d²u/dx² − α·u = f(x), x ∈ [−L, L]
with boundary conditions selected via bc_type:
Dirichlet: u(+L) = bc_right, u(−L) = bc_left
Neumann: u'(+L) = bc_right, u'(−L) = bc_left
For α = 0 this reduces to Poisson.
Method — Boundary-Row Replacement
On Gauss–Lobatto nodes the endpoints x[0]=+L and x[N]=−L are collocation
points, so we discretise as
A u = b, A = D² − α·I, b = f
and then overwrite rows 0 and N with the boundary equations:
Dirichlet : row 0 ← eᵀ₀, b[0] ← bc_right
row N ← eᵀ_N, b[N] ← bc_left
Neumann : row 0 ← D[0, :], b[0] ← bc_right
row N ← D[N, :], b[N] ← bc_left
The resulting (N+1)×(N+1) linear system is solved with :func:jnp.linalg.solve.
Gauss-node grids do not include the endpoints, so this boundary-row
method is inapplicable; the constructor validates the grid and raises.
Pure Neumann + Poisson (α = 0) is only solvable up to a constant
(constant nullspace of the discretisation); the solver pins the gauge
inside the linear system by replacing one interior equation with the
point constraint u[N//2] = 0, so the solve is well-posed. Shift
the returned field by any constant if a different gauge is needed.
Attributes
grid : ChebyshevGrid1D
Must use 'gauss-lobatto' nodes.
Examples
Solve u″ = −π² sin(πx) with u(±1) = 0 (analytic solution u = sin(πx)):
import jax.numpy as jnp
grid = ChebyshevGrid1D.from_N_L(N=32, L=1.0)
solver = ChebyshevHelmholtzSolver1D(grid=grid)
x = grid.x
f = -(jnp.pi**2) * jnp.sin(jnp.pi * x)
u = solver.solve(f, alpha=0.0, bc_left=0.0, bc_right=0.0)
Neumann example — solve u″ = cos(πx) with u'(±1) = 0:
f = jnp.cos(jnp.pi * grid.x)
u = solver.solve(f, alpha=0.0, bc_type="neumann")
Source code in spectraldiffx/_src/chebyshev/solvers.py
| class ChebyshevHelmholtzSolver1D(eqx.Module):
"""1D Chebyshev-collocation Helmholtz/Poisson solver with Dirichlet or Neumann BCs.
Solves the boundary-value problem on [−L, L]:
d²u/dx² − α·u = f(x), x ∈ [−L, L]
with boundary conditions selected via ``bc_type``:
Dirichlet: u(+L) = bc_right, u(−L) = bc_left
Neumann: u'(+L) = bc_right, u'(−L) = bc_left
For α = 0 this reduces to Poisson.
Method — Boundary-Row Replacement
---------------------------------
On Gauss–Lobatto nodes the endpoints x[0]=+L and x[N]=−L are collocation
points, so we discretise as
A u = b, A = D² − α·I, b = f
and then overwrite rows 0 and N with the boundary equations:
Dirichlet : row 0 ← eᵀ₀, b[0] ← bc_right
row N ← eᵀ_N, b[N] ← bc_left
Neumann : row 0 ← D[0, :], b[0] ← bc_right
row N ← D[N, :], b[N] ← bc_left
The resulting (N+1)×(N+1) linear system is solved with :func:`jnp.linalg.solve`.
Gauss-node grids do not include the endpoints, so this boundary-row
method is inapplicable; the constructor validates the grid and raises.
Pure Neumann + Poisson (α = 0) is only solvable up to a constant
(constant nullspace of the discretisation); the solver pins the gauge
inside the linear system by replacing one interior equation with the
point constraint ``u[N//2] = 0``, so the solve is well-posed. Shift
the returned field by any constant if a different gauge is needed.
Attributes
----------
grid : ChebyshevGrid1D
Must use ``'gauss-lobatto'`` nodes.
Examples
--------
Solve u″ = −π² sin(πx) with u(±1) = 0 (analytic solution u = sin(πx)):
>>> import jax.numpy as jnp
>>> grid = ChebyshevGrid1D.from_N_L(N=32, L=1.0)
>>> solver = ChebyshevHelmholtzSolver1D(grid=grid)
>>> x = grid.x
>>> f = -(jnp.pi**2) * jnp.sin(jnp.pi * x)
>>> u = solver.solve(f, alpha=0.0, bc_left=0.0, bc_right=0.0)
Neumann example — solve u″ = cos(πx) with u'(±1) = 0:
>>> f = jnp.cos(jnp.pi * grid.x)
>>> u = solver.solve(f, alpha=0.0, bc_type="neumann")
"""
grid: ChebyshevGrid1D
def solve(
self,
f: Num[Array, "Npts"],
alpha: float = 0.0,
bc_left: float = 0.0,
bc_right: float = 0.0,
bc_type: BCType = "dirichlet",
) -> Float[Array, "Npts"]:
"""Solve (d²/dx² − α) u = f on [−L, L] with Dirichlet or Neumann BCs.
Parameters
----------
f : Num[Array, "Npts"]
Source term sampled at the N+1 Gauss–Lobatto nodes
(ordered x[0]=+L, …, x[N]=−L).
alpha : float
Helmholtz parameter (≥ 0). α=0 gives the Poisson equation.
bc_left : float
BC value at x = −L. Dirichlet: u(−L); Neumann: u'(−L).
bc_right : float
BC value at x = +L. Dirichlet: u(+L); Neumann: u'(+L).
bc_type : {"dirichlet", "neumann"}
Boundary-condition flavour.
Returns
-------
Float[Array, "Npts"]
Solution at the N+1 GL nodes.
Raises
------
ValueError
If the grid uses Gauss nodes, the length of ``f`` is wrong,
or ``alpha < 0``.
"""
if self.grid.node_type != "gauss-lobatto":
raise ValueError(
"ChebyshevHelmholtzSolver1D requires 'gauss-lobatto' nodes — "
"the boundary-row method evaluates u (or u') at the endpoints "
"x[0]=+L and x[N]=−L, which Gauss nodes exclude. Got "
f"node_type='{self.grid.node_type}'."
)
if f.shape[0] != self.grid.N + 1:
raise ValueError(
f"f must have length N+1={self.grid.N + 1} (Gauss–Lobatto), "
f"got length {f.shape[0]}."
)
if alpha < 0:
raise ValueError(f"alpha must be >= 0, got {alpha}")
if bc_type not in ("dirichlet", "neumann"):
raise ValueError(
f"bc_type must be 'dirichlet' or 'neumann', got {bc_type!r}"
)
D = self.grid.D
N = self.grid.N
# A = D² − α·I (interior operator; boundary rows replaced below)
A = D @ D - alpha * jnp.eye(N + 1)
b = f
if bc_type == "dirichlet":
# Row 0 → u(+L) = bc_right, row N → u(−L) = bc_left
A = A.at[0, :].set(0.0).at[0, 0].set(1.0)
A = A.at[N, :].set(0.0).at[N, N].set(1.0)
else: # neumann
# Row 0 → u'(+L) = D[0,:]·u, row N → u'(−L) = D[N,:]·u
A = A.at[0, :].set(D[0, :])
A = A.at[N, :].set(D[N, :])
b = b.at[0].set(bc_right)
b = b.at[N].set(bc_left)
if bc_type == "neumann" and alpha == 0.0:
# Pure-Neumann Poisson is rank-deficient (constant nullspace:
# D²·1 = 0 and D·1 = 0, so A·1 = 0). Pin a gauge inside the
# linear system by replacing one interior equation with
# u[middle] = 0. This removes the singularity before the solve,
# making it robust across RHS / grid sizes. The user can shift
# the result by any constant afterwards if a different gauge is
# needed.
mid = N // 2
gauge_row = jnp.zeros(N + 1).at[mid].set(1.0)
A = A.at[mid, :].set(gauge_row)
b = b.at[mid].set(0.0)
return jnp.linalg.solve(A, b)
|
Functions
solve(f, alpha=0.0, bc_left=0.0, bc_right=0.0, bc_type='dirichlet')
Solve (d²/dx² − α) u = f on [−L, L] with Dirichlet or Neumann BCs.
Parameters
f : Num[Array, "Npts"]
Source term sampled at the N+1 Gauss–Lobatto nodes
(ordered x[0]=+L, …, x[N]=−L).
alpha : float
Helmholtz parameter (≥ 0). α=0 gives the Poisson equation.
bc_left : float
BC value at x = −L. Dirichlet: u(−L); Neumann: u'(−L).
bc_right : float
BC value at x = +L. Dirichlet: u(+L); Neumann: u'(+L).
bc_type : {"dirichlet", "neumann"}
Boundary-condition flavour.
Returns
Float[Array, "Npts"]
Solution at the N+1 GL nodes.
Raises
ValueError
If the grid uses Gauss nodes, the length of f is wrong,
or alpha < 0.
Source code in spectraldiffx/_src/chebyshev/solvers.py
| def solve(
self,
f: Num[Array, "Npts"],
alpha: float = 0.0,
bc_left: float = 0.0,
bc_right: float = 0.0,
bc_type: BCType = "dirichlet",
) -> Float[Array, "Npts"]:
"""Solve (d²/dx² − α) u = f on [−L, L] with Dirichlet or Neumann BCs.
Parameters
----------
f : Num[Array, "Npts"]
Source term sampled at the N+1 Gauss–Lobatto nodes
(ordered x[0]=+L, …, x[N]=−L).
alpha : float
Helmholtz parameter (≥ 0). α=0 gives the Poisson equation.
bc_left : float
BC value at x = −L. Dirichlet: u(−L); Neumann: u'(−L).
bc_right : float
BC value at x = +L. Dirichlet: u(+L); Neumann: u'(+L).
bc_type : {"dirichlet", "neumann"}
Boundary-condition flavour.
Returns
-------
Float[Array, "Npts"]
Solution at the N+1 GL nodes.
Raises
------
ValueError
If the grid uses Gauss nodes, the length of ``f`` is wrong,
or ``alpha < 0``.
"""
if self.grid.node_type != "gauss-lobatto":
raise ValueError(
"ChebyshevHelmholtzSolver1D requires 'gauss-lobatto' nodes — "
"the boundary-row method evaluates u (or u') at the endpoints "
"x[0]=+L and x[N]=−L, which Gauss nodes exclude. Got "
f"node_type='{self.grid.node_type}'."
)
if f.shape[0] != self.grid.N + 1:
raise ValueError(
f"f must have length N+1={self.grid.N + 1} (Gauss–Lobatto), "
f"got length {f.shape[0]}."
)
if alpha < 0:
raise ValueError(f"alpha must be >= 0, got {alpha}")
if bc_type not in ("dirichlet", "neumann"):
raise ValueError(
f"bc_type must be 'dirichlet' or 'neumann', got {bc_type!r}"
)
D = self.grid.D
N = self.grid.N
# A = D² − α·I (interior operator; boundary rows replaced below)
A = D @ D - alpha * jnp.eye(N + 1)
b = f
if bc_type == "dirichlet":
# Row 0 → u(+L) = bc_right, row N → u(−L) = bc_left
A = A.at[0, :].set(0.0).at[0, 0].set(1.0)
A = A.at[N, :].set(0.0).at[N, N].set(1.0)
else: # neumann
# Row 0 → u'(+L) = D[0,:]·u, row N → u'(−L) = D[N,:]·u
A = A.at[0, :].set(D[0, :])
A = A.at[N, :].set(D[N, :])
b = b.at[0].set(bc_right)
b = b.at[N].set(bc_left)
if bc_type == "neumann" and alpha == 0.0:
# Pure-Neumann Poisson is rank-deficient (constant nullspace:
# D²·1 = 0 and D·1 = 0, so A·1 = 0). Pin a gauge inside the
# linear system by replacing one interior equation with
# u[middle] = 0. This removes the singularity before the solve,
# making it robust across RHS / grid sizes. The user can shift
# the result by any constant afterwards if a different gauge is
# needed.
mid = N // 2
gauge_row = jnp.zeros(N + 1).at[mid].set(1.0)
A = A.at[mid, :].set(gauge_row)
b = b.at[mid].set(0.0)
return jnp.linalg.solve(A, b)
|