Skip to content

Spectral Elliptic Solvers

Helmholtz, Poisson, and Laplace solvers for rectangular domains. See the theory page for the mathematical background.

Layer 0 — Pure Functions

Periodic (FFT)

1D

solve_helmholtz_fft_1d(rhs, dx, lambda_=0.0, *, approximation='fd2')

Solve (∇² − λ)ψ = f in 1-D with periodic BCs using the FFT.

Spectral algorithm:

  1. Forward FFT: f̂ = FFT(f) [N]
  2. Eigenvalues: Λ_k (FD2 or PS, see approximation) [N]
  3. Spectral division: ψ̂[k] = f̂[k] / (Λ_k − λ) [N]
  4. Inverse FFT: ψ = Re(IFFT(ψ̂)) [N]

When lambda_ == 0 the k=0 mode has Λ_0 = 0, making the denominator singular. This is handled by setting ψ̂[0] = 0 (zero-mean gauge).

Parameters

rhs : Float[Array, " N"] Right-hand side on the periodic domain. dx : float Grid spacing. lambda_ : float Helmholtz parameter λ. Default: 0.0 (Poisson). approximation : {"fd2", "spectral"} "fd2" uses finite-difference eigenvalues (exact inverse of the 3-point stencil). "spectral" uses continuous Laplacian eigenvalues (spectral accuracy for smooth solutions). Default: "fd2".

Returns

Float[Array, " N"] Solution ψ (real), same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_helmholtz_fft_1d(
    rhs: Float[Array, " N"],
    dx: float,
    lambda_: float = 0.0,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, " N"]:
    """Solve (∇² − λ)ψ = f in 1-D with periodic BCs using the FFT.

    Spectral algorithm:

    1. Forward FFT:  f̂ = FFT(f)                                    [N]
    2. Eigenvalues:  Λ_k (FD2 or PS, see *approximation*)          [N]
    3. Spectral division:  ψ̂[k] = f̂[k] / (Λ_k − λ)              [N]
    4. Inverse FFT:  ψ = Re(IFFT(ψ̂))                              [N]

    When ``lambda_ == 0`` the k=0 mode has Λ_0 = 0, making the denominator
    singular.  This is handled by setting ψ̂[0] = 0 (zero-mean gauge).

    Parameters
    ----------
    rhs : Float[Array, " N"]
        Right-hand side on the periodic domain.
    dx : float
        Grid spacing.
    lambda_ : float
        Helmholtz parameter λ.  Default: 0.0 (Poisson).
    approximation : {"fd2", "spectral"}
        ``"fd2"`` uses finite-difference eigenvalues (exact inverse of the
        3-point stencil).  ``"spectral"`` uses continuous Laplacian
        eigenvalues (spectral accuracy for smooth solutions).
        Default: ``"fd2"``.

    Returns
    -------
    Float[Array, " N"]
        Solution ψ (real), same shape as *rhs*.
    """
    (N,) = rhs.shape
    rhs_hat = jnp.fft.fft(rhs)
    eig = _eig_1d(fft_eigenvalues, fft_eigenvalues_ps, N, dx, N * dx, approximation)
    denom = eig - lambda_
    is_null = denom == 0.0
    denom_safe = jnp.where(is_null, 1.0, denom)
    psi_hat = rhs_hat / denom_safe
    psi_hat = jnp.where(is_null, 0.0, psi_hat)
    return jnp.real(jnp.fft.ifft(psi_hat))

solve_poisson_fft_1d(rhs, dx, *, approximation='fd2')

Solve ∇²ψ = f in 1-D with periodic BCs using FFT.

Convenience wrapper around :func:solve_helmholtz_fft_1d with lambda_=0.

Parameters

rhs : Float[Array, " N"] Right-hand side on the periodic domain. dx : float Grid spacing. approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, " N"] Zero-mean solution ψ (real), same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_poisson_fft_1d(
    rhs: Float[Array, " N"],
    dx: float,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, " N"]:
    """Solve ∇²ψ = f in 1-D with periodic BCs using FFT.

    Convenience wrapper around :func:`solve_helmholtz_fft_1d` with ``lambda_=0``.

    Parameters
    ----------
    rhs : Float[Array, " N"]
        Right-hand side on the periodic domain.
    dx : float
        Grid spacing.
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, " N"]
        Zero-mean solution ψ (real), same shape as *rhs*.
    """
    return solve_helmholtz_fft_1d(rhs, dx, lambda_=0.0, approximation=approximation)

2D

solve_helmholtz_fft(rhs, dx, dy, lambda_=0.0, *, approximation='fd2')

Solve (∇² − λ)ψ = f with periodic BCs using the 2-D FFT.

Spectral algorithm:

  1. Forward 2-D FFT: f̂ = FFT2(f) [Ny, Nx]
  2. Eigenvalue matrix: Λ[j,i] = Λ_j^y + Λ_i^x − λ [Ny, Nx] where Λ^x, Λ^y are FD2 or PS eigenvalues (see approximation).
  3. Spectral division: ψ̂[j,i] = f̂[j,i] / Λ[j,i] [Ny, Nx]
  4. Inverse 2-D FFT: ψ = Re(IFFT2(ψ̂)) [Ny, Nx]

When lambda_ == 0 the (0,0) mode is singular (Λ[0,0] = 0). This is handled by setting ψ̂[0,0] = 0 (zero-mean gauge).

Parameters

rhs : Float[Array, "Ny Nx"] Right-hand side on the periodic domain. dx : float Grid spacing in x. dy : float Grid spacing in y. lambda_ : float Helmholtz parameter λ. Default: 0.0 (Poisson). approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Ny Nx"] Solution ψ (real), same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_helmholtz_fft(
    rhs: Float[Array, "Ny Nx"],
    dx: float,
    dy: float,
    lambda_: float = 0.0,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Ny Nx"]:
    """Solve (∇² − λ)ψ = f with periodic BCs using the 2-D FFT.

    Spectral algorithm:

    1. Forward 2-D FFT:  f̂ = FFT2(f)                        [Ny, Nx]
    2. Eigenvalue matrix:
           Λ[j,i] = Λ_j^y + Λ_i^x − λ                      [Ny, Nx]
       where Λ^x, Λ^y are FD2 or PS eigenvalues (see *approximation*).
    3. Spectral division:  ψ̂[j,i] = f̂[j,i] / Λ[j,i]       [Ny, Nx]
    4. Inverse 2-D FFT:  ψ = Re(IFFT2(ψ̂))                  [Ny, Nx]

    When ``lambda_ == 0`` the (0,0) mode is singular (Λ[0,0] = 0).
    This is handled by setting ψ̂[0,0] = 0 (zero-mean gauge).

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Right-hand side on the periodic domain.
    dx : float
        Grid spacing in x.
    dy : float
        Grid spacing in y.
    lambda_ : float
        Helmholtz parameter λ.  Default: 0.0 (Poisson).
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Solution ψ (real), same shape as *rhs*.
    """
    Ny, Nx = rhs.shape
    rhs_hat = jnp.fft.fft2(rhs)
    eigx = _eig_1d(fft_eigenvalues, fft_eigenvalues_ps, Nx, dx, Nx * dx, approximation)
    eigy = _eig_1d(fft_eigenvalues, fft_eigenvalues_ps, Ny, dy, Ny * dy, approximation)
    eig2d = eigy[:, None] + eigx[None, :] - lambda_
    is_null = eig2d[0, 0] == 0.0
    eig2d_safe = eig2d.at[0, 0].set(jnp.where(is_null, 1.0, eig2d[0, 0]))
    psi_hat = rhs_hat / eig2d_safe
    psi_hat = psi_hat.at[0, 0].set(
        jnp.where(is_null, jnp.zeros_like(psi_hat[0, 0]), psi_hat[0, 0])
    )
    return jnp.real(jnp.fft.ifft2(psi_hat))

solve_poisson_fft(rhs, dx, dy, *, approximation='fd2')

Solve ∇²ψ = f with periodic BCs using the 2-D FFT.

Convenience wrapper around :func:solve_helmholtz_fft with lambda_=0.

Parameters

rhs : Float[Array, "Ny Nx"] Right-hand side on the periodic domain. dx : float Grid spacing in x. dy : float Grid spacing in y. approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Ny Nx"] Zero-mean solution ψ (real), same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_poisson_fft(
    rhs: Float[Array, "Ny Nx"],
    dx: float,
    dy: float,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Ny Nx"]:
    """Solve ∇²ψ = f with periodic BCs using the 2-D FFT.

    Convenience wrapper around :func:`solve_helmholtz_fft` with ``lambda_=0``.

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Right-hand side on the periodic domain.
    dx : float
        Grid spacing in x.
    dy : float
        Grid spacing in y.
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Zero-mean solution ψ (real), same shape as *rhs*.
    """
    return solve_helmholtz_fft(rhs, dx, dy, lambda_=0.0, approximation=approximation)

3D

solve_helmholtz_fft_3d(rhs, dx, dy, dz, lambda_=0.0, *, approximation='fd2')

Solve (∇² − λ)ψ = f with periodic BCs using the 3-D FFT.

Parameters

rhs : Float[Array, "Nz Ny Nx"] Right-hand side on the triply periodic domain. dx, dy, dz : float Grid spacings. lambda_ : float Helmholtz parameter λ. Default: 0.0 (Poisson). approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Nz Ny Nx"] Solution ψ (real), same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_helmholtz_fft_3d(
    rhs: Float[Array, "Nz Ny Nx"],
    dx: float,
    dy: float,
    dz: float,
    lambda_: float = 0.0,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Nz Ny Nx"]:
    """Solve (∇² − λ)ψ = f with periodic BCs using the 3-D FFT.

    Parameters
    ----------
    rhs : Float[Array, "Nz Ny Nx"]
        Right-hand side on the triply periodic domain.
    dx, dy, dz : float
        Grid spacings.
    lambda_ : float
        Helmholtz parameter λ.  Default: 0.0 (Poisson).
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Nz Ny Nx"]
        Solution ψ (real), same shape as *rhs*.
    """
    Nz, Ny, Nx = rhs.shape
    rhs_hat = jnp.fft.fftn(rhs)
    eigx = _eig_1d(fft_eigenvalues, fft_eigenvalues_ps, Nx, dx, Nx * dx, approximation)
    eigy = _eig_1d(fft_eigenvalues, fft_eigenvalues_ps, Ny, dy, Ny * dy, approximation)
    eigz = _eig_1d(fft_eigenvalues, fft_eigenvalues_ps, Nz, dz, Nz * dz, approximation)
    eig3d = eigz[:, None, None] + eigy[None, :, None] + eigx[None, None, :] - lambda_
    is_null = eig3d[0, 0, 0] == 0.0
    eig3d_safe = eig3d.at[0, 0, 0].set(jnp.where(is_null, 1.0, eig3d[0, 0, 0]))
    psi_hat = rhs_hat / eig3d_safe
    psi_hat = psi_hat.at[0, 0, 0].set(
        jnp.where(is_null, jnp.zeros_like(psi_hat[0, 0, 0]), psi_hat[0, 0, 0])
    )
    return jnp.real(jnp.fft.ifftn(psi_hat))

solve_poisson_fft_3d(rhs, dx, dy, dz, *, approximation='fd2')

Solve ∇²ψ = f with periodic BCs using the 3-D FFT.

Parameters

rhs : Float[Array, "Nz Ny Nx"] Right-hand side on the triply periodic domain. dx, dy, dz : float Grid spacings. approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Nz Ny Nx"] Zero-mean solution ψ (real), same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_poisson_fft_3d(
    rhs: Float[Array, "Nz Ny Nx"],
    dx: float,
    dy: float,
    dz: float,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Nz Ny Nx"]:
    """Solve ∇²ψ = f with periodic BCs using the 3-D FFT.

    Parameters
    ----------
    rhs : Float[Array, "Nz Ny Nx"]
        Right-hand side on the triply periodic domain.
    dx, dy, dz : float
        Grid spacings.
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Nz Ny Nx"]
        Zero-mean solution ψ (real), same shape as *rhs*.
    """
    return solve_helmholtz_fft_3d(
        rhs, dx, dy, dz, lambda_=0.0, approximation=approximation
    )

Dirichlet, Regular Grid (DST-I)

1D

solve_helmholtz_dst1_1d(rhs, dx, lambda_=0.0, *, approximation='fd2')

Solve (∇² − λ)ψ = f in 1-D with Dirichlet BCs on a regular grid (DST-I).

Parameters

rhs : Float[Array, " N"] Right-hand side at interior grid points. dx : float Grid spacing. lambda_ : float Helmholtz parameter λ. Default: 0.0 (Poisson). approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, " N"] Solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_helmholtz_dst1_1d(
    rhs: Float[Array, " N"],
    dx: float,
    lambda_: float = 0.0,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, " N"]:
    """Solve (∇² − λ)ψ = f in 1-D with Dirichlet BCs on a regular grid (DST-I).

    Parameters
    ----------
    rhs : Float[Array, " N"]
        Right-hand side at interior grid points.
    dx : float
        Grid spacing.
    lambda_ : float
        Helmholtz parameter λ.  Default: 0.0 (Poisson).
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, " N"]
        Solution ψ, same shape as *rhs*.
    """
    (N,) = rhs.shape
    rhs_hat = dstn(rhs, type=1, axes=[0])
    eig = _eig_1d(
        dst1_eigenvalues, dst1_eigenvalues_ps, N, dx, (N + 1) * dx, approximation
    )
    psi_hat = rhs_hat / (eig - lambda_)
    return idstn(psi_hat, type=1, axes=[0])

solve_poisson_dst1_1d(rhs, dx, *, approximation='fd2')

Solve ∇²ψ = f in 1-D with Dirichlet BCs on a regular grid (DST-I).

Parameters

rhs : Float[Array, " N"] Right-hand side at interior grid points. dx : float Grid spacing. approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, " N"] Solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_poisson_dst1_1d(
    rhs: Float[Array, " N"],
    dx: float,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, " N"]:
    """Solve ∇²ψ = f in 1-D with Dirichlet BCs on a regular grid (DST-I).

    Parameters
    ----------
    rhs : Float[Array, " N"]
        Right-hand side at interior grid points.
    dx : float
        Grid spacing.
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, " N"]
        Solution ψ, same shape as *rhs*.
    """
    return solve_helmholtz_dst1_1d(rhs, dx, lambda_=0.0, approximation=approximation)

2D

solve_helmholtz_dst(rhs, dx, dy, lambda_=0.0, *, approximation='fd2')

Solve (∇² − λ)ψ = f with homogeneous Dirichlet BCs using DST-I.

The input rhs lives on the interior grid (boundary values are implicitly zero: ψ = 0 on all four edges).

Parameters

rhs : Float[Array, "Ny Nx"] Right-hand side at interior grid points. dx : float Grid spacing in x. dy : float Grid spacing in y. lambda_ : float Helmholtz parameter λ. Default: 0.0 (Poisson). approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Ny Nx"] Solution ψ at interior grid points, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_helmholtz_dst(
    rhs: Float[Array, "Ny Nx"],
    dx: float,
    dy: float,
    lambda_: float = 0.0,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Ny Nx"]:
    """Solve (∇² − λ)ψ = f with homogeneous Dirichlet BCs using DST-I.

    The input *rhs* lives on the **interior** grid (boundary values are
    implicitly zero: ψ = 0 on all four edges).

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Right-hand side at interior grid points.
    dx : float
        Grid spacing in x.
    dy : float
        Grid spacing in y.
    lambda_ : float
        Helmholtz parameter λ.  Default: 0.0 (Poisson).
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Solution ψ at interior grid points, same shape as *rhs*.
    """
    Ny, Nx = rhs.shape
    rhs_hat = dstn(rhs, type=1, axes=[0, 1])
    eigx = _eig_1d(
        dst1_eigenvalues, dst1_eigenvalues_ps, Nx, dx, (Nx + 1) * dx, approximation
    )
    eigy = _eig_1d(
        dst1_eigenvalues, dst1_eigenvalues_ps, Ny, dy, (Ny + 1) * dy, approximation
    )
    eig2d = eigy[:, None] + eigx[None, :] - lambda_
    psi_hat = rhs_hat / eig2d
    return idstn(psi_hat, type=1, axes=[0, 1])

solve_poisson_dst(rhs, dx, dy, *, approximation='fd2')

Solve ∇²ψ = f with homogeneous Dirichlet BCs using DST-I.

Convenience wrapper around :func:solve_helmholtz_dst with lambda_=0.

Parameters

rhs : Float[Array, "Ny Nx"] Right-hand side at interior grid points. dx : float Grid spacing in x. dy : float Grid spacing in y. approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Ny Nx"] Solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_poisson_dst(
    rhs: Float[Array, "Ny Nx"],
    dx: float,
    dy: float,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Ny Nx"]:
    """Solve ∇²ψ = f with homogeneous Dirichlet BCs using DST-I.

    Convenience wrapper around :func:`solve_helmholtz_dst` with ``lambda_=0``.

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Right-hand side at interior grid points.
    dx : float
        Grid spacing in x.
    dy : float
        Grid spacing in y.
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Solution ψ, same shape as *rhs*.
    """
    return solve_helmholtz_dst(rhs, dx, dy, lambda_=0.0, approximation=approximation)

3D

solve_helmholtz_dst1_3d(rhs, dx, dy, dz, lambda_=0.0, *, approximation='fd2')

Solve (∇² − λ)ψ = f with Dirichlet BCs on a regular 3-D grid (DST-I).

Parameters

rhs : Float[Array, "Nz Ny Nx"] Right-hand side at interior grid points. dx, dy, dz : float Grid spacings. lambda_ : float Helmholtz parameter λ. Default: 0.0 (Poisson). approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Nz Ny Nx"] Solution ψ at interior grid points, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_helmholtz_dst1_3d(
    rhs: Float[Array, "Nz Ny Nx"],
    dx: float,
    dy: float,
    dz: float,
    lambda_: float = 0.0,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Nz Ny Nx"]:
    """Solve (∇² − λ)ψ = f with Dirichlet BCs on a regular 3-D grid (DST-I).

    Parameters
    ----------
    rhs : Float[Array, "Nz Ny Nx"]
        Right-hand side at interior grid points.
    dx, dy, dz : float
        Grid spacings.
    lambda_ : float
        Helmholtz parameter λ.  Default: 0.0 (Poisson).
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Nz Ny Nx"]
        Solution ψ at interior grid points, same shape as *rhs*.
    """
    Nz, Ny, Nx = rhs.shape
    rhs_hat = dstn(rhs, type=1, axes=[0, 1, 2])
    eigx = _eig_1d(
        dst1_eigenvalues, dst1_eigenvalues_ps, Nx, dx, (Nx + 1) * dx, approximation
    )
    eigy = _eig_1d(
        dst1_eigenvalues, dst1_eigenvalues_ps, Ny, dy, (Ny + 1) * dy, approximation
    )
    eigz = _eig_1d(
        dst1_eigenvalues, dst1_eigenvalues_ps, Nz, dz, (Nz + 1) * dz, approximation
    )
    eig3d = eigz[:, None, None] + eigy[None, :, None] + eigx[None, None, :] - lambda_
    psi_hat = rhs_hat / eig3d
    return idstn(psi_hat, type=1, axes=[0, 1, 2])

solve_poisson_dst1_3d(rhs, dx, dy, dz, *, approximation='fd2')

Solve ∇²ψ = f with Dirichlet BCs on a regular 3-D grid (DST-I).

Parameters

rhs : Float[Array, "Nz Ny Nx"] Right-hand side at interior grid points. dx, dy, dz : float Grid spacings. approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Nz Ny Nx"] Solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_poisson_dst1_3d(
    rhs: Float[Array, "Nz Ny Nx"],
    dx: float,
    dy: float,
    dz: float,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Nz Ny Nx"]:
    """Solve ∇²ψ = f with Dirichlet BCs on a regular 3-D grid (DST-I).

    Parameters
    ----------
    rhs : Float[Array, "Nz Ny Nx"]
        Right-hand side at interior grid points.
    dx, dy, dz : float
        Grid spacings.
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Nz Ny Nx"]
        Solution ψ, same shape as *rhs*.
    """
    return solve_helmholtz_dst1_3d(
        rhs, dx, dy, dz, lambda_=0.0, approximation=approximation
    )

Dirichlet, Staggered Grid (DST-II)

1D

solve_helmholtz_dst2_1d(rhs, dx, lambda_=0.0, *, approximation='fd2')

Solve (∇² − λ)ψ = f in 1-D with Dirichlet BCs on a staggered grid (DST-II).

Parameters

rhs : Float[Array, " N"] Right-hand side at cell-centred grid points. dx : float Grid spacing. lambda_ : float Helmholtz parameter λ. Default: 0.0 (Poisson). approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, " N"] Solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_helmholtz_dst2_1d(
    rhs: Float[Array, " N"],
    dx: float,
    lambda_: float = 0.0,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, " N"]:
    """Solve (∇² − λ)ψ = f in 1-D with Dirichlet BCs on a staggered grid (DST-II).

    Parameters
    ----------
    rhs : Float[Array, " N"]
        Right-hand side at cell-centred grid points.
    dx : float
        Grid spacing.
    lambda_ : float
        Helmholtz parameter λ.  Default: 0.0 (Poisson).
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, " N"]
        Solution ψ, same shape as *rhs*.
    """
    (N,) = rhs.shape
    rhs_hat = dstn(rhs, type=2, axes=[0])
    eig = _eig_1d(dst2_eigenvalues, dst2_eigenvalues_ps, N, dx, N * dx, approximation)
    psi_hat = rhs_hat / (eig - lambda_)
    return idstn(psi_hat, type=2, axes=[0])

solve_poisson_dst2_1d(rhs, dx, *, approximation='fd2')

Solve ∇²ψ = f in 1-D with Dirichlet BCs on a staggered grid (DST-II).

Parameters

rhs : Float[Array, " N"] Right-hand side at cell-centred grid points. dx : float Grid spacing. approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, " N"] Solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_poisson_dst2_1d(
    rhs: Float[Array, " N"],
    dx: float,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, " N"]:
    """Solve ∇²ψ = f in 1-D with Dirichlet BCs on a staggered grid (DST-II).

    Parameters
    ----------
    rhs : Float[Array, " N"]
        Right-hand side at cell-centred grid points.
    dx : float
        Grid spacing.
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, " N"]
        Solution ψ, same shape as *rhs*.
    """
    return solve_helmholtz_dst2_1d(rhs, dx, lambda_=0.0, approximation=approximation)

2D

solve_helmholtz_dst2(rhs, dx, dy, lambda_=0.0, *, approximation='fd2')

Solve (∇² − λ)ψ = f with Dirichlet BCs on a staggered grid (DST-II).

The input rhs lives on a cell-centred (staggered) grid; boundary values ψ = 0 are located half a grid spacing outside the first and last rows/columns.

Parameters

rhs : Float[Array, "Ny Nx"] Right-hand side at cell-centred grid points. dx : float Grid spacing in x. dy : float Grid spacing in y. lambda_ : float Helmholtz parameter λ. Default: 0.0 (Poisson). approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Ny Nx"] Solution ψ at cell-centred grid points, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_helmholtz_dst2(
    rhs: Float[Array, "Ny Nx"],
    dx: float,
    dy: float,
    lambda_: float = 0.0,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Ny Nx"]:
    """Solve (∇² − λ)ψ = f with Dirichlet BCs on a staggered grid (DST-II).

    The input *rhs* lives on a cell-centred (staggered) grid; boundary
    values ψ = 0 are located half a grid spacing outside the first and
    last rows/columns.

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Right-hand side at cell-centred grid points.
    dx : float
        Grid spacing in x.
    dy : float
        Grid spacing in y.
    lambda_ : float
        Helmholtz parameter λ.  Default: 0.0 (Poisson).
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Solution ψ at cell-centred grid points, same shape as *rhs*.
    """
    Ny, Nx = rhs.shape
    rhs_hat = dstn(rhs, type=2, axes=[0, 1])
    eigx = _eig_1d(
        dst2_eigenvalues, dst2_eigenvalues_ps, Nx, dx, Nx * dx, approximation
    )
    eigy = _eig_1d(
        dst2_eigenvalues, dst2_eigenvalues_ps, Ny, dy, Ny * dy, approximation
    )
    eig2d = eigy[:, None] + eigx[None, :] - lambda_
    psi_hat = rhs_hat / eig2d
    return idstn(psi_hat, type=2, axes=[0, 1])

solve_poisson_dst2(rhs, dx, dy, *, approximation='fd2')

Solve ∇²ψ = f with Dirichlet BCs on a staggered grid (DST-II).

Parameters

rhs : Float[Array, "Ny Nx"] Right-hand side at cell-centred grid points. dx : float Grid spacing in x. dy : float Grid spacing in y. approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Ny Nx"] Solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_poisson_dst2(
    rhs: Float[Array, "Ny Nx"],
    dx: float,
    dy: float,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Ny Nx"]:
    """Solve ∇²ψ = f with Dirichlet BCs on a staggered grid (DST-II).

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Right-hand side at cell-centred grid points.
    dx : float
        Grid spacing in x.
    dy : float
        Grid spacing in y.
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Solution ψ, same shape as *rhs*.
    """
    return solve_helmholtz_dst2(rhs, dx, dy, lambda_=0.0, approximation=approximation)

3D

solve_helmholtz_dst2_3d(rhs, dx, dy, dz, lambda_=0.0, *, approximation='fd2')

Solve (∇² − λ)ψ = f with Dirichlet BCs on a staggered 3-D grid (DST-II).

Parameters

rhs : Float[Array, "Nz Ny Nx"] Right-hand side at cell-centred grid points. dx, dy, dz : float Grid spacings. lambda_ : float Helmholtz parameter λ. Default: 0.0 (Poisson). approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Nz Ny Nx"] Solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_helmholtz_dst2_3d(
    rhs: Float[Array, "Nz Ny Nx"],
    dx: float,
    dy: float,
    dz: float,
    lambda_: float = 0.0,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Nz Ny Nx"]:
    """Solve (∇² − λ)ψ = f with Dirichlet BCs on a staggered 3-D grid (DST-II).

    Parameters
    ----------
    rhs : Float[Array, "Nz Ny Nx"]
        Right-hand side at cell-centred grid points.
    dx, dy, dz : float
        Grid spacings.
    lambda_ : float
        Helmholtz parameter λ.  Default: 0.0 (Poisson).
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Nz Ny Nx"]
        Solution ψ, same shape as *rhs*.
    """
    Nz, Ny, Nx = rhs.shape
    rhs_hat = dstn(rhs, type=2, axes=[0, 1, 2])
    eigx = _eig_1d(
        dst2_eigenvalues, dst2_eigenvalues_ps, Nx, dx, Nx * dx, approximation
    )
    eigy = _eig_1d(
        dst2_eigenvalues, dst2_eigenvalues_ps, Ny, dy, Ny * dy, approximation
    )
    eigz = _eig_1d(
        dst2_eigenvalues, dst2_eigenvalues_ps, Nz, dz, Nz * dz, approximation
    )
    eig3d = eigz[:, None, None] + eigy[None, :, None] + eigx[None, None, :] - lambda_
    psi_hat = rhs_hat / eig3d
    return idstn(psi_hat, type=2, axes=[0, 1, 2])

solve_poisson_dst2_3d(rhs, dx, dy, dz, *, approximation='fd2')

Solve ∇²ψ = f with Dirichlet BCs on a staggered 3-D grid (DST-II).

Parameters

rhs : Float[Array, "Nz Ny Nx"] Right-hand side at cell-centred grid points. dx, dy, dz : float Grid spacings. approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Nz Ny Nx"] Solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_poisson_dst2_3d(
    rhs: Float[Array, "Nz Ny Nx"],
    dx: float,
    dy: float,
    dz: float,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Nz Ny Nx"]:
    """Solve ∇²ψ = f with Dirichlet BCs on a staggered 3-D grid (DST-II).

    Parameters
    ----------
    rhs : Float[Array, "Nz Ny Nx"]
        Right-hand side at cell-centred grid points.
    dx, dy, dz : float
        Grid spacings.
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Nz Ny Nx"]
        Solution ψ, same shape as *rhs*.
    """
    return solve_helmholtz_dst2_3d(
        rhs, dx, dy, dz, lambda_=0.0, approximation=approximation
    )

Neumann, Regular Grid (DCT-I)

1D

solve_helmholtz_dct1_1d(rhs, dx, lambda_=0.0, *, approximation='fd2')

Solve (∇² − λ)ψ = f in 1-D with Neumann BCs on a regular grid (DCT-I).

The k=0 eigenvalue is zero. For λ=0 (Poisson), the null mode is projected out (zero-mean gauge).

Parameters

rhs : Float[Array, " N"] Right-hand side (including boundary grid points). dx : float Grid spacing. lambda_ : float Helmholtz parameter λ. Default: 0.0 (Poisson). approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, " N"] Solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_helmholtz_dct1_1d(
    rhs: Float[Array, " N"],
    dx: float,
    lambda_: float = 0.0,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, " N"]:
    """Solve (∇² − λ)ψ = f in 1-D with Neumann BCs on a regular grid (DCT-I).

    The k=0 eigenvalue is zero.  For λ=0 (Poisson), the null mode is
    projected out (zero-mean gauge).

    Parameters
    ----------
    rhs : Float[Array, " N"]
        Right-hand side (including boundary grid points).
    dx : float
        Grid spacing.
    lambda_ : float
        Helmholtz parameter λ.  Default: 0.0 (Poisson).
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, " N"]
        Solution ψ, same shape as *rhs*.
    """
    (N,) = rhs.shape
    rhs_hat = dctn(rhs, type=1, axes=[0])
    eig = _eig_1d(
        dct1_eigenvalues, dct1_eigenvalues_ps, N, dx, (N - 1) * dx, approximation
    )
    denom = eig - lambda_
    is_null = denom == 0.0
    denom_safe = jnp.where(is_null, 1.0, denom)
    psi_hat = rhs_hat / denom_safe
    psi_hat = jnp.where(is_null, 0.0, psi_hat)
    return idctn(psi_hat, type=1, axes=[0])

solve_poisson_dct1_1d(rhs, dx, *, approximation='fd2')

Solve ∇²ψ = f in 1-D with Neumann BCs on a regular grid (DCT-I).

Parameters

rhs : Float[Array, " N"] Right-hand side (including boundary grid points). dx : float Grid spacing. approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, " N"] Zero-mean solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_poisson_dct1_1d(
    rhs: Float[Array, " N"],
    dx: float,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, " N"]:
    """Solve ∇²ψ = f in 1-D with Neumann BCs on a regular grid (DCT-I).

    Parameters
    ----------
    rhs : Float[Array, " N"]
        Right-hand side (including boundary grid points).
    dx : float
        Grid spacing.
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, " N"]
        Zero-mean solution ψ, same shape as *rhs*.
    """
    return solve_helmholtz_dct1_1d(rhs, dx, lambda_=0.0, approximation=approximation)

2D

solve_helmholtz_dct1(rhs, dx, dy, lambda_=0.0, *, approximation='fd2')

Solve (∇² − λ)ψ = f with Neumann BCs on a regular grid (DCT-I).

When lambda_ == 0 the (0,0) mode is singular (Λ[0,0] = 0, corresponding to the constant null mode of the Neumann Laplacian). This is handled by setting ψ̂[0,0] = 0 (zero-mean gauge).

Parameters

rhs : Float[Array, "Ny Nx"] Right-hand side (including boundary grid points). dx : float Grid spacing in x. dy : float Grid spacing in y. lambda_ : float Helmholtz parameter λ. Default: 0.0 (Poisson). approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Ny Nx"] Solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_helmholtz_dct1(
    rhs: Float[Array, "Ny Nx"],
    dx: float,
    dy: float,
    lambda_: float = 0.0,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Ny Nx"]:
    """Solve (∇² − λ)ψ = f with Neumann BCs on a regular grid (DCT-I).

    When ``lambda_ == 0`` the (0,0) mode is singular (Λ[0,0] = 0,
    corresponding to the constant null mode of the Neumann Laplacian).
    This is handled by setting ψ̂[0,0] = 0 (zero-mean gauge).

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Right-hand side (including boundary grid points).
    dx : float
        Grid spacing in x.
    dy : float
        Grid spacing in y.
    lambda_ : float
        Helmholtz parameter λ.  Default: 0.0 (Poisson).
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Solution ψ, same shape as *rhs*.
    """
    Ny, Nx = rhs.shape
    rhs_hat = dctn(rhs, type=1, axes=[0, 1])
    eigx = _eig_1d(
        dct1_eigenvalues, dct1_eigenvalues_ps, Nx, dx, (Nx - 1) * dx, approximation
    )
    eigy = _eig_1d(
        dct1_eigenvalues, dct1_eigenvalues_ps, Ny, dy, (Ny - 1) * dy, approximation
    )
    eig2d = eigy[:, None] + eigx[None, :] - lambda_
    is_null = eig2d[0, 0] == 0.0
    eig2d_safe = eig2d.at[0, 0].set(jnp.where(is_null, 1.0, eig2d[0, 0]))
    psi_hat = rhs_hat / eig2d_safe
    psi_hat = psi_hat.at[0, 0].set(jnp.where(is_null, 0.0, psi_hat[0, 0]))
    return idctn(psi_hat, type=1, axes=[0, 1])

solve_poisson_dct1(rhs, dx, dy, *, approximation='fd2')

Solve ∇²ψ = f with Neumann BCs on a regular grid (DCT-I).

Parameters

rhs : Float[Array, "Ny Nx"] Right-hand side (including boundary grid points). dx : float Grid spacing in x. dy : float Grid spacing in y. approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Ny Nx"] Zero-mean solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_poisson_dct1(
    rhs: Float[Array, "Ny Nx"],
    dx: float,
    dy: float,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Ny Nx"]:
    """Solve ∇²ψ = f with Neumann BCs on a regular grid (DCT-I).

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Right-hand side (including boundary grid points).
    dx : float
        Grid spacing in x.
    dy : float
        Grid spacing in y.
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Zero-mean solution ψ, same shape as *rhs*.
    """
    return solve_helmholtz_dct1(rhs, dx, dy, lambda_=0.0, approximation=approximation)

3D

solve_helmholtz_dct1_3d(rhs, dx, dy, dz, lambda_=0.0, *, approximation='fd2')

Solve (∇² − λ)ψ = f with Neumann BCs on a regular 3-D grid (DCT-I).

Parameters

rhs : Float[Array, "Nz Ny Nx"] Right-hand side (including boundary grid points). dx, dy, dz : float Grid spacings. lambda_ : float Helmholtz parameter λ. Default: 0.0 (Poisson). approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Nz Ny Nx"] Solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_helmholtz_dct1_3d(
    rhs: Float[Array, "Nz Ny Nx"],
    dx: float,
    dy: float,
    dz: float,
    lambda_: float = 0.0,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Nz Ny Nx"]:
    """Solve (∇² − λ)ψ = f with Neumann BCs on a regular 3-D grid (DCT-I).

    Parameters
    ----------
    rhs : Float[Array, "Nz Ny Nx"]
        Right-hand side (including boundary grid points).
    dx, dy, dz : float
        Grid spacings.
    lambda_ : float
        Helmholtz parameter λ.  Default: 0.0 (Poisson).
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Nz Ny Nx"]
        Solution ψ, same shape as *rhs*.
    """
    Nz, Ny, Nx = rhs.shape
    rhs_hat = dctn(rhs, type=1, axes=[0, 1, 2])
    eigx = _eig_1d(
        dct1_eigenvalues, dct1_eigenvalues_ps, Nx, dx, (Nx - 1) * dx, approximation
    )
    eigy = _eig_1d(
        dct1_eigenvalues, dct1_eigenvalues_ps, Ny, dy, (Ny - 1) * dy, approximation
    )
    eigz = _eig_1d(
        dct1_eigenvalues, dct1_eigenvalues_ps, Nz, dz, (Nz - 1) * dz, approximation
    )
    eig3d = eigz[:, None, None] + eigy[None, :, None] + eigx[None, None, :] - lambda_
    is_null = eig3d[0, 0, 0] == 0.0
    eig3d_safe = eig3d.at[0, 0, 0].set(jnp.where(is_null, 1.0, eig3d[0, 0, 0]))
    psi_hat = rhs_hat / eig3d_safe
    psi_hat = psi_hat.at[0, 0, 0].set(jnp.where(is_null, 0.0, psi_hat[0, 0, 0]))
    return idctn(psi_hat, type=1, axes=[0, 1, 2])

solve_poisson_dct1_3d(rhs, dx, dy, dz, *, approximation='fd2')

Solve ∇²ψ = f with Neumann BCs on a regular 3-D grid (DCT-I).

Parameters

rhs : Float[Array, "Nz Ny Nx"] Right-hand side (including boundary grid points). dx, dy, dz : float Grid spacings. approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Nz Ny Nx"] Zero-mean solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_poisson_dct1_3d(
    rhs: Float[Array, "Nz Ny Nx"],
    dx: float,
    dy: float,
    dz: float,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Nz Ny Nx"]:
    """Solve ∇²ψ = f with Neumann BCs on a regular 3-D grid (DCT-I).

    Parameters
    ----------
    rhs : Float[Array, "Nz Ny Nx"]
        Right-hand side (including boundary grid points).
    dx, dy, dz : float
        Grid spacings.
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Nz Ny Nx"]
        Zero-mean solution ψ, same shape as *rhs*.
    """
    return solve_helmholtz_dct1_3d(
        rhs, dx, dy, dz, lambda_=0.0, approximation=approximation
    )

Neumann, Staggered Grid (DCT-II)

1D

solve_helmholtz_dct2_1d(rhs, dx, lambda_=0.0, *, approximation='fd2')

Solve (∇² − λ)ψ = f in 1-D with Neumann BCs on a staggered grid (DCT-II).

The k=0 eigenvalue is zero. For λ=0 (Poisson), the null mode is projected out (zero-mean gauge).

Parameters

rhs : Float[Array, " N"] Right-hand side at cell-centred grid points. dx : float Grid spacing. lambda_ : float Helmholtz parameter λ. Default: 0.0 (Poisson). approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, " N"] Solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_helmholtz_dct2_1d(
    rhs: Float[Array, " N"],
    dx: float,
    lambda_: float = 0.0,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, " N"]:
    """Solve (∇² − λ)ψ = f in 1-D with Neumann BCs on a staggered grid (DCT-II).

    The k=0 eigenvalue is zero.  For λ=0 (Poisson), the null mode is
    projected out (zero-mean gauge).

    Parameters
    ----------
    rhs : Float[Array, " N"]
        Right-hand side at cell-centred grid points.
    dx : float
        Grid spacing.
    lambda_ : float
        Helmholtz parameter λ.  Default: 0.0 (Poisson).
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, " N"]
        Solution ψ, same shape as *rhs*.
    """
    (N,) = rhs.shape
    rhs_hat = dctn(rhs, type=2, axes=[0])
    eig = _eig_1d(dct2_eigenvalues, dct2_eigenvalues_ps, N, dx, N * dx, approximation)
    denom = eig - lambda_
    is_null = denom == 0.0
    denom_safe = jnp.where(is_null, 1.0, denom)
    psi_hat = rhs_hat / denom_safe
    psi_hat = jnp.where(is_null, 0.0, psi_hat)
    return idctn(psi_hat, type=2, axes=[0])

solve_poisson_dct2_1d(rhs, dx, *, approximation='fd2')

Solve ∇²ψ = f in 1-D with Neumann BCs on a staggered grid (DCT-II).

Parameters

rhs : Float[Array, " N"] Right-hand side at cell-centred grid points. dx : float Grid spacing.

{"fd2", "spectral"}

Eigenvalue type. Default: "fd2".

Returns

Float[Array, " N"] Zero-mean solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_poisson_dct2_1d(
    rhs: Float[Array, " N"],
    dx: float,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, " N"]:
    """Solve ∇²ψ = f in 1-D with Neumann BCs on a staggered grid (DCT-II).

    Parameters
    ----------
    rhs : Float[Array, " N"]
        Right-hand side at cell-centred grid points.
    dx : float
        Grid spacing.

    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, " N"]
        Zero-mean solution ψ, same shape as *rhs*.
    """
    return solve_helmholtz_dct2_1d(rhs, dx, lambda_=0.0, approximation=approximation)

2D

solve_helmholtz_dct(rhs, dx, dy, lambda_=0.0, *, approximation='fd2')

Solve (∇² − λ)ψ = f with homogeneous Neumann BCs using DCT-II.

When lambda_ == 0 the (0,0) mode is singular (Λ[0,0] = 0, corresponding to the constant null mode of the Neumann Laplacian). This is handled by setting ψ̂[0,0] = 0 (zero-mean gauge).

Parameters

rhs : Float[Array, "Ny Nx"] Right-hand side. dx : float Grid spacing in x. dy : float Grid spacing in y. lambda_ : float Helmholtz parameter λ. Default: 0.0 (Poisson). approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Ny Nx"] Solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_helmholtz_dct(
    rhs: Float[Array, "Ny Nx"],
    dx: float,
    dy: float,
    lambda_: float = 0.0,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Ny Nx"]:
    """Solve (∇² − λ)ψ = f with homogeneous Neumann BCs using DCT-II.

    When ``lambda_ == 0`` the (0,0) mode is singular (Λ[0,0] = 0,
    corresponding to the constant null mode of the Neumann Laplacian).
    This is handled by setting ψ̂[0,0] = 0 (zero-mean gauge).

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Right-hand side.
    dx : float
        Grid spacing in x.
    dy : float
        Grid spacing in y.
    lambda_ : float
        Helmholtz parameter λ.  Default: 0.0 (Poisson).
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Solution ψ, same shape as *rhs*.
    """
    Ny, Nx = rhs.shape
    rhs_hat = dctn(rhs, type=2, axes=[0, 1])
    eigx = _eig_1d(
        dct2_eigenvalues, dct2_eigenvalues_ps, Nx, dx, Nx * dx, approximation
    )
    eigy = _eig_1d(
        dct2_eigenvalues, dct2_eigenvalues_ps, Ny, dy, Ny * dy, approximation
    )
    eig2d = eigy[:, None] + eigx[None, :] - lambda_
    is_null = eig2d[0, 0] == 0.0
    eig2d_safe = eig2d.at[0, 0].set(jnp.where(is_null, 1.0, eig2d[0, 0]))
    psi_hat = rhs_hat / eig2d_safe
    psi_hat = psi_hat.at[0, 0].set(jnp.where(is_null, 0.0, psi_hat[0, 0]))
    return idctn(psi_hat, type=2, axes=[0, 1])

solve_poisson_dct(rhs, dx, dy, *, approximation='fd2')

Solve ∇²ψ = f with homogeneous Neumann BCs using DCT-II.

The Poisson problem has a one-dimensional null space (constant solutions). This function fixes the gauge by forcing the domain-mean of ψ to zero.

Parameters

rhs : Float[Array, "Ny Nx"] Right-hand side. dx : float Grid spacing in x. dy : float Grid spacing in y. approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Ny Nx"] Zero-mean solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_poisson_dct(
    rhs: Float[Array, "Ny Nx"],
    dx: float,
    dy: float,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Ny Nx"]:
    """Solve ∇²ψ = f with homogeneous Neumann BCs using DCT-II.

    The Poisson problem has a one-dimensional null space (constant solutions).
    This function fixes the gauge by forcing the domain-mean of ψ to zero.

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Right-hand side.
    dx : float
        Grid spacing in x.
    dy : float
        Grid spacing in y.
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Zero-mean solution ψ, same shape as *rhs*.
    """
    return solve_helmholtz_dct(rhs, dx, dy, lambda_=0.0, approximation=approximation)

3D

solve_helmholtz_dct2_3d(rhs, dx, dy, dz, lambda_=0.0, *, approximation='fd2')

Solve (∇² − λ)ψ = f with Neumann BCs on a staggered 3-D grid (DCT-II).

Parameters

rhs : Float[Array, "Nz Ny Nx"] Right-hand side at cell-centred grid points. dx, dy, dz : float Grid spacings. lambda_ : float Helmholtz parameter λ. Default: 0.0 (Poisson). approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Nz Ny Nx"] Solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_helmholtz_dct2_3d(
    rhs: Float[Array, "Nz Ny Nx"],
    dx: float,
    dy: float,
    dz: float,
    lambda_: float = 0.0,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Nz Ny Nx"]:
    """Solve (∇² − λ)ψ = f with Neumann BCs on a staggered 3-D grid (DCT-II).

    Parameters
    ----------
    rhs : Float[Array, "Nz Ny Nx"]
        Right-hand side at cell-centred grid points.
    dx, dy, dz : float
        Grid spacings.
    lambda_ : float
        Helmholtz parameter λ.  Default: 0.0 (Poisson).
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Nz Ny Nx"]
        Solution ψ, same shape as *rhs*.
    """
    Nz, Ny, Nx = rhs.shape
    rhs_hat = dctn(rhs, type=2, axes=[0, 1, 2])
    eigx = _eig_1d(
        dct2_eigenvalues, dct2_eigenvalues_ps, Nx, dx, Nx * dx, approximation
    )
    eigy = _eig_1d(
        dct2_eigenvalues, dct2_eigenvalues_ps, Ny, dy, Ny * dy, approximation
    )
    eigz = _eig_1d(
        dct2_eigenvalues, dct2_eigenvalues_ps, Nz, dz, Nz * dz, approximation
    )
    eig3d = eigz[:, None, None] + eigy[None, :, None] + eigx[None, None, :] - lambda_
    is_null = eig3d[0, 0, 0] == 0.0
    eig3d_safe = eig3d.at[0, 0, 0].set(jnp.where(is_null, 1.0, eig3d[0, 0, 0]))
    psi_hat = rhs_hat / eig3d_safe
    psi_hat = psi_hat.at[0, 0, 0].set(jnp.where(is_null, 0.0, psi_hat[0, 0, 0]))
    return idctn(psi_hat, type=2, axes=[0, 1, 2])

solve_poisson_dct2_3d(rhs, dx, dy, dz, *, approximation='fd2')

Solve ∇²ψ = f with Neumann BCs on a staggered 3-D grid (DCT-II).

Parameters

rhs : Float[Array, "Nz Ny Nx"] Right-hand side at cell-centred grid points. dx, dy, dz : float Grid spacings. approximation : {"fd2", "spectral"} Eigenvalue type. Default: "fd2".

Returns

Float[Array, "Nz Ny Nx"] Zero-mean solution ψ, same shape as rhs.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_poisson_dct2_3d(
    rhs: Float[Array, "Nz Ny Nx"],
    dx: float,
    dy: float,
    dz: float,
    *,
    approximation: Approximation = "fd2",
) -> Float[Array, "Nz Ny Nx"]:
    """Solve ∇²ψ = f with Neumann BCs on a staggered 3-D grid (DCT-II).

    Parameters
    ----------
    rhs : Float[Array, "Nz Ny Nx"]
        Right-hand side at cell-centred grid points.
    dx, dy, dz : float
        Grid spacings.
    approximation : {"fd2", "spectral"}
        Eigenvalue type.  Default: ``"fd2"``.

    Returns
    -------
    Float[Array, "Nz Ny Nx"]
        Zero-mean solution ψ, same shape as *rhs*.
    """
    return solve_helmholtz_dct2_3d(
        rhs, dx, dy, dz, lambda_=0.0, approximation=approximation
    )

Mixed Per-Axis BCs (2D/3D)

2D

solve_helmholtz_2d(rhs, dx, dy, bc_x='periodic', bc_y='periodic', lambda_=0.0, *, bc_x_values=(None, None), bc_y_values=(None, None))

Solve (nabla^2 - lambda)psi = f in 2-D with per-axis boundary conditions.

Supports any combination of boundary conditions on each axis:

  • "periodic" — periodic (FFT)
  • "dirichlet" — homogeneous Dirichlet on regular (vertex) grid (DST-I)
  • "dirichlet_stag" — homogeneous Dirichlet on staggered (cell) grid (DST-II)
  • "neumann" — homogeneous Neumann on regular grid (DCT-I)
  • "neumann_stag" — homogeneous Neumann on staggered grid (DCT-II)
  • ("dirichlet_stag", "neumann_stag") — Dirichlet left + Neumann right, staggered (DST-IV)
  • ("neumann_stag", "dirichlet_stag") — Neumann left + Dirichlet right, staggered (DCT-IV)
  • ("dirichlet", "neumann") — Dirichlet left + Neumann right, regular (DST-III)
  • ("neumann", "dirichlet") — Neumann left + Dirichlet right, regular (DCT-III)

Transforms are applied as sequential 1-D transforms along each axis. When mixing FFT (complex) with DST/DCT (real), the real and imaginary parts are transformed separately along the non-periodic axis (PoisFFT approach).

Parameters

rhs : Float[Array, "Ny Nx"] Right-hand side. dx : float Grid spacing in x. dy : float Grid spacing in y. bc_x : BoundaryCondition Boundary condition along the x-axis (columns). bc_y : BoundaryCondition Boundary condition along the y-axis (rows). lambda_ : float Helmholtz parameter. Default: 0.0 (Poisson). bc_x_values : tuple[Array | None, Array | None], keyword-only (left, right) inhomogeneous boundary values for x-axis. Each is an array of shape (Ny,) or None (homogeneous). For Dirichlet: prescribed psi values. For Neumann: prescribed dpsi/dn (outward normal). bc_y_values : tuple[Array | None, Array | None], keyword-only (bottom, top) inhomogeneous boundary values for y-axis. Each is an array of shape (Nx,) or None (homogeneous).

Returns

Float[Array, "Ny Nx"] Solution psi, same shape as rhs.

Notes

When using jax.jit, bc_x and bc_y must be marked as static since they are Python objects used for dispatch::

solve_jit = jax.jit(solve_helmholtz_2d, static_argnames=("bc_x", "bc_y"))

Inhomogeneous BC support uses FD2 eigenvalues only (the default).

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_helmholtz_2d(
    rhs: Float[Array, "Ny Nx"],
    dx: float,
    dy: float,
    bc_x: BoundaryCondition = "periodic",
    bc_y: BoundaryCondition = "periodic",
    lambda_: float = 0.0,
    *,
    bc_x_values: tuple[Float[Array, " Ny"] | None, Float[Array, " Ny"] | None] = (
        None,
        None,
    ),
    bc_y_values: tuple[Float[Array, " Nx"] | None, Float[Array, " Nx"] | None] = (
        None,
        None,
    ),
) -> Float[Array, "Ny Nx"]:
    """Solve (nabla^2 - lambda)psi = f in 2-D with per-axis boundary conditions.

    Supports any combination of boundary conditions on each axis:

    * ``"periodic"`` — periodic (FFT)
    * ``"dirichlet"`` — homogeneous Dirichlet on regular (vertex) grid (DST-I)
    * ``"dirichlet_stag"`` — homogeneous Dirichlet on staggered (cell) grid (DST-II)
    * ``"neumann"`` — homogeneous Neumann on regular grid (DCT-I)
    * ``"neumann_stag"`` — homogeneous Neumann on staggered grid (DCT-II)
    * ``("dirichlet_stag", "neumann_stag")`` — Dirichlet left + Neumann right, staggered (DST-IV)
    * ``("neumann_stag", "dirichlet_stag")`` — Neumann left + Dirichlet right, staggered (DCT-IV)
    * ``("dirichlet", "neumann")`` — Dirichlet left + Neumann right, regular (DST-III)
    * ``("neumann", "dirichlet")`` — Neumann left + Dirichlet right, regular (DCT-III)

    Transforms are applied as sequential 1-D transforms along each axis.
    When mixing FFT (complex) with DST/DCT (real), the real and imaginary
    parts are transformed separately along the non-periodic axis (PoisFFT
    approach).

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Right-hand side.
    dx : float
        Grid spacing in x.
    dy : float
        Grid spacing in y.
    bc_x : BoundaryCondition
        Boundary condition along the x-axis (columns).
    bc_y : BoundaryCondition
        Boundary condition along the y-axis (rows).
    lambda_ : float
        Helmholtz parameter.  Default: 0.0 (Poisson).
    bc_x_values : tuple[Array | None, Array | None], keyword-only
        ``(left, right)`` inhomogeneous boundary values for x-axis.
        Each is an array of shape ``(Ny,)`` or ``None`` (homogeneous).
        For Dirichlet: prescribed psi values.
        For Neumann: prescribed dpsi/dn (outward normal).
    bc_y_values : tuple[Array | None, Array | None], keyword-only
        ``(bottom, top)`` inhomogeneous boundary values for y-axis.
        Each is an array of shape ``(Nx,)`` or ``None`` (homogeneous).

    Returns
    -------
    Float[Array, "Ny Nx"]
        Solution psi, same shape as *rhs*.

    Notes
    -----
    When using ``jax.jit``, ``bc_x`` and ``bc_y`` must be marked as static
    since they are Python objects used for dispatch::

        solve_jit = jax.jit(solve_helmholtz_2d, static_argnames=("bc_x", "bc_y"))

    Inhomogeneous BC support uses FD2 eigenvalues only (the default).
    """
    # Apply inhomogeneous BC RHS modification if any values are non-None.
    _has_inhomogeneous = any(v is not None for v in (*bc_x_values, *bc_y_values))
    if _has_inhomogeneous:
        rhs = modify_rhs_2d(rhs, bc_x, bc_y, dx, dy, bc_x_values, bc_y_values)

    Ny, Nx = rhs.shape
    fam_x, type_x, eig_fn_x, null_x = _lookup_bc(bc_x)
    fam_y, type_y, eig_fn_y, null_y = _lookup_bc(bc_y)

    eigx = eig_fn_x(Nx, dx)
    eigy = eig_fn_y(Ny, dy)
    eig2d = eigy[:, None] + eigx[None, :] - lambda_

    _has_null_mode = null_x and null_y
    if _has_null_mode:
        is_null = eig2d[0, 0] == 0.0
        eig2d_safe = eig2d.at[0, 0].set(jnp.where(is_null, 1.0, eig2d[0, 0]))
    else:
        eig2d_safe = eig2d

    x_periodic = fam_x == "fft"
    y_periodic = fam_y == "fft"

    if x_periodic and y_periodic:
        rhs_hat = jnp.fft.fft2(rhs)
    elif not x_periodic and not y_periodic:
        temp = _forward_1d(rhs, fam_x, type_x, axis=1)
        rhs_hat = _forward_1d(temp, fam_y, type_y, axis=0)
    elif x_periodic:
        temp = jnp.fft.fft(rhs, axis=1)
        rhs_hat = _forward_1d(temp.real, fam_y, type_y, axis=0) + 1j * _forward_1d(
            temp.imag, fam_y, type_y, axis=0
        )
    else:
        temp = _forward_1d(rhs, fam_x, type_x, axis=1)
        rhs_hat = jnp.fft.fft(temp, axis=0)

    psi_hat = rhs_hat / eig2d_safe
    if _has_null_mode:
        psi_hat = psi_hat.at[0, 0].set(
            jnp.where(is_null, jnp.zeros_like(psi_hat[0, 0]), psi_hat[0, 0])
        )

    if x_periodic and y_periodic:
        psi = jnp.real(jnp.fft.ifft2(psi_hat))
    elif not x_periodic and not y_periodic:
        temp = _inverse_1d(psi_hat, fam_y, type_y, axis=0)
        psi = _inverse_1d(temp, fam_x, type_x, axis=1)
    elif x_periodic:
        temp = _inverse_1d(psi_hat.real, fam_y, type_y, axis=0) + 1j * _inverse_1d(
            psi_hat.imag, fam_y, type_y, axis=0
        )
        psi = jnp.real(jnp.fft.ifft(temp, axis=1))
    else:
        temp = jnp.fft.ifft(psi_hat, axis=0)
        psi = _inverse_1d(temp.real, fam_x, type_x, axis=1)

    return psi

solve_poisson_2d(rhs, dx, dy, bc_x='periodic', bc_y='periodic', *, bc_x_values=(None, None), bc_y_values=(None, None))

Solve nabla^2 psi = f in 2-D with per-axis boundary conditions.

Convenience wrapper around :func:solve_helmholtz_2d with lambda_=0.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_poisson_2d(
    rhs: Float[Array, "Ny Nx"],
    dx: float,
    dy: float,
    bc_x: BoundaryCondition = "periodic",
    bc_y: BoundaryCondition = "periodic",
    *,
    bc_x_values: tuple[Float[Array, " Ny"] | None, Float[Array, " Ny"] | None] = (
        None,
        None,
    ),
    bc_y_values: tuple[Float[Array, " Nx"] | None, Float[Array, " Nx"] | None] = (
        None,
        None,
    ),
) -> Float[Array, "Ny Nx"]:
    """Solve nabla^2 psi = f in 2-D with per-axis boundary conditions.

    Convenience wrapper around :func:`solve_helmholtz_2d` with ``lambda_=0``.
    """
    return solve_helmholtz_2d(
        rhs,
        dx,
        dy,
        bc_x=bc_x,
        bc_y=bc_y,
        lambda_=0.0,
        bc_x_values=bc_x_values,
        bc_y_values=bc_y_values,
    )

3D

solve_helmholtz_3d(rhs, dx, dy, dz, bc_x='periodic', bc_y='periodic', bc_z='periodic', lambda_=0.0, *, bc_x_values=(None, None), bc_y_values=(None, None), bc_z_values=(None, None))

Solve (nabla^2 - lambda)psi = f in 3-D with per-axis boundary conditions.

Supports any combination of boundary conditions on each axis:

  • "periodic" — periodic (FFT)
  • "dirichlet" — homogeneous Dirichlet on regular grid (DST-I)
  • "dirichlet_stag" — homogeneous Dirichlet on staggered grid (DST-II)
  • "neumann" — homogeneous Neumann on regular grid (DCT-I)
  • "neumann_stag" — homogeneous Neumann on staggered grid (DCT-II)
  • ("dirichlet_stag", "neumann_stag") — Dirichlet left + Neumann right, staggered (DST-IV)
  • ("neumann_stag", "dirichlet_stag") — Neumann left + Dirichlet right, staggered (DCT-IV)
  • ("dirichlet", "neumann") — Dirichlet left + Neumann right, regular (DST-III)
  • ("neumann", "dirichlet") — Neumann left + Dirichlet right, regular (DCT-III)

Transforms are applied as sequential 1-D transforms along each axis. Real (DST/DCT) axes are transformed first, periodic (FFT) axes last. This avoids complex intermediate handling — the IFFT on the inverse pass recovers real data before the real inverse transforms are applied.

Parameters

rhs : Float[Array, "Nz Ny Nx"] Right-hand side. dx : float Grid spacing in x. dy : float Grid spacing in y. dz : float Grid spacing in z. bc_x : BoundaryCondition Boundary condition along the x-axis (columns). bc_y : BoundaryCondition Boundary condition along the y-axis (rows). bc_z : BoundaryCondition Boundary condition along the z-axis (depth). lambda_ : float Helmholtz parameter. Default: 0.0 (Poisson). bc_x_values : tuple[Array | None, Array | None], keyword-only (left, right) face arrays of shape (Nz, Ny) or None. bc_y_values : tuple[Array | None, Array | None], keyword-only (bottom, top) face arrays of shape (Nz, Nx) or None. bc_z_values : tuple[Array | None, Array | None], keyword-only (back, front) face arrays of shape (Ny, Nx) or None.

Returns

Float[Array, "Nz Ny Nx"] Solution psi, same shape as rhs.

Notes

When using jax.jit, bc_x, bc_y, and bc_z must be marked as static since they are Python objects used for dispatch::

solve_jit = jax.jit(
    solve_helmholtz_3d, static_argnames=("bc_x", "bc_y", "bc_z")
)

Inhomogeneous BC support uses FD2 eigenvalues only (the default).

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_helmholtz_3d(
    rhs: Float[Array, "Nz Ny Nx"],
    dx: float,
    dy: float,
    dz: float,
    bc_x: BoundaryCondition = "periodic",
    bc_y: BoundaryCondition = "periodic",
    bc_z: BoundaryCondition = "periodic",
    lambda_: float = 0.0,
    *,
    bc_x_values: tuple[Float[Array, "Nz Ny"] | None, Float[Array, "Nz Ny"] | None] = (
        None,
        None,
    ),
    bc_y_values: tuple[Float[Array, "Nz Nx"] | None, Float[Array, "Nz Nx"] | None] = (
        None,
        None,
    ),
    bc_z_values: tuple[Float[Array, "Ny Nx"] | None, Float[Array, "Ny Nx"] | None] = (
        None,
        None,
    ),
) -> Float[Array, "Nz Ny Nx"]:
    """Solve (nabla^2 - lambda)psi = f in 3-D with per-axis boundary conditions.

    Supports any combination of boundary conditions on each axis:

    * ``"periodic"`` — periodic (FFT)
    * ``"dirichlet"`` — homogeneous Dirichlet on regular grid (DST-I)
    * ``"dirichlet_stag"`` — homogeneous Dirichlet on staggered grid (DST-II)
    * ``"neumann"`` — homogeneous Neumann on regular grid (DCT-I)
    * ``"neumann_stag"`` — homogeneous Neumann on staggered grid (DCT-II)
    * ``("dirichlet_stag", "neumann_stag")`` — Dirichlet left + Neumann right, staggered (DST-IV)
    * ``("neumann_stag", "dirichlet_stag")`` — Neumann left + Dirichlet right, staggered (DCT-IV)
    * ``("dirichlet", "neumann")`` — Dirichlet left + Neumann right, regular (DST-III)
    * ``("neumann", "dirichlet")`` — Neumann left + Dirichlet right, regular (DCT-III)

    Transforms are applied as sequential 1-D transforms along each axis.
    Real (DST/DCT) axes are transformed first, periodic (FFT) axes last.
    This avoids complex intermediate handling — the IFFT on the inverse pass
    recovers real data before the real inverse transforms are applied.

    Parameters
    ----------
    rhs : Float[Array, "Nz Ny Nx"]
        Right-hand side.
    dx : float
        Grid spacing in x.
    dy : float
        Grid spacing in y.
    dz : float
        Grid spacing in z.
    bc_x : BoundaryCondition
        Boundary condition along the x-axis (columns).
    bc_y : BoundaryCondition
        Boundary condition along the y-axis (rows).
    bc_z : BoundaryCondition
        Boundary condition along the z-axis (depth).
    lambda_ : float
        Helmholtz parameter.  Default: 0.0 (Poisson).
    bc_x_values : tuple[Array | None, Array | None], keyword-only
        ``(left, right)`` face arrays of shape ``(Nz, Ny)`` or ``None``.
    bc_y_values : tuple[Array | None, Array | None], keyword-only
        ``(bottom, top)`` face arrays of shape ``(Nz, Nx)`` or ``None``.
    bc_z_values : tuple[Array | None, Array | None], keyword-only
        ``(back, front)`` face arrays of shape ``(Ny, Nx)`` or ``None``.

    Returns
    -------
    Float[Array, "Nz Ny Nx"]
        Solution psi, same shape as *rhs*.

    Notes
    -----
    When using ``jax.jit``, ``bc_x``, ``bc_y``, and ``bc_z`` must be marked
    as static since they are Python objects used for dispatch::

        solve_jit = jax.jit(
            solve_helmholtz_3d, static_argnames=("bc_x", "bc_y", "bc_z")
        )

    Inhomogeneous BC support uses FD2 eigenvalues only (the default).
    """
    # Apply inhomogeneous BC RHS modification if any values are non-None.
    _has_inhomogeneous = any(
        v is not None for v in (*bc_x_values, *bc_y_values, *bc_z_values)
    )
    if _has_inhomogeneous:
        rhs = modify_rhs_3d(
            rhs,
            bc_x,
            bc_y,
            bc_z,
            dx,
            dy,
            dz,
            bc_x_values,
            bc_y_values,
            bc_z_values,
        )

    Nz, Ny, Nx = rhs.shape
    fam_x, type_x, eig_fn_x, null_x = _lookup_bc(bc_x)
    fam_y, type_y, eig_fn_y, null_y = _lookup_bc(bc_y)
    fam_z, type_z, eig_fn_z, null_z = _lookup_bc(bc_z)

    # 1-D eigenvalues → 3-D broadcast
    eigx = eig_fn_x(Nx, dx)  # [Nx]
    eigy = eig_fn_y(Ny, dy)  # [Ny]
    eigz = eig_fn_z(Nz, dz)  # [Nz]
    eig3d = eigz[:, None, None] + eigy[None, :, None] + eigx[None, None, :] - lambda_

    # Null-mode handling: only when ALL axes have a null mode.
    _has_null = null_x and null_y and null_z
    if _has_null:
        is_null = eig3d[0, 0, 0] == 0.0
        eig3d_safe = eig3d.at[0, 0, 0].set(jnp.where(is_null, 1.0, eig3d[0, 0, 0]))
    else:
        eig3d_safe = eig3d

    # Classify axes: real (DST/DCT) first, periodic (FFT) last.
    all_axes = [
        (2, fam_x, type_x),  # x = axis 2
        (1, fam_y, type_y),  # y = axis 1
        (0, fam_z, type_z),  # z = axis 0
    ]
    real_axes = [(a, f, t) for a, f, t in all_axes if f != "fft"]
    periodic_axes = [(a, f, t) for a, f, t in all_axes if f == "fft"]

    # --- Forward transform: real axes first, then periodic ---
    data = rhs
    for axis, family, type_ in real_axes + periodic_axes:
        data = _forward_1d(data, family, type_, axis)

    # --- Spectral division ---
    psi_hat = data / eig3d_safe
    if _has_null:
        psi_hat = psi_hat.at[0, 0, 0].set(
            jnp.where(is_null, jnp.zeros_like(psi_hat[0, 0, 0]), psi_hat[0, 0, 0])
        )

    # --- Inverse transform: periodic first, .real, then real axes ---
    data = psi_hat
    for axis, family, type_ in periodic_axes:
        data = _inverse_1d(data, family, type_, axis)
    if periodic_axes:
        data = data.real
    for axis, family, type_ in real_axes:
        data = _inverse_1d(data, family, type_, axis)

    return data

solve_poisson_3d(rhs, dx, dy, dz, bc_x='periodic', bc_y='periodic', bc_z='periodic', *, bc_x_values=(None, None), bc_y_values=(None, None), bc_z_values=(None, None))

Solve nabla^2 psi = f in 3-D with per-axis boundary conditions.

Convenience wrapper around :func:solve_helmholtz_3d with lambda_=0.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve_poisson_3d(
    rhs: Float[Array, "Nz Ny Nx"],
    dx: float,
    dy: float,
    dz: float,
    bc_x: BoundaryCondition = "periodic",
    bc_y: BoundaryCondition = "periodic",
    bc_z: BoundaryCondition = "periodic",
    *,
    bc_x_values: tuple[Float[Array, "Nz Ny"] | None, Float[Array, "Nz Ny"] | None] = (
        None,
        None,
    ),
    bc_y_values: tuple[Float[Array, "Nz Nx"] | None, Float[Array, "Nz Nx"] | None] = (
        None,
        None,
    ),
    bc_z_values: tuple[Float[Array, "Ny Nx"] | None, Float[Array, "Ny Nx"] | None] = (
        None,
        None,
    ),
) -> Float[Array, "Nz Ny Nx"]:
    """Solve nabla^2 psi = f in 3-D with per-axis boundary conditions.

    Convenience wrapper around :func:`solve_helmholtz_3d` with ``lambda_=0``.
    """
    return solve_helmholtz_3d(
        rhs,
        dx,
        dy,
        dz,
        bc_x=bc_x,
        bc_y=bc_y,
        bc_z=bc_z,
        lambda_=0.0,
        bc_x_values=bc_x_values,
        bc_y_values=bc_y_values,
        bc_z_values=bc_z_values,
    )

Inhomogeneous BC Helpers

modify_rhs_1d(rhs, bc, dx, bc_values=(0.0, 0.0))

Apply boundary source terms for inhomogeneous BCs to a 1-D RHS.

Modifies the right-hand side at boundary-adjacent grid points to account for non-zero boundary values. This exploits the FD2 stencil structure and must only be used with FD2 eigenvalues.

Parameters

rhs : Float[Array, " N"] Original right-hand side. bc : BoundaryCondition Boundary condition type. dx : float Grid spacing. bc_values : tuple[float, float] (left_value, right_value). For Dirichlet BCs these are the prescribed psi values; for Neumann BCs these are the prescribed dpsi/dn (outward normal derivative) values.

Returns

Float[Array, " N"] Modified RHS with boundary source terms incorporated.

Source code in spectraldiffx/_src/fourier/solvers.py
def modify_rhs_1d(
    rhs: Float[Array, " N"],
    bc: BoundaryCondition,
    dx: float,
    bc_values: tuple[float, float] = (0.0, 0.0),
) -> Float[Array, " N"]:
    """Apply boundary source terms for inhomogeneous BCs to a 1-D RHS.

    Modifies the right-hand side at boundary-adjacent grid points to
    account for non-zero boundary values.  This exploits the FD2 stencil
    structure and must only be used with FD2 eigenvalues.

    Parameters
    ----------
    rhs : Float[Array, " N"]
        Original right-hand side.
    bc : BoundaryCondition
        Boundary condition type.
    dx : float
        Grid spacing.
    bc_values : tuple[float, float]
        ``(left_value, right_value)``.  For Dirichlet BCs these are the
        prescribed psi values; for Neumann BCs these are the prescribed
        dpsi/dn (outward normal derivative) values.

    Returns
    -------
    Float[Array, " N"]
        Modified RHS with boundary source terms incorporated.
    """
    left_val, right_val = bc_values
    if left_val == 0.0 and right_val == 0.0:
        return rhs
    left_corr, right_corr = _rhs_correction_1d(bc, dx, left_val, right_val)
    rhs = rhs.at[0].add(left_corr)
    rhs = rhs.at[-1].add(right_corr)
    return rhs

modify_rhs_2d(rhs, bc_x, bc_y, dx, dy, bc_x_values=(None, None), bc_y_values=(None, None))

Apply boundary source terms for inhomogeneous BCs to a 2-D RHS.

Parameters

rhs : Float[Array, "Ny Nx"] Original right-hand side. bc_x, bc_y : BoundaryCondition Boundary condition types for x and y axes. dx, dy : float Grid spacings. bc_x_values : tuple[Array | None, Array | None] (left, right) boundary values for the x-axis. Each is an array of shape (Ny,) or None (homogeneous). bc_y_values : tuple[Array | None, Array | None] (bottom, top) boundary values for the y-axis. Each is an array of shape (Nx,) or None (homogeneous).

Returns

Float[Array, "Ny Nx"] Modified RHS.

Raises

ValueError If non-None boundary values are provided for a periodic axis.

Source code in spectraldiffx/_src/fourier/solvers.py
def modify_rhs_2d(
    rhs: Float[Array, "Ny Nx"],
    bc_x: BoundaryCondition,
    bc_y: BoundaryCondition,
    dx: float,
    dy: float,
    bc_x_values: tuple[Float[Array, " Ny"] | None, Float[Array, " Ny"] | None] = (
        None,
        None,
    ),
    bc_y_values: tuple[Float[Array, " Nx"] | None, Float[Array, " Nx"] | None] = (
        None,
        None,
    ),
) -> Float[Array, "Ny Nx"]:
    """Apply boundary source terms for inhomogeneous BCs to a 2-D RHS.

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Original right-hand side.
    bc_x, bc_y : BoundaryCondition
        Boundary condition types for x and y axes.
    dx, dy : float
        Grid spacings.
    bc_x_values : tuple[Array | None, Array | None]
        ``(left, right)`` boundary values for the x-axis.
        Each is an array of shape ``(Ny,)`` or ``None`` (homogeneous).
    bc_y_values : tuple[Array | None, Array | None]
        ``(bottom, top)`` boundary values for the y-axis.
        Each is an array of shape ``(Nx,)`` or ``None`` (homogeneous).

    Returns
    -------
    Float[Array, "Ny Nx"]
        Modified RHS.

    Raises
    ------
    ValueError
        If non-None boundary values are provided for a periodic axis.
    """
    _validate_no_periodic_inhomogeneous(bc_x, bc_x_values, "x")
    _validate_no_periodic_inhomogeneous(bc_y, bc_y_values, "y")

    xl, xr = bc_x_values
    yb, yt = bc_y_values

    # x-axis: modify columns 0 and -1
    if xl is not None:
        bc_left_type = _bc_type_for_side(bc_x, "left")
        left_corr = _BC_RHS_FORMULAS[bc_left_type](1.0, 0.0, dx)[0]
        rhs = rhs.at[:, 0].add(left_corr * xl)
    if xr is not None:
        bc_right_type = _bc_type_for_side(bc_x, "right")
        right_corr = _BC_RHS_FORMULAS[bc_right_type](0.0, 1.0, dx)[1]
        rhs = rhs.at[:, -1].add(right_corr * xr)

    # y-axis: modify rows 0 and -1
    if yb is not None:
        bc_bottom_type = _bc_type_for_side(bc_y, "left")
        bottom_corr = _BC_RHS_FORMULAS[bc_bottom_type](1.0, 0.0, dy)[0]
        rhs = rhs.at[0, :].add(bottom_corr * yb)
    if yt is not None:
        bc_top_type = _bc_type_for_side(bc_y, "right")
        top_corr = _BC_RHS_FORMULAS[bc_top_type](0.0, 1.0, dy)[1]
        rhs = rhs.at[-1, :].add(top_corr * yt)

    return rhs

modify_rhs_3d(rhs, bc_x, bc_y, bc_z, dx, dy, dz, bc_x_values=(None, None), bc_y_values=(None, None), bc_z_values=(None, None))

Apply boundary source terms for inhomogeneous BCs to a 3-D RHS.

Parameters

rhs : Float[Array, "Nz Ny Nx"] Original right-hand side. bc_x, bc_y, bc_z : BoundaryCondition Boundary condition types for each axis. dx, dy, dz : float Grid spacings. bc_x_values : tuple[Array | None, Array | None] (left, right) face arrays of shape (Nz, Ny). bc_y_values : tuple[Array | None, Array | None] (bottom, top) face arrays of shape (Nz, Nx). bc_z_values : tuple[Array | None, Array | None] (back, front) face arrays of shape (Ny, Nx).

Returns

Float[Array, "Nz Ny Nx"] Modified RHS.

Raises

ValueError If non-None boundary values are provided for a periodic axis.

Source code in spectraldiffx/_src/fourier/solvers.py
def modify_rhs_3d(
    rhs: Float[Array, "Nz Ny Nx"],
    bc_x: BoundaryCondition,
    bc_y: BoundaryCondition,
    bc_z: BoundaryCondition,
    dx: float,
    dy: float,
    dz: float,
    bc_x_values: tuple[Float[Array, "Nz Ny"] | None, Float[Array, "Nz Ny"] | None] = (
        None,
        None,
    ),
    bc_y_values: tuple[Float[Array, "Nz Nx"] | None, Float[Array, "Nz Nx"] | None] = (
        None,
        None,
    ),
    bc_z_values: tuple[Float[Array, "Ny Nx"] | None, Float[Array, "Ny Nx"] | None] = (
        None,
        None,
    ),
) -> Float[Array, "Nz Ny Nx"]:
    """Apply boundary source terms for inhomogeneous BCs to a 3-D RHS.

    Parameters
    ----------
    rhs : Float[Array, "Nz Ny Nx"]
        Original right-hand side.
    bc_x, bc_y, bc_z : BoundaryCondition
        Boundary condition types for each axis.
    dx, dy, dz : float
        Grid spacings.
    bc_x_values : tuple[Array | None, Array | None]
        ``(left, right)`` face arrays of shape ``(Nz, Ny)``.
    bc_y_values : tuple[Array | None, Array | None]
        ``(bottom, top)`` face arrays of shape ``(Nz, Nx)``.
    bc_z_values : tuple[Array | None, Array | None]
        ``(back, front)`` face arrays of shape ``(Ny, Nx)``.

    Returns
    -------
    Float[Array, "Nz Ny Nx"]
        Modified RHS.

    Raises
    ------
    ValueError
        If non-None boundary values are provided for a periodic axis.
    """
    _validate_no_periodic_inhomogeneous(bc_x, bc_x_values, "x")
    _validate_no_periodic_inhomogeneous(bc_y, bc_y_values, "y")
    _validate_no_periodic_inhomogeneous(bc_z, bc_z_values, "z")

    xl, xr = bc_x_values
    yb, yt = bc_y_values
    zb, zf = bc_z_values

    # x-axis faces: [:, :, 0] and [:, :, -1]
    if xl is not None:
        bc_left_type = _bc_type_for_side(bc_x, "left")
        corr = _BC_RHS_FORMULAS[bc_left_type](1.0, 0.0, dx)[0]
        rhs = rhs.at[:, :, 0].add(corr * xl)
    if xr is not None:
        bc_right_type = _bc_type_for_side(bc_x, "right")
        corr = _BC_RHS_FORMULAS[bc_right_type](0.0, 1.0, dx)[1]
        rhs = rhs.at[:, :, -1].add(corr * xr)

    # y-axis faces: [:, 0, :] and [:, -1, :]
    if yb is not None:
        bc_bottom_type = _bc_type_for_side(bc_y, "left")
        corr = _BC_RHS_FORMULAS[bc_bottom_type](1.0, 0.0, dy)[0]
        rhs = rhs.at[:, 0, :].add(corr * yb)
    if yt is not None:
        bc_top_type = _bc_type_for_side(bc_y, "right")
        corr = _BC_RHS_FORMULAS[bc_top_type](0.0, 1.0, dy)[1]
        rhs = rhs.at[:, -1, :].add(corr * yt)

    # z-axis faces: [0, :, :] and [-1, :, :]
    if zb is not None:
        bc_back_type = _bc_type_for_side(bc_z, "left")
        corr = _BC_RHS_FORMULAS[bc_back_type](1.0, 0.0, dz)[0]
        rhs = rhs.at[0, :, :].add(corr * zb)
    if zf is not None:
        bc_front_type = _bc_type_for_side(bc_z, "right")
        corr = _BC_RHS_FORMULAS[bc_front_type](0.0, 1.0, dz)[1]
        rhs = rhs.at[-1, :, :].add(corr * zf)

    return rhs

Layer 1 — Module Classes

Periodic (FFT)

SpectralHelmholtzSolver1D

Bases: Module

1D Helmholtz/Poisson solver with periodic BCs using FFT.

Solves (d²/dx² − α)φ = f on a periodic 1-D domain [0, L) using continuous Fourier wavenumbers:

φ̂_k = −f̂_k / (k² + α)

where k = 2πm/L are the Fourier wavenumbers from grid.k.

For α = 0 (Poisson), the k=0 mode is singular; zero_mean=True projects it out (sets φ̂_0 = 0).

Attributes

grid : FourierGrid1D 1-D Fourier grid providing wavenumbers and FFT methods.

Source code in spectraldiffx/_src/fourier/solvers.py
class SpectralHelmholtzSolver1D(eqx.Module):
    """1D Helmholtz/Poisson solver with periodic BCs using FFT.

    Solves ``(d²/dx² − α)φ = f`` on a periodic 1-D domain [0, L) using
    continuous Fourier wavenumbers:

        φ̂_k = −f̂_k / (k² + α)

    where k = 2πm/L are the Fourier wavenumbers from ``grid.k``.

    For α = 0 (Poisson), the k=0 mode is singular; ``zero_mean=True``
    projects it out (sets φ̂_0 = 0).

    Attributes
    ----------
    grid : FourierGrid1D
        1-D Fourier grid providing wavenumbers and FFT methods.
    """

    grid: FourierGrid1D

    def solve(
        self,
        f: Array,
        alpha: float = 0.0,
        zero_mean: bool = True,
        spectral: bool = False,
    ) -> Array:
        """Solve (d²/dx² − α)φ = f.

        Parameters
        ----------
        f : Float[Array, "N"]
            Source term in physical space.
        alpha : float
            Helmholtz parameter (α ≥ 0).  Default: 0.0 (Poisson).
        zero_mean : bool
            Force the mean (k=0 mode) to zero.  Default: True.
        spectral : bool
            If True, *f* is already in spectral space.

        Returns
        -------
        Float[Array, "N"]
            Solution φ in physical space.
        """
        f_hat = f if spectral else self.grid.transform(f)
        k2 = self.grid.k**2
        denom = k2 + alpha  # k^2 + alpha  [N]

        denom_safe = jnp.where(denom == 0.0, 1.0, denom)
        phi_hat = -f_hat / denom_safe  # phi_hat = -f_hat/(k^2+alpha)  [N]

        if zero_mean:
            phi_hat = jnp.where(k2 == 0.0, 0.0, phi_hat)

        return self.grid.transform(phi_hat, inverse=True).real

Functions

solve(f, alpha=0.0, zero_mean=True, spectral=False)

Solve (d²/dx² − α)φ = f.

Parameters

f : Float[Array, "N"] Source term in physical space. alpha : float Helmholtz parameter (α ≥ 0). Default: 0.0 (Poisson). zero_mean : bool Force the mean (k=0 mode) to zero. Default: True. spectral : bool If True, f is already in spectral space.

Returns

Float[Array, "N"] Solution φ in physical space.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve(
    self,
    f: Array,
    alpha: float = 0.0,
    zero_mean: bool = True,
    spectral: bool = False,
) -> Array:
    """Solve (d²/dx² − α)φ = f.

    Parameters
    ----------
    f : Float[Array, "N"]
        Source term in physical space.
    alpha : float
        Helmholtz parameter (α ≥ 0).  Default: 0.0 (Poisson).
    zero_mean : bool
        Force the mean (k=0 mode) to zero.  Default: True.
    spectral : bool
        If True, *f* is already in spectral space.

    Returns
    -------
    Float[Array, "N"]
        Solution φ in physical space.
    """
    f_hat = f if spectral else self.grid.transform(f)
    k2 = self.grid.k**2
    denom = k2 + alpha  # k^2 + alpha  [N]

    denom_safe = jnp.where(denom == 0.0, 1.0, denom)
    phi_hat = -f_hat / denom_safe  # phi_hat = -f_hat/(k^2+alpha)  [N]

    if zero_mean:
        phi_hat = jnp.where(k2 == 0.0, 0.0, phi_hat)

    return self.grid.transform(phi_hat, inverse=True).real

SpectralHelmholtzSolver2D

Bases: Module

2D Helmholtz/Poisson solver with periodic BCs using FFT.

Solves (∇² − α)φ = f on a doubly periodic domain using continuous Fourier wavenumbers:

φ̂[j,i] = −f̂[j,i] / (kx_i² + ky_j² + α)

where |k|² = kx² + ky² is provided by grid.K2.

For α = 0 (Poisson), the (0,0) mode is singular; zero_mean=True projects it out.

Attributes

grid : FourierGrid2D 2-D Fourier grid providing wavenumber meshgrid and FFT methods.

Source code in spectraldiffx/_src/fourier/solvers.py
class SpectralHelmholtzSolver2D(eqx.Module):
    """2D Helmholtz/Poisson solver with periodic BCs using FFT.

    Solves ``(∇² − α)φ = f`` on a doubly periodic domain using continuous
    Fourier wavenumbers:

        φ̂[j,i] = −f̂[j,i] / (kx_i² + ky_j² + α)

    where |k|² = kx² + ky² is provided by ``grid.K2``.

    For α = 0 (Poisson), the (0,0) mode is singular; ``zero_mean=True``
    projects it out.

    Attributes
    ----------
    grid : FourierGrid2D
        2-D Fourier grid providing wavenumber meshgrid and FFT methods.
    """

    grid: FourierGrid2D

    def solve(
        self,
        f: Array,
        alpha: float = 0.0,
        zero_mean: bool = True,
        spectral: bool = False,
    ) -> Array:
        """Solve (∇² − α)φ = f.

        Parameters
        ----------
        f : Array [Ny, Nx]
            Source term in physical space.
        alpha : float
            Helmholtz parameter (α ≥ 0).  Default: 0.0 (Poisson).
        zero_mean : bool
            Force the (0,0) mode to zero.  Default: True.
        spectral : bool
            If True, *f* is already in spectral space.

        Returns
        -------
        Array [Ny, Nx]
            Solution φ in physical space.
        """
        f_hat = f if spectral else self.grid.transform(f)
        K2 = self.grid.K2  # kx^2 + ky^2  [Ny, Nx]
        denom = K2 + alpha

        denom_safe = jnp.where(denom == 0.0, 1.0, denom)
        phi_hat = -f_hat / denom_safe  # -f_hat / (|k|^2 + alpha)

        if zero_mean:
            phi_hat = jnp.where(K2 == 0.0, 0.0, phi_hat)

        return self.grid.transform(phi_hat, inverse=True).real

Functions

solve(f, alpha=0.0, zero_mean=True, spectral=False)

Solve (∇² − α)φ = f.

Parameters

f : Array [Ny, Nx] Source term in physical space. alpha : float Helmholtz parameter (α ≥ 0). Default: 0.0 (Poisson). zero_mean : bool Force the (0,0) mode to zero. Default: True. spectral : bool If True, f is already in spectral space.

Returns

Array [Ny, Nx] Solution φ in physical space.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve(
    self,
    f: Array,
    alpha: float = 0.0,
    zero_mean: bool = True,
    spectral: bool = False,
) -> Array:
    """Solve (∇² − α)φ = f.

    Parameters
    ----------
    f : Array [Ny, Nx]
        Source term in physical space.
    alpha : float
        Helmholtz parameter (α ≥ 0).  Default: 0.0 (Poisson).
    zero_mean : bool
        Force the (0,0) mode to zero.  Default: True.
    spectral : bool
        If True, *f* is already in spectral space.

    Returns
    -------
    Array [Ny, Nx]
        Solution φ in physical space.
    """
    f_hat = f if spectral else self.grid.transform(f)
    K2 = self.grid.K2  # kx^2 + ky^2  [Ny, Nx]
    denom = K2 + alpha

    denom_safe = jnp.where(denom == 0.0, 1.0, denom)
    phi_hat = -f_hat / denom_safe  # -f_hat / (|k|^2 + alpha)

    if zero_mean:
        phi_hat = jnp.where(K2 == 0.0, 0.0, phi_hat)

    return self.grid.transform(phi_hat, inverse=True).real

SpectralHelmholtzSolver3D

Bases: Module

3D Helmholtz/Poisson solver with periodic BCs using FFT.

Solves (∇² − α)φ = f on a triply periodic domain using continuous Fourier wavenumbers:

φ̂[l,j,i] = −f̂[l,j,i] / (kx_i² + ky_j² + kz_l² + α)

where |k|² = kx² + ky² + kz² is provided by grid.K2.

For α = 0 (Poisson), the (0,0,0) mode is singular; zero_mean=True projects it out.

Attributes

grid : FourierGrid3D 3-D Fourier grid providing wavenumber meshgrid and FFT methods.

Source code in spectraldiffx/_src/fourier/solvers.py
class SpectralHelmholtzSolver3D(eqx.Module):
    """3D Helmholtz/Poisson solver with periodic BCs using FFT.

    Solves ``(∇² − α)φ = f`` on a triply periodic domain using continuous
    Fourier wavenumbers:

        φ̂[l,j,i] = −f̂[l,j,i] / (kx_i² + ky_j² + kz_l² + α)

    where |k|² = kx² + ky² + kz² is provided by ``grid.K2``.

    For α = 0 (Poisson), the (0,0,0) mode is singular; ``zero_mean=True``
    projects it out.

    Attributes
    ----------
    grid : FourierGrid3D
        3-D Fourier grid providing wavenumber meshgrid and FFT methods.
    """

    grid: FourierGrid3D

    def solve(
        self,
        f: Array,
        alpha: float = 0.0,
        zero_mean: bool = True,
        spectral: bool = False,
    ) -> Array:
        """Solve (∇² − α)φ = f.

        Parameters
        ----------
        f : Array [Nz, Ny, Nx]
            Source term in physical space.
        alpha : float
            Helmholtz parameter (α ≥ 0).  Default: 0.0 (Poisson).
        zero_mean : bool
            Force the (0,0,0) mode to zero.  Default: True.
        spectral : bool
            If True, *f* is already in spectral space.

        Returns
        -------
        Array [Nz, Ny, Nx]
            Solution φ in physical space.
        """
        f_hat = f if spectral else self.grid.transform(f)
        K2 = self.grid.K2  # kx^2 + ky^2 + kz^2  [Nz, Ny, Nx]
        denom = K2 + alpha

        denom_safe = jnp.where(denom == 0.0, 1.0, denom)
        phi_hat = -f_hat / denom_safe  # -f_hat / (|k|^2 + alpha)

        if zero_mean:
            phi_hat = jnp.where(K2 == 0.0, 0.0, phi_hat)

        return self.grid.transform(phi_hat, inverse=True).real

Functions

solve(f, alpha=0.0, zero_mean=True, spectral=False)

Solve (∇² − α)φ = f.

Parameters

f : Array [Nz, Ny, Nx] Source term in physical space. alpha : float Helmholtz parameter (α ≥ 0). Default: 0.0 (Poisson). zero_mean : bool Force the (0,0,0) mode to zero. Default: True. spectral : bool If True, f is already in spectral space.

Returns

Array [Nz, Ny, Nx] Solution φ in physical space.

Source code in spectraldiffx/_src/fourier/solvers.py
def solve(
    self,
    f: Array,
    alpha: float = 0.0,
    zero_mean: bool = True,
    spectral: bool = False,
) -> Array:
    """Solve (∇² − α)φ = f.

    Parameters
    ----------
    f : Array [Nz, Ny, Nx]
        Source term in physical space.
    alpha : float
        Helmholtz parameter (α ≥ 0).  Default: 0.0 (Poisson).
    zero_mean : bool
        Force the (0,0,0) mode to zero.  Default: True.
    spectral : bool
        If True, *f* is already in spectral space.

    Returns
    -------
    Array [Nz, Ny, Nx]
        Solution φ in physical space.
    """
    f_hat = f if spectral else self.grid.transform(f)
    K2 = self.grid.K2  # kx^2 + ky^2 + kz^2  [Nz, Ny, Nx]
    denom = K2 + alpha

    denom_safe = jnp.where(denom == 0.0, 1.0, denom)
    phi_hat = -f_hat / denom_safe  # -f_hat / (|k|^2 + alpha)

    if zero_mean:
        phi_hat = jnp.where(K2 == 0.0, 0.0, phi_hat)

    return self.grid.transform(phi_hat, inverse=True).real

Dirichlet, Regular Grid (DST-I)

DirichletHelmholtzSolver2D

Bases: Module

2D Helmholtz/Poisson/Laplace solver with homogeneous Dirichlet BCs.

Solves (∇² − α)ψ = f where ψ = 0 on all four edges, using the DST-I spectral method (see :func:solve_helmholtz_dst).

The input rhs contains values at the Ny × Nx interior grid points; boundary values are implicitly zero.

With alpha=0.0 (default), solves the Poisson equation ∇²ψ = f.

Attributes

dx : float Grid spacing in x. dy : float Grid spacing in y. alpha : float Helmholtz parameter α ≥ 0. Default: 0.0 (Poisson/Laplace).

Source code in spectraldiffx/_src/fourier/solvers.py
class DirichletHelmholtzSolver2D(eqx.Module):
    """2D Helmholtz/Poisson/Laplace solver with homogeneous Dirichlet BCs.

    Solves ``(∇² − α)ψ = f`` where ψ = 0 on all four edges, using the
    DST-I spectral method (see :func:`solve_helmholtz_dst`).

    The input *rhs* contains values at the Ny × Nx **interior** grid
    points; boundary values are implicitly zero.

    With ``alpha=0.0`` (default), solves the Poisson equation ∇²ψ = f.

    Attributes
    ----------
    dx : float
        Grid spacing in x.
    dy : float
        Grid spacing in y.
    alpha : float
        Helmholtz parameter α ≥ 0.  Default: 0.0 (Poisson/Laplace).
    """

    dx: float
    dy: float
    alpha: float = 0.0

    def __call__(self, rhs: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
        """Solve (∇² − α)ψ = rhs.

        Parameters
        ----------
        rhs : Float[Array, "Ny Nx"]
            Right-hand side at interior grid points.

        Returns
        -------
        Float[Array, "Ny Nx"]
            Solution ψ at interior grid points.
        """
        return solve_helmholtz_dst(rhs, self.dx, self.dy, self.alpha)

Functions

__call__(rhs)

Solve (∇² − α)ψ = rhs.

Parameters

rhs : Float[Array, "Ny Nx"] Right-hand side at interior grid points.

Returns

Float[Array, "Ny Nx"] Solution ψ at interior grid points.

Source code in spectraldiffx/_src/fourier/solvers.py
def __call__(self, rhs: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
    """Solve (∇² − α)ψ = rhs.

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Right-hand side at interior grid points.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Solution ψ at interior grid points.
    """
    return solve_helmholtz_dst(rhs, self.dx, self.dy, self.alpha)

Dirichlet, Staggered Grid (DST-II)

StaggeredDirichletHelmholtzSolver2D

Bases: Module

2D Helmholtz/Poisson solver with Dirichlet BCs on a staggered grid.

Solves (∇² − α)ψ = f where ψ = 0 at the cell edges (half a grid spacing outside the first and last cell centres), using the DST-II spectral method (see :func:solve_helmholtz_dst2).

Attributes

dx : float Grid spacing in x. dy : float Grid spacing in y. alpha : float Helmholtz parameter α ≥ 0. Default: 0.0 (Poisson).

Source code in spectraldiffx/_src/fourier/solvers.py
class StaggeredDirichletHelmholtzSolver2D(eqx.Module):
    """2D Helmholtz/Poisson solver with Dirichlet BCs on a staggered grid.

    Solves ``(∇² − α)ψ = f`` where ψ = 0 at the cell edges (half a grid
    spacing outside the first and last cell centres), using the DST-II
    spectral method (see :func:`solve_helmholtz_dst2`).

    Attributes
    ----------
    dx : float
        Grid spacing in x.
    dy : float
        Grid spacing in y.
    alpha : float
        Helmholtz parameter α ≥ 0.  Default: 0.0 (Poisson).
    """

    dx: float
    dy: float
    alpha: float = 0.0

    def __call__(self, rhs: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
        """Solve (∇² − α)ψ = rhs.

        Parameters
        ----------
        rhs : Float[Array, "Ny Nx"]
            Right-hand side at cell-centred grid points.

        Returns
        -------
        Float[Array, "Ny Nx"]
            Solution ψ at cell-centred grid points.
        """
        return solve_helmholtz_dst2(rhs, self.dx, self.dy, self.alpha)

Functions

__call__(rhs)

Solve (∇² − α)ψ = rhs.

Parameters

rhs : Float[Array, "Ny Nx"] Right-hand side at cell-centred grid points.

Returns

Float[Array, "Ny Nx"] Solution ψ at cell-centred grid points.

Source code in spectraldiffx/_src/fourier/solvers.py
def __call__(self, rhs: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
    """Solve (∇² − α)ψ = rhs.

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Right-hand side at cell-centred grid points.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Solution ψ at cell-centred grid points.
    """
    return solve_helmholtz_dst2(rhs, self.dx, self.dy, self.alpha)

Neumann, Staggered Grid (DCT-II)

NeumannHelmholtzSolver2D

Bases: Module

2D Helmholtz/Poisson/Laplace solver with homogeneous Neumann BCs.

Solves (∇² − α)ψ = f where ∂ψ/∂n = 0 on all four edges, using the DCT-II spectral method (see :func:solve_helmholtz_dct).

With alpha=0.0 (default), solves the Poisson equation ∇²ψ = f. The Poisson null space (constant mode) is removed by enforcing zero-mean: ψ̂[0,0] = 0.

Attributes

dx : float Grid spacing in x. dy : float Grid spacing in y. alpha : float Helmholtz parameter α ≥ 0. Default: 0.0 (Poisson/Laplace).

Source code in spectraldiffx/_src/fourier/solvers.py
class NeumannHelmholtzSolver2D(eqx.Module):
    """2D Helmholtz/Poisson/Laplace solver with homogeneous Neumann BCs.

    Solves ``(∇² − α)ψ = f`` where ∂ψ/∂n = 0 on all four edges, using
    the DCT-II spectral method (see :func:`solve_helmholtz_dct`).

    With ``alpha=0.0`` (default), solves the Poisson equation ∇²ψ = f.
    The Poisson null space (constant mode) is removed by enforcing
    zero-mean: ψ̂[0,0] = 0.

    Attributes
    ----------
    dx : float
        Grid spacing in x.
    dy : float
        Grid spacing in y.
    alpha : float
        Helmholtz parameter α ≥ 0.  Default: 0.0 (Poisson/Laplace).
    """

    dx: float
    dy: float
    alpha: float = 0.0

    def __call__(self, rhs: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
        """Solve (∇² − α)ψ = rhs.

        Parameters
        ----------
        rhs : Float[Array, "Ny Nx"]
            Right-hand side.

        Returns
        -------
        Float[Array, "Ny Nx"]
            Solution ψ (zero-mean gauge when α = 0).
        """
        return solve_helmholtz_dct(rhs, self.dx, self.dy, self.alpha)

Functions

__call__(rhs)

Solve (∇² − α)ψ = rhs.

Parameters

rhs : Float[Array, "Ny Nx"] Right-hand side.

Returns

Float[Array, "Ny Nx"] Solution ψ (zero-mean gauge when α = 0).

Source code in spectraldiffx/_src/fourier/solvers.py
def __call__(self, rhs: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
    """Solve (∇² − α)ψ = rhs.

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Right-hand side.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Solution ψ (zero-mean gauge when α = 0).
    """
    return solve_helmholtz_dct(rhs, self.dx, self.dy, self.alpha)

Neumann, Regular Grid (DCT-I)

RegularNeumannHelmholtzSolver2D

Bases: Module

2D Helmholtz/Poisson solver with Neumann BCs on a regular grid.

Solves (∇² − α)ψ = f where ∂ψ/∂n = 0 on all four edges, using the DCT-I spectral method (see :func:solve_helmholtz_dct1).

Attributes

dx : float Grid spacing in x. dy : float Grid spacing in y. alpha : float Helmholtz parameter α ≥ 0. Default: 0.0 (Poisson).

Source code in spectraldiffx/_src/fourier/solvers.py
class RegularNeumannHelmholtzSolver2D(eqx.Module):
    """2D Helmholtz/Poisson solver with Neumann BCs on a regular grid.

    Solves ``(∇² − α)ψ = f`` where ∂ψ/∂n = 0 on all four edges, using
    the DCT-I spectral method (see :func:`solve_helmholtz_dct1`).

    Attributes
    ----------
    dx : float
        Grid spacing in x.
    dy : float
        Grid spacing in y.
    alpha : float
        Helmholtz parameter α ≥ 0.  Default: 0.0 (Poisson).
    """

    dx: float
    dy: float
    alpha: float = 0.0

    def __call__(self, rhs: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
        """Solve (∇² − α)ψ = rhs.

        Parameters
        ----------
        rhs : Float[Array, "Ny Nx"]
            Right-hand side (including boundary grid points).

        Returns
        -------
        Float[Array, "Ny Nx"]
            Solution ψ (zero-mean gauge when α = 0).
        """
        return solve_helmholtz_dct1(rhs, self.dx, self.dy, self.alpha)

Functions

__call__(rhs)

Solve (∇² − α)ψ = rhs.

Parameters

rhs : Float[Array, "Ny Nx"] Right-hand side (including boundary grid points).

Returns

Float[Array, "Ny Nx"] Solution ψ (zero-mean gauge when α = 0).

Source code in spectraldiffx/_src/fourier/solvers.py
def __call__(self, rhs: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
    """Solve (∇² − α)ψ = rhs.

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Right-hand side (including boundary grid points).

    Returns
    -------
    Float[Array, "Ny Nx"]
        Solution ψ (zero-mean gauge when α = 0).
    """
    return solve_helmholtz_dct1(rhs, self.dx, self.dy, self.alpha)

Mixed Per-Axis BCs

MixedBCHelmholtzSolver2D

Bases: Module

2D Helmholtz/Poisson solver with per-axis boundary conditions.

Solves (nabla^2 - alpha)psi = f where each axis can have a different boundary condition type (see :func:solve_helmholtz_2d).

Examples

Channel flow (periodic in x, Dirichlet walls in y)::

solver = MixedBCHelmholtzSolver2D(
    dx=0.1,
    dy=0.1,
    bc_x="periodic",
    bc_y="dirichlet",
)
psi = solver(rhs)

Attributes

dx : float Grid spacing in x. dy : float Grid spacing in y. bc_x : BoundaryCondition Boundary condition along the x-axis. bc_y : BoundaryCondition Boundary condition along the y-axis. alpha : float Helmholtz parameter. Default: 0.0 (Poisson).

Source code in spectraldiffx/_src/fourier/solvers.py
class MixedBCHelmholtzSolver2D(eqx.Module):
    """2D Helmholtz/Poisson solver with per-axis boundary conditions.

    Solves ``(nabla^2 - alpha)psi = f`` where each axis can have a different
    boundary condition type (see :func:`solve_helmholtz_2d`).

    Examples
    --------
    Channel flow (periodic in x, Dirichlet walls in y)::

        solver = MixedBCHelmholtzSolver2D(
            dx=0.1,
            dy=0.1,
            bc_x="periodic",
            bc_y="dirichlet",
        )
        psi = solver(rhs)

    Attributes
    ----------
    dx : float
        Grid spacing in x.
    dy : float
        Grid spacing in y.
    bc_x : BoundaryCondition
        Boundary condition along the x-axis.
    bc_y : BoundaryCondition
        Boundary condition along the y-axis.
    alpha : float
        Helmholtz parameter.  Default: 0.0 (Poisson).
    """

    dx: float
    dy: float
    bc_x: BoundaryCondition = eqx.field(static=True, default="periodic")
    bc_y: BoundaryCondition = eqx.field(static=True, default="periodic")
    alpha: float = 0.0

    def __call__(
        self,
        rhs: Float[Array, "Ny Nx"],
        *,
        bc_x_values: tuple[Float[Array, " Ny"] | None, Float[Array, " Ny"] | None] = (
            None,
            None,
        ),
        bc_y_values: tuple[Float[Array, " Nx"] | None, Float[Array, " Nx"] | None] = (
            None,
            None,
        ),
    ) -> Float[Array, "Ny Nx"]:
        """Solve (nabla^2 - alpha)psi = rhs.

        Parameters
        ----------
        rhs : Float[Array, "Ny Nx"]
            Right-hand side.
        bc_x_values : tuple[Array | None, Array | None], keyword-only
            ``(left, right)`` inhomogeneous boundary values for x-axis.
        bc_y_values : tuple[Array | None, Array | None], keyword-only
            ``(bottom, top)`` inhomogeneous boundary values for y-axis.

        Returns
        -------
        Float[Array, "Ny Nx"]
            Solution psi.
        """
        return solve_helmholtz_2d(
            rhs,
            self.dx,
            self.dy,
            bc_x=self.bc_x,
            bc_y=self.bc_y,
            lambda_=self.alpha,
            bc_x_values=bc_x_values,
            bc_y_values=bc_y_values,
        )

Functions

__call__(rhs, *, bc_x_values=(None, None), bc_y_values=(None, None))

Solve (nabla^2 - alpha)psi = rhs.

Parameters

rhs : Float[Array, "Ny Nx"] Right-hand side. bc_x_values : tuple[Array | None, Array | None], keyword-only (left, right) inhomogeneous boundary values for x-axis. bc_y_values : tuple[Array | None, Array | None], keyword-only (bottom, top) inhomogeneous boundary values for y-axis.

Returns

Float[Array, "Ny Nx"] Solution psi.

Source code in spectraldiffx/_src/fourier/solvers.py
def __call__(
    self,
    rhs: Float[Array, "Ny Nx"],
    *,
    bc_x_values: tuple[Float[Array, " Ny"] | None, Float[Array, " Ny"] | None] = (
        None,
        None,
    ),
    bc_y_values: tuple[Float[Array, " Nx"] | None, Float[Array, " Nx"] | None] = (
        None,
        None,
    ),
) -> Float[Array, "Ny Nx"]:
    """Solve (nabla^2 - alpha)psi = rhs.

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Right-hand side.
    bc_x_values : tuple[Array | None, Array | None], keyword-only
        ``(left, right)`` inhomogeneous boundary values for x-axis.
    bc_y_values : tuple[Array | None, Array | None], keyword-only
        ``(bottom, top)`` inhomogeneous boundary values for y-axis.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Solution psi.
    """
    return solve_helmholtz_2d(
        rhs,
        self.dx,
        self.dy,
        bc_x=self.bc_x,
        bc_y=self.bc_y,
        lambda_=self.alpha,
        bc_x_values=bc_x_values,
        bc_y_values=bc_y_values,
    )

MixedBCHelmholtzSolver3D

Bases: Module

3D Helmholtz/Poisson solver with per-axis boundary conditions.

Solves (nabla^2 - alpha)psi = f where each axis can have a different boundary condition type (see :func:solve_helmholtz_3d).

Examples

Atmospheric boundary layer (periodic x/y, Neumann z)::

solver = MixedBCHelmholtzSolver3D(
    dx=1000.0,
    dy=1000.0,
    dz=50.0,
    bc_x="periodic",
    bc_y="periodic",
    bc_z="neumann_stag",
)
psi = solver(rhs)

Attributes

dx : float Grid spacing in x. dy : float Grid spacing in y. dz : float Grid spacing in z. bc_x : BoundaryCondition Boundary condition along the x-axis. bc_y : BoundaryCondition Boundary condition along the y-axis. bc_z : BoundaryCondition Boundary condition along the z-axis. alpha : float Helmholtz parameter. Default: 0.0 (Poisson).

Source code in spectraldiffx/_src/fourier/solvers.py
class MixedBCHelmholtzSolver3D(eqx.Module):
    """3D Helmholtz/Poisson solver with per-axis boundary conditions.

    Solves ``(nabla^2 - alpha)psi = f`` where each axis can have a different
    boundary condition type (see :func:`solve_helmholtz_3d`).

    Examples
    --------
    Atmospheric boundary layer (periodic x/y, Neumann z)::

        solver = MixedBCHelmholtzSolver3D(
            dx=1000.0,
            dy=1000.0,
            dz=50.0,
            bc_x="periodic",
            bc_y="periodic",
            bc_z="neumann_stag",
        )
        psi = solver(rhs)

    Attributes
    ----------
    dx : float
        Grid spacing in x.
    dy : float
        Grid spacing in y.
    dz : float
        Grid spacing in z.
    bc_x : BoundaryCondition
        Boundary condition along the x-axis.
    bc_y : BoundaryCondition
        Boundary condition along the y-axis.
    bc_z : BoundaryCondition
        Boundary condition along the z-axis.
    alpha : float
        Helmholtz parameter.  Default: 0.0 (Poisson).
    """

    dx: float
    dy: float
    dz: float
    bc_x: BoundaryCondition = eqx.field(static=True, default="periodic")
    bc_y: BoundaryCondition = eqx.field(static=True, default="periodic")
    bc_z: BoundaryCondition = eqx.field(static=True, default="periodic")
    alpha: float = 0.0

    def __call__(
        self,
        rhs: Float[Array, "Nz Ny Nx"],
        *,
        bc_x_values: tuple[
            Float[Array, "Nz Ny"] | None, Float[Array, "Nz Ny"] | None
        ] = (None, None),
        bc_y_values: tuple[
            Float[Array, "Nz Nx"] | None, Float[Array, "Nz Nx"] | None
        ] = (None, None),
        bc_z_values: tuple[
            Float[Array, "Ny Nx"] | None, Float[Array, "Ny Nx"] | None
        ] = (None, None),
    ) -> Float[Array, "Nz Ny Nx"]:
        """Solve (nabla^2 - alpha)psi = rhs.

        Parameters
        ----------
        rhs : Float[Array, "Nz Ny Nx"]
            Right-hand side.
        bc_x_values : tuple[Array | None, Array | None], keyword-only
            ``(left, right)`` face arrays of shape ``(Nz, Ny)``.
        bc_y_values : tuple[Array | None, Array | None], keyword-only
            ``(bottom, top)`` face arrays of shape ``(Nz, Nx)``.
        bc_z_values : tuple[Array | None, Array | None], keyword-only
            ``(back, front)`` face arrays of shape ``(Ny, Nx)``.

        Returns
        -------
        Float[Array, "Nz Ny Nx"]
            Solution psi.
        """
        return solve_helmholtz_3d(
            rhs,
            self.dx,
            self.dy,
            self.dz,
            bc_x=self.bc_x,
            bc_y=self.bc_y,
            bc_z=self.bc_z,
            lambda_=self.alpha,
            bc_x_values=bc_x_values,
            bc_y_values=bc_y_values,
            bc_z_values=bc_z_values,
        )

Functions

__call__(rhs, *, bc_x_values=(None, None), bc_y_values=(None, None), bc_z_values=(None, None))

Solve (nabla^2 - alpha)psi = rhs.

Parameters

rhs : Float[Array, "Nz Ny Nx"] Right-hand side. bc_x_values : tuple[Array | None, Array | None], keyword-only (left, right) face arrays of shape (Nz, Ny). bc_y_values : tuple[Array | None, Array | None], keyword-only (bottom, top) face arrays of shape (Nz, Nx). bc_z_values : tuple[Array | None, Array | None], keyword-only (back, front) face arrays of shape (Ny, Nx).

Returns

Float[Array, "Nz Ny Nx"] Solution psi.

Source code in spectraldiffx/_src/fourier/solvers.py
def __call__(
    self,
    rhs: Float[Array, "Nz Ny Nx"],
    *,
    bc_x_values: tuple[
        Float[Array, "Nz Ny"] | None, Float[Array, "Nz Ny"] | None
    ] = (None, None),
    bc_y_values: tuple[
        Float[Array, "Nz Nx"] | None, Float[Array, "Nz Nx"] | None
    ] = (None, None),
    bc_z_values: tuple[
        Float[Array, "Ny Nx"] | None, Float[Array, "Ny Nx"] | None
    ] = (None, None),
) -> Float[Array, "Nz Ny Nx"]:
    """Solve (nabla^2 - alpha)psi = rhs.

    Parameters
    ----------
    rhs : Float[Array, "Nz Ny Nx"]
        Right-hand side.
    bc_x_values : tuple[Array | None, Array | None], keyword-only
        ``(left, right)`` face arrays of shape ``(Nz, Ny)``.
    bc_y_values : tuple[Array | None, Array | None], keyword-only
        ``(bottom, top)`` face arrays of shape ``(Nz, Nx)``.
    bc_z_values : tuple[Array | None, Array | None], keyword-only
        ``(back, front)`` face arrays of shape ``(Ny, Nx)``.

    Returns
    -------
    Float[Array, "Nz Ny Nx"]
        Solution psi.
    """
    return solve_helmholtz_3d(
        rhs,
        self.dx,
        self.dy,
        self.dz,
        bc_x=self.bc_x,
        bc_y=self.bc_y,
        bc_z=self.bc_z,
        lambda_=self.alpha,
        bc_x_values=bc_x_values,
        bc_y_values=bc_y_values,
        bc_z_values=bc_z_values,
    )