Skip to content

Convenience Wrappers

High-level functions for common elliptic PDE inversions on C-grids. Each wrapper dispatches to spectral, CG, or capacitance solvers depending on the method argument.

Streamfunction from Vorticity

finitevolx.streamfunction_from_vorticity(zeta, dx, dy, bc='dst', lambda_=0.0, method='spectral', mask=None, capacitance_solver=None, preconditioner=None)

Invert the vorticity–streamfunction relation ∇²ψ − λψ = ζ.

Solves the Poisson (λ = 0) or Helmholtz (λ ≠ 0) equation to recover the streamfunction from relative vorticity.

Three solver methods are available:

  • "spectral" — Direct spectral solver (DST/DCT/FFT) for rectangular domains. Selected by bc. Default.
  • "cg" — Preconditioned Conjugate Gradient for masked / irregular domains. Requires mask. Uses a spectral preconditioner by default, or a custom one via preconditioner.
  • "capacitance" — Capacitance matrix method for masked domains. Requires a pre-built :class:CapacitanceSolver via capacitance_solver.

Parameters:

Name Type Description Default
zeta Float[Array, 'Ny Nx']

Relative vorticity (right-hand side).

required
dx float

Grid spacings.

required
dy float

Grid spacings.

required
bc ('dst', 'dct', 'fft')

Boundary-condition type for the spectral solver (used by method="spectral"). "dst" (Dirichlet, ψ = 0 on boundary) is the most common choice for streamfunction inversion.

"dst"
lambda_ float

Helmholtz parameter. Use 0.0 for the pure Poisson problem (streamfunction from vorticity). Non-zero values arise in QG PV inversion: (∇² − λ)ψ = q.

0.0
method ('spectral', 'cg', 'capacitance')

Solver method. Default: "spectral".

"spectral"
mask Float[Array, 'Ny Nx'] or ArakawaCGridMask or None

Domain mask. Required for method="cg". When an :class:ArakawaCGridMask is passed the psi staggering mask is extracted automatically.

None
capacitance_solver CapacitanceSolver or None

Pre-built capacitance solver. Required for method="capacitance".

None
preconditioner callable or None

Custom preconditioner for method="cg". Signature: preconditioner(r: Array) -> Array. When None, a spectral preconditioner (FFT-based) is used automatically.

None

Returns:

Type Description
Float[Array, 'Ny Nx']

Streamfunction ψ.

Source code in finitevolx/_src/solvers/elliptic.py
def streamfunction_from_vorticity(
    zeta: Float[Array, "Ny Nx"],
    dx: float,
    dy: float,
    bc: str = "dst",
    lambda_: float = 0.0,
    method: str = "spectral",
    mask: _MaskLike | None = None,
    capacitance_solver: CapacitanceSolver | None = None,
    preconditioner: _PrecondLike | None = None,
) -> Float[Array, "Ny Nx"]:
    r"""Invert the vorticity–streamfunction relation ∇²ψ − λψ = ζ.

    Solves the Poisson (λ = 0) or Helmholtz (λ ≠ 0) equation to recover the
    streamfunction from relative vorticity.

    Three solver methods are available:

    * ``"spectral"`` — Direct spectral solver (DST/DCT/FFT) for rectangular
      domains.  Selected by *bc*.  Default.
    * ``"cg"`` — Preconditioned Conjugate Gradient for masked / irregular
      domains.  Requires *mask*.  Uses a spectral preconditioner by default,
      or a custom one via *preconditioner*.
    * ``"capacitance"`` — Capacitance matrix method for masked domains.
      Requires a pre-built :class:`CapacitanceSolver` via
      *capacitance_solver*.

    Parameters
    ----------
    zeta : Float[Array, "Ny Nx"]
        Relative vorticity (right-hand side).
    dx, dy : float
        Grid spacings.
    bc : {"dst", "dct", "fft"}
        Boundary-condition type for the spectral solver (used by
        ``method="spectral"``).
        ``"dst"`` (Dirichlet, ψ = 0 on boundary) is the most common choice
        for streamfunction inversion.
    lambda_ : float
        Helmholtz parameter.  Use 0.0 for the pure Poisson problem
        (streamfunction from vorticity).  Non-zero values arise in QG PV
        inversion: (∇² − λ)ψ = q.
    method : {"spectral", "cg", "capacitance"}
        Solver method.  Default: ``"spectral"``.
    mask : Float[Array, "Ny Nx"] or ArakawaCGridMask or None
        Domain mask.  Required for ``method="cg"``.  When an
        :class:`ArakawaCGridMask` is passed the ``psi`` staggering mask is
        extracted automatically.
    capacitance_solver : CapacitanceSolver or None
        Pre-built capacitance solver.  Required for
        ``method="capacitance"``.
    preconditioner : callable or None
        Custom preconditioner for ``method="cg"``.  Signature:
        ``preconditioner(r: Array) -> Array``.  When ``None``, a spectral
        preconditioner (FFT-based) is used automatically.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Streamfunction ψ.
    """
    return _solve_dispatch(
        zeta, dx, dy, lambda_, bc, method, mask, capacitance_solver, preconditioner
    )

Pressure from Divergence

finitevolx.pressure_from_divergence(div_u, dx, dy, bc='dct', method='spectral', mask=None, capacitance_solver=None, preconditioner=None)

Solve ∇²p = ∇·u for the pressure correction.

Used in pressure-projection methods (Chorin splitting) where the divergence of the provisional velocity field must be removed.

Solver selection follows the same three-method dispatch as :func:streamfunction_from_vorticity.

Parameters:

Name Type Description Default
div_u Float[Array, 'Ny Nx']

Divergence of the velocity field (right-hand side).

required
dx float

Grid spacings.

required
dy float

Grid spacings.

required
bc ('dct', 'dst', 'fft')

Boundary-condition type for the spectral solver. "dct" (Neumann, ∂p/∂n = 0) is the standard choice for pressure with solid walls.

"dct"
method ('spectral', 'cg', 'capacitance')

Solver method. Default: "spectral".

"spectral"
mask Float[Array, 'Ny Nx'] or ArakawaCGridMask or None

Domain mask. Required for method="cg".

None
capacitance_solver CapacitanceSolver or None

Pre-built capacitance solver. Required for method="capacitance".

None
preconditioner callable or None

Custom preconditioner for method="cg".

None

Returns:

Type Description
Float[Array, 'Ny Nx']

Pressure field p.

Source code in finitevolx/_src/solvers/elliptic.py
def pressure_from_divergence(
    div_u: Float[Array, "Ny Nx"],
    dx: float,
    dy: float,
    bc: str = "dct",
    method: str = "spectral",
    mask: _MaskLike | None = None,
    capacitance_solver: CapacitanceSolver | None = None,
    preconditioner: _PrecondLike | None = None,
) -> Float[Array, "Ny Nx"]:
    r"""Solve ∇²p = ∇·u for the pressure correction.

    Used in pressure-projection methods (Chorin splitting) where the
    divergence of the provisional velocity field must be removed.

    Solver selection follows the same three-method dispatch as
    :func:`streamfunction_from_vorticity`.

    Parameters
    ----------
    div_u : Float[Array, "Ny Nx"]
        Divergence of the velocity field (right-hand side).
    dx, dy : float
        Grid spacings.
    bc : {"dct", "dst", "fft"}
        Boundary-condition type for the spectral solver.
        ``"dct"`` (Neumann, ∂p/∂n = 0) is the standard choice for
        pressure with solid walls.
    method : {"spectral", "cg", "capacitance"}
        Solver method.  Default: ``"spectral"``.
    mask : Float[Array, "Ny Nx"] or ArakawaCGridMask or None
        Domain mask.  Required for ``method="cg"``.
    capacitance_solver : CapacitanceSolver or None
        Pre-built capacitance solver.  Required for
        ``method="capacitance"``.
    preconditioner : callable or None
        Custom preconditioner for ``method="cg"``.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Pressure field p.
    """
    return _solve_dispatch(
        div_u, dx, dy, 0.0, bc, method, mask, capacitance_solver, preconditioner
    )

PV Inversion

finitevolx.pv_inversion(pv, dx, dy, lambda_, bc='dst', method='spectral', mask=None, capacitance_solver=None, preconditioner=None)

QG potential-vorticity inversion: solve (∇² − λ)ψ = q.

Supports batched / multi-layer PV fields. When lambda_ is a 1-D array of shape (nl,), each layer is solved with its own Helmholtz parameter (e.g. 1/Rd² per vertical mode from :func:~finitevolx.decompose_vertical_modes).

Solver selection follows the same three-method dispatch as :func:streamfunction_from_vorticity.

Parameters:

Name Type Description Default
pv Float[Array, '... Ny Nx']

Potential-vorticity field. Leading dimensions are batched.

required
dx float

Grid spacings.

required
dy float

Grid spacings.

required
lambda_ float or Float[Array, ' nl']

Helmholtz parameter(s). Scalar for a single layer; array of shape (nl,) for multi-layer inversion.

required
bc ('dst', 'dct', 'fft')

Boundary-condition type (for method="spectral").

"dst"
method ('spectral', 'cg', 'capacitance')

Solver method. Default: "spectral".

"spectral"
mask Float[Array, 'Ny Nx'] or ArakawaCGridMask or None

Domain mask. Required for method="cg".

None
capacitance_solver CapacitanceSolver or None

Pre-built capacitance solver. Required for method="capacitance".

None
preconditioner callable or None

Custom preconditioner for method="cg".

None

Returns:

Type Description
Float[Array, '... Ny Nx']

Streamfunction ψ, same shape as pv.

Source code in finitevolx/_src/solvers/elliptic.py
def pv_inversion(
    pv: Float[Array, "... Ny Nx"],
    dx: float,
    dy: float,
    lambda_: float | Float[Array, " nl"],
    bc: str = "dst",
    method: str = "spectral",
    mask: _MaskLike | None = None,
    capacitance_solver: CapacitanceSolver | None = None,
    preconditioner: _PrecondLike | None = None,
) -> Float[Array, "... Ny Nx"]:
    r"""QG potential-vorticity inversion: solve (∇² − λ)ψ = q.

    Supports batched / multi-layer PV fields.  When *lambda_* is a 1-D
    array of shape ``(nl,)``, each layer is solved with its own Helmholtz
    parameter (e.g. 1/Rd² per vertical mode from
    :func:`~finitevolx.decompose_vertical_modes`).

    Solver selection follows the same three-method dispatch as
    :func:`streamfunction_from_vorticity`.

    Parameters
    ----------
    pv : Float[Array, "... Ny Nx"]
        Potential-vorticity field.  Leading dimensions are batched.
    dx, dy : float
        Grid spacings.
    lambda_ : float or Float[Array, " nl"]
        Helmholtz parameter(s).  Scalar for a single layer; array of
        shape ``(nl,)`` for multi-layer inversion.
    bc : {"dst", "dct", "fft"}
        Boundary-condition type (for ``method="spectral"``).
    method : {"spectral", "cg", "capacitance"}
        Solver method.  Default: ``"spectral"``.
    mask : Float[Array, "Ny Nx"] or ArakawaCGridMask or None
        Domain mask.  Required for ``method="cg"``.
    capacitance_solver : CapacitanceSolver or None
        Pre-built capacitance solver.  Required for
        ``method="capacitance"``.
    preconditioner : callable or None
        Custom preconditioner for ``method="cg"``.

    Returns
    -------
    Float[Array, "... Ny Nx"]
        Streamfunction ψ, same shape as *pv*.
    """
    lam = jnp.asarray(lambda_)

    if lam.ndim == 0:
        # Scalar lambda: vmap over all leading dims if present
        if pv.ndim == 2:
            return _solve_dispatch(
                pv,
                dx,
                dy,
                float(lam),
                bc,
                method,
                mask,
                capacitance_solver,
                preconditioner,
            )
        # Flatten leading dims, solve each, reshape
        shape = pv.shape
        flat = pv.reshape(-1, shape[-2], shape[-1])

        def _solve_one(rhs: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
            return _solve_dispatch(
                rhs,
                dx,
                dy,
                float(lam),
                bc,
                method,
                mask,
                capacitance_solver,
                preconditioner,
            )

        out = eqx.filter_vmap(_solve_one)(flat)
        return out.reshape(shape)

    # Array lambda: leading dim must match lam.shape[0]
    if pv.ndim < 3:
        raise ValueError(
            f"pv must have at least 3 dims when lambda_ is an array, "
            f"got shape {pv.shape}"
        )
    nl = lam.shape[0]
    if pv.shape[-3] != nl:
        raise ValueError(
            f"pv.shape[-3]={pv.shape[-3]} does not match lambda_ length {nl}"
        )

    # Solve each layer with its own lambda.
    # We call the Helmholtz solver directly (not _solve_dispatch) because
    # lam_i is a JAX tracer inside vmap and Python-level ``if lam == 0``
    # branches in _spectral_solve would fail.
    if method == "capacitance":
        raise ValueError(
            "method='capacitance' does not support array-valued lambda_; "
            "solve each layer separately or use method='spectral' or 'cg' "
            "for multi-layer problems."
        )

    elif method == "cg":
        mask_arr = _resolve_mask_arr(mask)
        if mask_arr is None:
            raise ValueError("method='cg' requires a mask")

        _precond = preconditioner

        def _solve_layer(
            rhs: Float[Array, "Ny Nx"], lam_i: float
        ) -> Float[Array, "Ny Nx"]:
            def _matvec(x: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
                return masked_laplacian(x, mask_arr, dx, dy, lambda_=lam_i)

            pc = (
                _precond
                if _precond is not None
                else make_spectral_preconditioner(dx, dy, lambda_=lam_i, bc="fft")
            )
            x, _info = solve_cg(_matvec, rhs * mask_arr, preconditioner=pc)
            return x * mask_arr

    elif method == "spectral":
        _helmholtz = _HELMHOLTZ_DISPATCH.get(bc)
        if _helmholtz is None:
            raise ValueError(f"bc must be 'fft', 'dst', or 'dct'; got {bc!r}")

        def _solve_layer(
            rhs: Float[Array, "Ny Nx"], lam_i: float
        ) -> Float[Array, "Ny Nx"]:
            return _helmholtz(rhs, dx, dy, lam_i)

    else:
        raise ValueError(
            f"method must be 'spectral', 'cg', or 'capacitance'; got {method!r}"
        )

    # Flatten any leading batch dims: (..., nl, Ny, Nx) -> (batch, nl, Ny, Nx)
    shape = pv.shape
    ny, nx = shape[-2], shape[-1]
    pv_4d = pv.reshape(-1, nl, ny, nx)

    # vmap over layer axis (pairing each layer with its lambda)
    _solve_layers = eqx.filter_vmap(_solve_layer, in_axes=(0, 0))

    # vmap over the (flattened) batch axis
    out_4d = eqx.filter_vmap(lambda batch: _solve_layers(batch, lam))(pv_4d)
    return out_4d.reshape(shape)