Skip to content

Multigrid Helmholtz Solver

Geometric multigrid solver for variable-coefficient Helmholtz equations on masked Arakawa C-grids.

Solver

finitevolx.MultigridSolver

Bases: Module

Geometric multigrid V-cycle solver for the variable-coefficient Helmholtz equation::

div(c(x,y) grad u) - lambda u = rhs

Supports spatially varying coefficients c(x, y) on staggered faces and masked (irregular) domains.

V-Cycle Algorithm

The V-cycle is a recursive algorithm that visits coarser grids to correct low-frequency error that the smoother cannot resolve::

Level 0 (fine)    *---smooth---*-----------*---smooth---*
                       | restrict    prolong |
Level 1           .....*---smooth---*---smooth---*.....
                            | restrict  prolong |
Level 2 (coarse)  ..........*--bottom solve--*..........

At each level:

  1. Pre-smooth (nu_1 weighted Jacobi iterations): damp high-frequency error.
  2. Compute residual: r = rhs - A u.
  3. Restrict residual to the coarse grid (2x2 averaging).
  4. Recurse: solve for the error on the coarse grid.
  5. Prolongate the coarse correction back to the fine grid (bilinear interpolation).
  6. Post-smooth (nu_2 weighted Jacobi iterations): damp any high-frequency error introduced by the prolongation.

The recursion is statically unrolled at JAX trace time because each level has a different array shape. All integer parameters (n_levels, n_pre, etc.) are eqx.field(static=True), so the unrolled structure is visible to the XLA compiler.

Differentiation Modes

Three solve modes trade off backward-pass cost vs gradient accuracy:

  • __call__Implicit differentiation via jax.lax.custom_linear_solve(symmetric=True). The backward pass solves the adjoint system A^T v = dL/du with one multigrid call. Since A is symmetric, this costs the same as the forward pass. O(1) memory, exact gradients for the linear system being solved (gradient accuracy is limited by how well the V-cycles approximate A^{-1}, i.e. depends on n_cycles and smoother settings).

  • solve_onestepOne-step differentiation (Bolte, Pauwels & Vaiter, NeurIPS 2023). Runs K V-cycles, applies stop_gradient after K-1, then autodiffs through the last cycle only. O(1 V-cycle) memory, approximate gradients with error O(rho).

  • solve_unrolledUnrolled differentiation via jax.lax.fori_loop. Backward replays all K iterations. O(K) memory, exact through-iteration gradients (reproduces the forward computation exactly, so gradient accuracy matches the forward solve accuracy).

Parameters:

Name Type Description Default
levels tuple of MultigridLevel

Precomputed level data, finest (index 0) to coarsest (index L-1).

required
n_levels int

Number of multigrid levels.

required
n_pre int

Pre- and post-smoothing iterations (weighted Jacobi).

required
n_post int

Pre- and post-smoothing iterations (weighted Jacobi).

required
n_coarse int

Jacobi iterations on the coarsest grid (bottom solver).

required
omega float

Jacobi relaxation weight (typically 0.8-0.95).

required
n_cycles int

Number of V-cycles per solve.

required
Source code in finitevolx/_src/solvers/multigrid.py
class MultigridSolver(eqx.Module):
    r"""Geometric multigrid V-cycle solver for the variable-coefficient
    Helmholtz equation::

        div(c(x,y) grad u) - lambda u = rhs

    Supports spatially varying coefficients ``c(x, y)`` on staggered faces
    and masked (irregular) domains.

    V-Cycle Algorithm
    -----------------
    The V-cycle is a recursive algorithm that visits coarser grids to
    correct low-frequency error that the smoother cannot resolve::

        Level 0 (fine)    *---smooth---*-----------*---smooth---*
                               | restrict    prolong |
        Level 1           .....*---smooth---*---smooth---*.....
                                    | restrict  prolong |
        Level 2 (coarse)  ..........*--bottom solve--*..........

    At each level:

    1. **Pre-smooth** (nu_1 weighted Jacobi iterations): damp
       high-frequency error.
    2. **Compute residual**: ``r = rhs - A u``.
    3. **Restrict** residual to the coarse grid (2x2 averaging).
    4. **Recurse**: solve for the error on the coarse grid.
    5. **Prolongate** the coarse correction back to the fine grid
       (bilinear interpolation).
    6. **Post-smooth** (nu_2 weighted Jacobi iterations): damp any
       high-frequency error introduced by the prolongation.

    The recursion is **statically unrolled** at JAX trace time because
    each level has a different array shape.  All integer parameters
    (``n_levels``, ``n_pre``, etc.) are ``eqx.field(static=True)``,
    so the unrolled structure is visible to the XLA compiler.

    Differentiation Modes
    ---------------------
    Three solve modes trade off backward-pass cost vs gradient accuracy:

    * ``__call__`` — **Implicit differentiation** via
      ``jax.lax.custom_linear_solve(symmetric=True)``.  The backward pass
      solves the adjoint system ``A^T v = dL/du`` with one multigrid call.
      Since ``A`` is symmetric, this costs the same as the forward pass.
      O(1) memory, exact gradients for the linear system being solved
      (gradient accuracy is limited by how well the V-cycles approximate
      ``A^{-1}``, i.e. depends on ``n_cycles`` and smoother settings).

    * ``solve_onestep`` — **One-step differentiation** (Bolte, Pauwels &
      Vaiter, NeurIPS 2023).  Runs K V-cycles, applies ``stop_gradient``
      after K-1, then autodiffs through the last cycle only.  O(1 V-cycle)
      memory, approximate gradients with error O(rho).

    * ``solve_unrolled`` — **Unrolled differentiation** via
      ``jax.lax.fori_loop``.  Backward replays all K iterations.
      O(K) memory, exact through-iteration gradients (reproduces the
      forward computation exactly, so gradient accuracy matches the
      forward solve accuracy).

    Parameters
    ----------
    levels : tuple of MultigridLevel
        Precomputed level data, finest (index 0) to coarsest (index L-1).
    n_levels : int
        Number of multigrid levels.
    n_pre, n_post : int
        Pre- and post-smoothing iterations (weighted Jacobi).
    n_coarse : int
        Jacobi iterations on the coarsest grid (bottom solver).
    omega : float
        Jacobi relaxation weight (typically 0.8-0.95).
    n_cycles : int
        Number of V-cycles per solve.
    """

    levels: tuple[MultigridLevel, ...]
    n_levels: int = eqx.field(static=True)
    n_pre: int = eqx.field(static=True, default=6)
    n_post: int = eqx.field(static=True, default=6)
    n_coarse: int = eqx.field(static=True, default=50)
    omega: float = eqx.field(static=True, default=0.95)
    n_cycles: int = eqx.field(static=True, default=5)

    # -- Public API ----------------------------------------------------------

    def __call__(self, rhs: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
        r"""Solve ``A u = rhs`` with implicit differentiation.

        The forward pass runs K V-cycles (identical to :meth:`solve_unrolled`).

        The backward pass uses ``jax.lax.custom_linear_solve`` with
        ``symmetric=True`` to compute gradients via the implicit function
        theorem (IFT) rather than unrolling through V-cycle iterations.

        For a scalar loss ``L(u)``, the gradient w.r.t. the RHS is::

            dL/d(rhs) = A^{-T} dL/du = A^{-1} dL/du   (since A = A^T)

        This adjoint solve is just another multigrid call — so the
        backward pass costs the same as the forward pass, with O(1) extra
        memory (no iteration history stored).

        Parameters
        ----------
        rhs : Float[Array, "Ny Nx"]
            Right-hand side of the linear system.

        Returns
        -------
        Float[Array, "Ny Nx"]
            Approximate solution ``u``.
        """
        level = self.levels[0]

        # _matvec defines the linear operator A for custom_linear_solve
        def _matvec(u: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
            return _apply_operator(u, level)

        # _solve provides the forward solve; custom_linear_solve will also
        # call it for the backward (adjoint) pass since symmetric=True
        def _solve(
            _matvec_fn: Callable,
            b: Float[Array, "Ny Nx"],
        ) -> Float[Array, "Ny Nx"]:
            return self._run_vcycles(b)

        return jax.lax.custom_linear_solve(_matvec, rhs, solve=_solve, symmetric=True)

    def solve_onestep(self, rhs: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
        r"""Solve with one-step differentiation (Bolte et al., NeurIPS 2023).

        Runs K V-cycles to convergence, but autodiff only sees the **last**
        cycle.  The first K-1 cycles are wrapped in ``jax.lax.stop_gradient``
        so they contribute no backward-pass cost.

        The forward result is identical to :meth:`solve_unrolled`.  The
        gradient approximation error is O(rho) where rho is the per-cycle
        convergence rate (typically 0.1-0.3 for multigrid).

        Gradient structure::

            u_0 = 0
            u_1 = V(u_0, rhs)
            ...
            u_{K-1} = V(u_{K-2}, rhs)      <-- stop_gradient here
            u_K     = V(u_{K-1}, rhs)       <-- autodiff traces only this

        Parameters
        ----------
        rhs : Float[Array, "Ny Nx"]
            Right-hand side.

        Returns
        -------
        Float[Array, "Ny Nx"]
            Approximate solution.

        References
        ----------
        Bolte, Pauwels & Vaiter (NeurIPS 2023). "One-step differentiation
        of iterative algorithms." https://arxiv.org/abs/2305.13768
        """
        u = jnp.zeros_like(rhs)

        # Run K-1 cycles with stop_gradient: the forward pass converges
        # normally, but no gradient graph is built for these iterations
        if self.n_cycles > 1:

            def _body(_: int, u: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
                return self.v_cycle(u, rhs)

            u = jax.lax.fori_loop(0, self.n_cycles - 1, _body, u)
            u = jax.lax.stop_gradient(u)

        # Final V-cycle: JAX autodiff traces through this one only.
        # The gradient cost is O(1 V-cycle) regardless of n_cycles.
        return self.v_cycle(u, rhs)

    def solve_unrolled(self, rhs: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
        """Solve by unrolling all V-cycles through ``lax.fori_loop``.

        The backward pass differentiates through every iteration, storing
        intermediate states for replay.  This costs O(n_cycles) memory.

        Use this mode when you specifically need gradients through the
        iteration dynamics itself.  For most applications, prefer
        ``__call__`` (implicit differentiation) which gives exact gradients
        at O(1) memory cost.

        Parameters
        ----------
        rhs : Float[Array, "Ny Nx"]
            Right-hand side.

        Returns
        -------
        Float[Array, "Ny Nx"]
            Approximate solution.
        """
        return self._run_vcycles(rhs)

    def v_cycle(
        self,
        u: Float[Array, "Ny Nx"],
        rhs: Float[Array, "Ny Nx"],
        level_idx: int = 0,
    ) -> Float[Array, "Ny Nx"]:
        """Execute a single multigrid V-cycle starting at *level_idx*.

        Algorithm::

            if coarsest level:
                return jacobi(u, rhs, n_coarse)   # bottom solve

            u = jacobi(u, rhs, n_pre)             # 1. pre-smooth
            r = rhs - A(u)                        # 2. compute residual
            r_c = restrict(r)                     # 3. restrict to coarse grid
            e_c = v_cycle(0, r_c, level+1)        # 4. recurse (solve A_c e_c = r_c)
            u = u + prolongate(e_c)               # 5. correct with coarse error
            u = jacobi(u, rhs, n_post)            # 6. post-smooth

        The recursion unrolls statically at JAX trace time because
        ``level_idx`` and ``n_levels`` are Python ints (static fields).

        Parameters
        ----------
        u : Float[Array, "Ny Nx"]
            Initial guess (typically zeros for the error equation on
            coarse grids, or the current iterate on the fine grid).
        rhs : Float[Array, "Ny Nx"]
            Right-hand side (original RHS on the fine grid, or the
            restricted residual on coarser grids).
        level_idx : int
            Current level (0 = finest, n_levels-1 = coarsest).

        Returns
        -------
        Float[Array, "Ny Nx"]
            Improved solution after one V-cycle.
        """
        level = self.levels[level_idx]

        # --- Coarsest level: iterated Jacobi as the bottom solver ---
        # The grid is small (typically 8x8 to 16x16), so many Jacobi
        # iterations are cheap and sufficient for convergence.
        if level_idx == self.n_levels - 1:
            return _weighted_jacobi(u, rhs, level, self.omega, self.n_coarse)

        # 1. Pre-smooth: damp high-frequency error on the current grid
        u = _weighted_jacobi(u, rhs, level, self.omega, self.n_pre)

        # 2. Compute residual: r = f - A(u)
        #    After smoothing, r contains mostly low-frequency error that
        #    the smoother cannot resolve at this grid resolution.
        r = (rhs - _apply_operator(u, level)) * level.mask

        # 3. Restrict residual to the coarse grid (2x coarser in each dim)
        coarse_level = self.levels[level_idx + 1]
        r_coarse = _restrict(r, level.mask, coarse_level.mask)

        # 4. Recurse: solve A_c * e_c = r_c on the coarse grid.
        #    Start from zero because we're solving for the *error*, not
        #    the solution itself.  On the coarse grid, the low-frequency
        #    residual becomes high-frequency and can be efficiently damped.
        e_coarse = self.v_cycle(jnp.zeros_like(r_coarse), r_coarse, level_idx + 1)

        # 5. Prolongate (interpolate) the coarse correction back to the
        #    fine grid and add to the current solution: u <- u + e_fine
        e_fine = _prolongate(e_coarse, coarse_level.mask, level.mask)
        u = (u + e_fine) * level.mask

        # 6. Post-smooth: clean up high-frequency error introduced by
        #    the coarse-to-fine interpolation
        u = _weighted_jacobi(u, rhs, level, self.omega, self.n_post)
        return u

    # -- Internal ------------------------------------------------------------

    def _run_vcycles(self, rhs: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
        """Run *n_cycles* V-cycles from a zero initial guess.

        Uses ``jax.lax.fori_loop`` so that the iteration count does not
        increase the traced program size (unlike a Python for-loop, which
        would unroll each cycle into separate XLA operations).
        """

        def _body(_: int, u: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
            return self.v_cycle(u, rhs)

        u0 = jnp.zeros_like(rhs)
        return jax.lax.fori_loop(0, self.n_cycles, _body, u0)

__call__(rhs)

Solve A u = rhs with implicit differentiation.

The forward pass runs K V-cycles (identical to :meth:solve_unrolled).

The backward pass uses jax.lax.custom_linear_solve with symmetric=True to compute gradients via the implicit function theorem (IFT) rather than unrolling through V-cycle iterations.

For a scalar loss L(u), the gradient w.r.t. the RHS is::

dL/d(rhs) = A^{-T} dL/du = A^{-1} dL/du   (since A = A^T)

This adjoint solve is just another multigrid call — so the backward pass costs the same as the forward pass, with O(1) extra memory (no iteration history stored).

Parameters:

Name Type Description Default
rhs Float[Array, 'Ny Nx']

Right-hand side of the linear system.

required

Returns:

Type Description
Float[Array, 'Ny Nx']

Approximate solution u.

Source code in finitevolx/_src/solvers/multigrid.py
def __call__(self, rhs: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
    r"""Solve ``A u = rhs`` with implicit differentiation.

    The forward pass runs K V-cycles (identical to :meth:`solve_unrolled`).

    The backward pass uses ``jax.lax.custom_linear_solve`` with
    ``symmetric=True`` to compute gradients via the implicit function
    theorem (IFT) rather than unrolling through V-cycle iterations.

    For a scalar loss ``L(u)``, the gradient w.r.t. the RHS is::

        dL/d(rhs) = A^{-T} dL/du = A^{-1} dL/du   (since A = A^T)

    This adjoint solve is just another multigrid call — so the
    backward pass costs the same as the forward pass, with O(1) extra
    memory (no iteration history stored).

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Right-hand side of the linear system.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Approximate solution ``u``.
    """
    level = self.levels[0]

    # _matvec defines the linear operator A for custom_linear_solve
    def _matvec(u: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
        return _apply_operator(u, level)

    # _solve provides the forward solve; custom_linear_solve will also
    # call it for the backward (adjoint) pass since symmetric=True
    def _solve(
        _matvec_fn: Callable,
        b: Float[Array, "Ny Nx"],
    ) -> Float[Array, "Ny Nx"]:
        return self._run_vcycles(b)

    return jax.lax.custom_linear_solve(_matvec, rhs, solve=_solve, symmetric=True)

solve_onestep(rhs)

Solve with one-step differentiation (Bolte et al., NeurIPS 2023).

Runs K V-cycles to convergence, but autodiff only sees the last cycle. The first K-1 cycles are wrapped in jax.lax.stop_gradient so they contribute no backward-pass cost.

The forward result is identical to :meth:solve_unrolled. The gradient approximation error is O(rho) where rho is the per-cycle convergence rate (typically 0.1-0.3 for multigrid).

Gradient structure::

u_0 = 0
u_1 = V(u_0, rhs)
...
u_{K-1} = V(u_{K-2}, rhs)      <-- stop_gradient here
u_K     = V(u_{K-1}, rhs)       <-- autodiff traces only this

Parameters:

Name Type Description Default
rhs Float[Array, 'Ny Nx']

Right-hand side.

required

Returns:

Type Description
Float[Array, 'Ny Nx']

Approximate solution.

References

Bolte, Pauwels & Vaiter (NeurIPS 2023). "One-step differentiation of iterative algorithms." https://arxiv.org/abs/2305.13768

Source code in finitevolx/_src/solvers/multigrid.py
def solve_onestep(self, rhs: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
    r"""Solve with one-step differentiation (Bolte et al., NeurIPS 2023).

    Runs K V-cycles to convergence, but autodiff only sees the **last**
    cycle.  The first K-1 cycles are wrapped in ``jax.lax.stop_gradient``
    so they contribute no backward-pass cost.

    The forward result is identical to :meth:`solve_unrolled`.  The
    gradient approximation error is O(rho) where rho is the per-cycle
    convergence rate (typically 0.1-0.3 for multigrid).

    Gradient structure::

        u_0 = 0
        u_1 = V(u_0, rhs)
        ...
        u_{K-1} = V(u_{K-2}, rhs)      <-- stop_gradient here
        u_K     = V(u_{K-1}, rhs)       <-- autodiff traces only this

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Right-hand side.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Approximate solution.

    References
    ----------
    Bolte, Pauwels & Vaiter (NeurIPS 2023). "One-step differentiation
    of iterative algorithms." https://arxiv.org/abs/2305.13768
    """
    u = jnp.zeros_like(rhs)

    # Run K-1 cycles with stop_gradient: the forward pass converges
    # normally, but no gradient graph is built for these iterations
    if self.n_cycles > 1:

        def _body(_: int, u: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
            return self.v_cycle(u, rhs)

        u = jax.lax.fori_loop(0, self.n_cycles - 1, _body, u)
        u = jax.lax.stop_gradient(u)

    # Final V-cycle: JAX autodiff traces through this one only.
    # The gradient cost is O(1 V-cycle) regardless of n_cycles.
    return self.v_cycle(u, rhs)

solve_unrolled(rhs)

Solve by unrolling all V-cycles through lax.fori_loop.

The backward pass differentiates through every iteration, storing intermediate states for replay. This costs O(n_cycles) memory.

Use this mode when you specifically need gradients through the iteration dynamics itself. For most applications, prefer __call__ (implicit differentiation) which gives exact gradients at O(1) memory cost.

Parameters:

Name Type Description Default
rhs Float[Array, 'Ny Nx']

Right-hand side.

required

Returns:

Type Description
Float[Array, 'Ny Nx']

Approximate solution.

Source code in finitevolx/_src/solvers/multigrid.py
def solve_unrolled(self, rhs: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
    """Solve by unrolling all V-cycles through ``lax.fori_loop``.

    The backward pass differentiates through every iteration, storing
    intermediate states for replay.  This costs O(n_cycles) memory.

    Use this mode when you specifically need gradients through the
    iteration dynamics itself.  For most applications, prefer
    ``__call__`` (implicit differentiation) which gives exact gradients
    at O(1) memory cost.

    Parameters
    ----------
    rhs : Float[Array, "Ny Nx"]
        Right-hand side.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Approximate solution.
    """
    return self._run_vcycles(rhs)

v_cycle(u, rhs, level_idx=0)

Execute a single multigrid V-cycle starting at level_idx.

Algorithm::

if coarsest level:
    return jacobi(u, rhs, n_coarse)   # bottom solve

u = jacobi(u, rhs, n_pre)             # 1. pre-smooth
r = rhs - A(u)                        # 2. compute residual
r_c = restrict(r)                     # 3. restrict to coarse grid
e_c = v_cycle(0, r_c, level+1)        # 4. recurse (solve A_c e_c = r_c)
u = u + prolongate(e_c)               # 5. correct with coarse error
u = jacobi(u, rhs, n_post)            # 6. post-smooth

The recursion unrolls statically at JAX trace time because level_idx and n_levels are Python ints (static fields).

Parameters:

Name Type Description Default
u Float[Array, 'Ny Nx']

Initial guess (typically zeros for the error equation on coarse grids, or the current iterate on the fine grid).

required
rhs Float[Array, 'Ny Nx']

Right-hand side (original RHS on the fine grid, or the restricted residual on coarser grids).

required
level_idx int

Current level (0 = finest, n_levels-1 = coarsest).

0

Returns:

Type Description
Float[Array, 'Ny Nx']

Improved solution after one V-cycle.

Source code in finitevolx/_src/solvers/multigrid.py
def v_cycle(
    self,
    u: Float[Array, "Ny Nx"],
    rhs: Float[Array, "Ny Nx"],
    level_idx: int = 0,
) -> Float[Array, "Ny Nx"]:
    """Execute a single multigrid V-cycle starting at *level_idx*.

    Algorithm::

        if coarsest level:
            return jacobi(u, rhs, n_coarse)   # bottom solve

        u = jacobi(u, rhs, n_pre)             # 1. pre-smooth
        r = rhs - A(u)                        # 2. compute residual
        r_c = restrict(r)                     # 3. restrict to coarse grid
        e_c = v_cycle(0, r_c, level+1)        # 4. recurse (solve A_c e_c = r_c)
        u = u + prolongate(e_c)               # 5. correct with coarse error
        u = jacobi(u, rhs, n_post)            # 6. post-smooth

    The recursion unrolls statically at JAX trace time because
    ``level_idx`` and ``n_levels`` are Python ints (static fields).

    Parameters
    ----------
    u : Float[Array, "Ny Nx"]
        Initial guess (typically zeros for the error equation on
        coarse grids, or the current iterate on the fine grid).
    rhs : Float[Array, "Ny Nx"]
        Right-hand side (original RHS on the fine grid, or the
        restricted residual on coarser grids).
    level_idx : int
        Current level (0 = finest, n_levels-1 = coarsest).

    Returns
    -------
    Float[Array, "Ny Nx"]
        Improved solution after one V-cycle.
    """
    level = self.levels[level_idx]

    # --- Coarsest level: iterated Jacobi as the bottom solver ---
    # The grid is small (typically 8x8 to 16x16), so many Jacobi
    # iterations are cheap and sufficient for convergence.
    if level_idx == self.n_levels - 1:
        return _weighted_jacobi(u, rhs, level, self.omega, self.n_coarse)

    # 1. Pre-smooth: damp high-frequency error on the current grid
    u = _weighted_jacobi(u, rhs, level, self.omega, self.n_pre)

    # 2. Compute residual: r = f - A(u)
    #    After smoothing, r contains mostly low-frequency error that
    #    the smoother cannot resolve at this grid resolution.
    r = (rhs - _apply_operator(u, level)) * level.mask

    # 3. Restrict residual to the coarse grid (2x coarser in each dim)
    coarse_level = self.levels[level_idx + 1]
    r_coarse = _restrict(r, level.mask, coarse_level.mask)

    # 4. Recurse: solve A_c * e_c = r_c on the coarse grid.
    #    Start from zero because we're solving for the *error*, not
    #    the solution itself.  On the coarse grid, the low-frequency
    #    residual becomes high-frequency and can be efficiently damped.
    e_coarse = self.v_cycle(jnp.zeros_like(r_coarse), r_coarse, level_idx + 1)

    # 5. Prolongate (interpolate) the coarse correction back to the
    #    fine grid and add to the current solution: u <- u + e_fine
    e_fine = _prolongate(e_coarse, coarse_level.mask, level.mask)
    u = (u + e_fine) * level.mask

    # 6. Post-smooth: clean up high-frequency error introduced by
    #    the coarse-to-fine interpolation
    u = _weighted_jacobi(u, rhs, level, self.omega, self.n_post)
    return u

Factory

finitevolx.build_multigrid_solver(mask, dx, dy, lambda_=0.0, coeff=None, n_levels=None, n_pre=6, n_post=6, n_coarse=50, omega=0.95, n_cycles=5)

Build a multigrid solver with precomputed level hierarchies.

This is an offline function (runs once on CPU with NumPy) that constructs the entire multigrid hierarchy:

  1. Mask coarsening: at each level, the cell mask is coarsened by 2x via 4-point averaging (threshold >= 0.5).
  2. Coefficient interpolation: the cell-centre coefficient c(x, y) is averaged to staggered face coefficients cx, cy at each level, then coarsened for the next level.
  3. Diagonal precomputation: the inverse diagonal D^{-1} of the Helmholtz operator is computed at each level for the Jacobi smoother.
  4. Grid spacing doubling: dx and dy double at each coarser level.

Grid hierarchy example (64x64, auto levels)::

Level 0:  64 x 64   (dx,    dy)     <- finest (solve here)
Level 1:  32 x 32   (2*dx,  2*dy)
Level 2:  16 x 16   (4*dx,  4*dy)
Level 3:   8 x  8   (8*dx,  8*dy)   <- coarsest (bottom solve)

The returned MultigridSolver is an immutable equinox.Module with frozen JAX arrays. All subsequent calls (forward solves, gradients, JIT compilation) use the precomputed hierarchy.

Parameters:

Name Type Description Default
mask array, shape (Ny, Nx), or ArakawaCGridMask

Domain mask (1 = fluid, 0 = land). None is not accepted; pass np.ones((Ny, Nx)) for a rectangular domain. When an :class:ArakawaCGridMask is passed, the psi staggering mask is extracted automatically.

required
dx float

Fine-grid spacings (metres or non-dimensional).

required
dy float

Fine-grid spacings (metres or non-dimensional).

required
lambda_ float

Helmholtz parameter (>= 0). Use 0.0 for pure Poisson (Laplacian only). For QG PV inversion, lambda_ = 1 / Rd**2.

0.0
coeff array, shape (Ny, Nx), or None

Spatially varying coefficient c(x, y) at cell centres. None -> constant coefficient = 1 everywhere (reduces to the standard constant-coefficient Helmholtz operator).

None
n_levels int or None

Number of multigrid levels. None -> auto-detect by halving until either dimension drops below 8. Both dimensions must be divisible by 2**(n_levels - 1); a ValueError is raised otherwise.

None
n_pre int

Number of pre- and post-smoothing Jacobi iterations per V-cycle. More smoothing improves the convergence rate but increases cost per cycle. Default: 6 each.

6
n_post int

Number of pre- and post-smoothing Jacobi iterations per V-cycle. More smoothing improves the convergence rate but increases cost per cycle. Default: 6 each.

6
n_coarse int

Number of Jacobi iterations on the coarsest grid (bottom solver). The coarsest grid is small (typically 8x8), so this is cheap. Default: 50.

50
omega float

Jacobi relaxation weight (0 < omega < 1). Under-relaxation improves smoothing stability. Default: 0.95.

0.95
n_cycles int

Number of V-cycles applied per solve. 5 cycles typically reduce the residual by 3-5 orders of magnitude. Default: 5.

5

Returns:

Type Description
MultigridSolver

Ready-to-use solver (JIT-compilable equinox.Module).

Raises:

Type Description
ValueError

If grid dimensions are not divisible by 2**(n_levels - 1).

Source code in finitevolx/_src/solvers/multigrid.py
def build_multigrid_solver(
    mask: np.ndarray | Float[Array, "Ny Nx"] | ArakawaCGridMask,
    dx: float,
    dy: float,
    lambda_: float = 0.0,
    coeff: np.ndarray | Float[Array, "Ny Nx"] | None = None,
    n_levels: int | None = None,
    n_pre: int = 6,
    n_post: int = 6,
    n_coarse: int = 50,
    omega: float = 0.95,
    n_cycles: int = 5,
) -> MultigridSolver:
    r"""Build a multigrid solver with precomputed level hierarchies.

    This is an **offline** function (runs once on CPU with NumPy) that
    constructs the entire multigrid hierarchy:

    1. **Mask coarsening**: at each level, the cell mask is coarsened by
       2x via 4-point averaging (threshold >= 0.5).
    2. **Coefficient interpolation**: the cell-centre coefficient
       ``c(x, y)`` is averaged to staggered face coefficients ``cx``,
       ``cy`` at each level, then coarsened for the next level.
    3. **Diagonal precomputation**: the inverse diagonal ``D^{-1}`` of
       the Helmholtz operator is computed at each level for the Jacobi
       smoother.
    4. **Grid spacing doubling**: ``dx`` and ``dy`` double at each
       coarser level.

    Grid hierarchy example (64x64, auto levels)::

        Level 0:  64 x 64   (dx,    dy)     <- finest (solve here)
        Level 1:  32 x 32   (2*dx,  2*dy)
        Level 2:  16 x 16   (4*dx,  4*dy)
        Level 3:   8 x  8   (8*dx,  8*dy)   <- coarsest (bottom solve)

    The returned ``MultigridSolver`` is an immutable ``equinox.Module``
    with frozen JAX arrays.  All subsequent calls (forward solves,
    gradients, JIT compilation) use the precomputed hierarchy.

    Parameters
    ----------
    mask : array, shape (Ny, Nx), or ArakawaCGridMask
        Domain mask (1 = fluid, 0 = land).  ``None`` is *not* accepted;
        pass ``np.ones((Ny, Nx))`` for a rectangular domain.
        When an :class:`ArakawaCGridMask` is passed, the ``psi``
        staggering mask is extracted automatically.
    dx, dy : float
        Fine-grid spacings (metres or non-dimensional).
    lambda_ : float
        Helmholtz parameter (>= 0).  Use 0.0 for pure Poisson (Laplacian
        only).  For QG PV inversion, ``lambda_ = 1 / Rd**2``.
    coeff : array, shape (Ny, Nx), or None
        Spatially varying coefficient ``c(x, y)`` at cell centres.
        ``None`` -> constant coefficient = 1 everywhere (reduces to the
        standard constant-coefficient Helmholtz operator).
    n_levels : int or None
        Number of multigrid levels.  ``None`` -> auto-detect by halving
        until either dimension drops below 8.  Both dimensions must be
        divisible by ``2**(n_levels - 1)``; a ``ValueError`` is raised
        otherwise.
    n_pre, n_post : int
        Number of pre- and post-smoothing Jacobi iterations per V-cycle.
        More smoothing improves the convergence rate but increases cost
        per cycle.  Default: 6 each.
    n_coarse : int
        Number of Jacobi iterations on the coarsest grid (bottom solver).
        The coarsest grid is small (typically 8x8), so this is cheap.
        Default: 50.
    omega : float
        Jacobi relaxation weight (0 < omega < 1).  Under-relaxation
        improves smoothing stability.  Default: 0.95.
    n_cycles : int
        Number of V-cycles applied per solve.  5 cycles typically reduce
        the residual by 3-5 orders of magnitude.  Default: 5.

    Returns
    -------
    MultigridSolver
        Ready-to-use solver (JIT-compilable ``equinox.Module``).

    Raises
    ------
    ValueError
        If grid dimensions are not divisible by ``2**(n_levels - 1)``.
    """
    # --- Extract mask as a NumPy float64 array ---
    if isinstance(mask, ArakawaCGridMask):
        mask_np = np.asarray(mask.psi, dtype=np.float64)
    else:
        mask_np = np.asarray(mask, dtype=np.float64)

    ny, nx = mask_np.shape

    # --- Auto-detect or validate number of levels ---
    if n_levels is None:
        n_levels = _compute_n_levels(ny, nx)
    factor = 2 ** (n_levels - 1)
    if ny % factor != 0 or nx % factor != 0:
        raise ValueError(
            f"Grid shape ({ny}, {nx}) is not divisible by "
            f"2^(n_levels-1) = {factor}.  Choose a different n_levels or "
            f"pad the grid."
        )

    # --- Default coefficient: c(x,y) = 1 everywhere ---
    if coeff is None:
        coeff_np = np.ones_like(mask_np)
    else:
        coeff_np = np.asarray(coeff, dtype=np.float64)
    coeff_np = coeff_np * mask_np  # zero outside domain

    # --- Build level hierarchy (finest to coarsest) ---
    levels: list[MultigridLevel] = []
    cur_mask = mask_np
    cur_coeff = coeff_np
    cur_dx, cur_dy = float(dx), float(dy)

    for lev in range(n_levels):
        # Interpolate cell-centre coefficient to staggered face coefficients.
        # cx[j,i] = face coeff between (j,i) and (j,i+1), zero if either is land.
        # cy[j,i] = face coeff between (j,i) and (j+1,i), zero if either is land.
        cx, cy = _interpolate_coeff_to_faces(cur_coeff, cur_mask)

        # Precompute 1/diag(A) for the Jacobi smoother at this level
        inv_diag = _compute_inv_diagonal(cx, cy, cur_mask, cur_dx, cur_dy, lambda_)

        # Store as frozen JAX arrays
        levels.append(
            MultigridLevel(
                mask=jnp.array(cur_mask),
                cx=jnp.array(cx),
                cy=jnp.array(cy),
                dx=cur_dx,
                dy=cur_dy,
                lambda_=float(lambda_),
                inv_diagonal=jnp.array(inv_diag),
            )
        )

        # Coarsen mask and coefficient for the next (coarser) level
        if lev < n_levels - 1:
            next_mask = _restrict_mask(cur_mask)
            next_coeff = _restrict_coeff(cur_coeff, cur_mask, next_mask)
            cur_mask = next_mask
            cur_coeff = next_coeff
            cur_dx *= 2.0  # grid spacing doubles at each coarser level
            cur_dy *= 2.0

    return MultigridSolver(
        levels=tuple(levels),
        n_levels=n_levels,
        n_pre=n_pre,
        n_post=n_post,
        n_coarse=n_coarse,
        omega=omega,
        n_cycles=n_cycles,
    )

Preconditioner

finitevolx.make_multigrid_preconditioner(mg_solver)

Return a preconditioner closure that applies a single multigrid V-cycle.

The returned callable approximates A^{-1} r by running one V-cycle from a zero initial guess, which is sufficient as a preconditioner (it doesn't need to converge — it just needs to be a good approximation).

This is compatible with :func:~finitevolx._src.solvers.iterative.solve_cg: pass the returned closure as the preconditioner argument. CG then converges in very few iterations (typically 5-10 instead of hundreds) because multigrid captures both high- and low-frequency components of the inverse.

Parameters:

Name Type Description Default
mg_solver MultigridSolver

A pre-built multigrid solver (from :func:~finitevolx._src.solvers.multigrid.build_multigrid_solver).

required

Returns:

Type Description
callable

preconditioner(r) -> approx_solution, where r has shape (Ny, Nx) and the output has the same shape.

Examples:

>>> mg = build_multigrid_solver(mask, dx, dy, lambda_=10.0)
>>> precond = make_multigrid_preconditioner(mg)
>>> u, info = solve_cg(A, rhs, preconditioner=precond)
Source code in finitevolx/_src/solvers/preconditioners.py
def make_multigrid_preconditioner(
    mg_solver: MultigridSolver,
) -> Callable[[Float[Array, "Ny Nx"]], Float[Array, "Ny Nx"]]:
    """Return a preconditioner closure that applies a single multigrid V-cycle.

    The returned callable approximates ``A^{-1} r`` by running one V-cycle
    from a zero initial guess, which is sufficient as a preconditioner
    (it doesn't need to converge — it just needs to be a good approximation).

    This is compatible with :func:`~finitevolx._src.solvers.iterative.solve_cg`:
    pass the returned closure as the ``preconditioner`` argument.  CG then
    converges in very few iterations (typically 5-10 instead of hundreds)
    because multigrid captures both high- and low-frequency components of
    the inverse.

    Parameters
    ----------
    mg_solver : MultigridSolver
        A pre-built multigrid solver
        (from :func:`~finitevolx._src.solvers.multigrid.build_multigrid_solver`).

    Returns
    -------
    callable
        ``preconditioner(r) -> approx_solution``, where ``r`` has shape
        ``(Ny, Nx)`` and the output has the same shape.

    Examples
    --------
    >>> mg = build_multigrid_solver(mask, dx, dy, lambda_=10.0)
    >>> precond = make_multigrid_preconditioner(mg)
    >>> u, info = solve_cg(A, rhs, preconditioner=precond)
    """

    def _preconditioner(r: Float[Array, "Ny Nx"]) -> Float[Array, "Ny Nx"]:
        return mg_solver.v_cycle(jnp.zeros_like(r), r)

    return _preconditioner