Quick Start¶
This page walks through the most common use cases.
Enable 64-bit precision¶
JAX defaults to 32-bit floats. For spectral methods the extra precision matters:
1D Fourier Differentiation¶
Compute derivatives of a smooth periodic function on a uniform grid.
import jax.numpy as jnp
from spectraldiffx import FourierGrid1D, SpectralDerivative1D
# Build grid: N points on [0, 2π)
grid = FourierGrid1D.from_N_L(N=64, L=2 * jnp.pi)
deriv = SpectralDerivative1D(grid=grid)
x = grid.x
u = jnp.sin(2 * x) + 0.5 * jnp.cos(5 * x)
# First derivative
du_dx = deriv(u, order=1)
# Second derivative
d2u_dx2 = deriv(u, order=2)
Spectral accuracy means machine-precision errors for smooth periodic functions — the error does not grow with the derivative order.
2D Fourier Operations¶
import jax.numpy as jnp
from spectraldiffx import FourierGrid2D, SpectralDerivative2D
grid = FourierGrid2D.from_N_L(Nx=64, Ny=64, Lx=2 * jnp.pi, Ly=2 * jnp.pi)
deriv = SpectralDerivative2D(grid=grid)
X, Y = grid.X
u = jnp.sin(X) * jnp.cos(Y)
# Partial derivatives
du_dx, du_dy = deriv.gradient(u)
# Laplacian
lap_u = deriv.laplacian(u) # should equal -2 * sin(x) * cos(y)
Chebyshev Differentiation¶
For non-periodic problems on \([-1, 1]\), use the Chebyshev basis:
import jax.numpy as jnp
from spectraldiffx import ChebyshevGrid1D, ChebyshevDerivative1D
# Chebyshev-Gauss-Lobatto nodes on [-1, 1]
grid = ChebyshevGrid1D(N=32)
deriv = ChebyshevDerivative1D(grid=grid)
x = grid.x
u = jnp.exp(x) # smooth, non-periodic
du_dx = deriv(u, order=1) # ≈ exp(x)
Nodes cluster near the boundaries, which controls the Runge phenomenon and enables high accuracy without the need for a periodic extension.
Spectral Filtering¶
Suppress aliasing or high-frequency noise with a spectral filter:
from spectraldiffx import FourierGrid1D, SpectralFilter1D
grid = FourierGrid1D.from_N_L(N=128, L=2 * jnp.pi)
filt = SpectralFilter1D(grid=grid, filter_type="exponential", order=8)
u_filtered = filt(u)
Key Concepts¶
| Concept | Description |
|---|---|
FourierGrid*D |
Uniform periodic grids; stores wavenumbers and transform helpers |
SpectralDerivative*D |
Differentiation via multiplication in spectral space |
SpectralFilter*D |
Spectral-space low-pass filters (exponential, raised cosine) |
SpectralHelmholtzSolver*D |
Solve \((\nabla^2 - \alpha)u = f\) in spectral space |
ChebyshevGrid*D |
Gauss-Lobatto nodes on \([-1,1]\); DCT-based transforms |
SphericalGrid*D |
Gauss-Legendre grids for spherical harmonic transforms |
All objects are Equinox modules, which means they are JAX pytrees and work seamlessly inside jax.jit, jax.vmap, and jax.grad.