Skip to content

Vorticity Operators

Relative vorticity and Jacobian operators on Arakawa C-grids.

finitevolx.Vorticity2D

Bases: Module

2-D vorticity and PV-flux operators.

Parameters:

Name Type Description Default
grid ArakawaCGrid2D
required
Source code in finitevolx/_src/operators/vorticity.py
class Vorticity2D(eqx.Module):
    """2-D vorticity and PV-flux operators.

    Parameters
    ----------
    grid : ArakawaCGrid2D
    """

    grid: ArakawaCGrid2D
    diff: Difference2D
    interp: Interpolation2D

    def __init__(self, grid: ArakawaCGrid2D) -> None:
        self.grid = grid
        self.diff = Difference2D(grid=grid)
        self.interp = Interpolation2D(grid=grid)

    def relative_vorticity(
        self,
        u: Float[Array, "Ny Nx"],
        v: Float[Array, "Ny Nx"],
    ) -> Float[Array, "Ny Nx"]:
        """Relative vorticity at X-points (corners).

        zeta[j+1/2, i+1/2] = dv_dx[j+1/2, i+1/2] - du_dy[j+1/2, i+1/2]
                            = (v[j+1/2, i+1] - v[j+1/2, i]) / dx
                            - (u[j+1, i+1/2] - u[j, i+1/2]) / dy

        Parameters
        ----------
        u : Float[Array, "Ny Nx"]
            x-velocity at U-points.
        v : Float[Array, "Ny Nx"]
            y-velocity at V-points.

        Returns
        -------
        Float[Array, "Ny Nx"]
            Relative vorticity at X-points.
        """
        return self.diff.curl(u, v)

    def potential_vorticity(
        self,
        u: Float[Array, "Ny Nx"],
        v: Float[Array, "Ny Nx"],
        h: Float[Array, "Ny Nx"],
        f: Float[Array, "Ny Nx"],
    ) -> Float[Array, "Ny Nx"]:
        """Potential vorticity at X-points (corners).

        q[j+1/2, i+1/2] = (zeta[j+1/2, i+1/2] + f_on_q[j+1/2, i+1/2])
                         / h_on_q[j+1/2, i+1/2]

        Parameters
        ----------
        u : Float[Array, "Ny Nx"]
            x-velocity at U-points.
        v : Float[Array, "Ny Nx"]
            y-velocity at V-points.
        h : Float[Array, "Ny Nx"]
            Layer thickness at T-points.
        f : Float[Array, "Ny Nx"]
            Coriolis parameter at T-points.

        Returns
        -------
        Float[Array, "Ny Nx"]
            Potential vorticity at X-points.
        """
        zeta = self.relative_vorticity(u, v)  # zeta at X-points
        f_on_q = self.interp.T_to_X(f)  # f interpolated to X-points
        h_on_q = self.interp.T_to_X(h)  # h interpolated to X-points
        # q[j+1/2, i+1/2] = (zeta + f) / h  at X-points
        num = zeta[1:-1, 1:-1] + f_on_q[1:-1, 1:-1]
        den = h_on_q[1:-1, 1:-1]
        out = interior(jnp.where(den == 0, jnp.nan, num / den), h)
        return out

    def pv_flux_energy_conserving(
        self,
        q: Float[Array, "Ny Nx"],
        u: Float[Array, "Ny Nx"],
        v: Float[Array, "Ny Nx"],
    ) -> tuple:
        """Energy-conserving PV flux.

        Interpolate q and velocity independently to faces, then multiply.

        qu[j, i+1/2] = q_on_u[j, i+1/2] * u[j, i+1/2]
        qv[j+1/2, i] = q_on_v[j+1/2, i] * v[j+1/2, i]

        Parameters
        ----------
        q : Float[Array, "Ny Nx"]
            Potential vorticity at X-points.
        u : Float[Array, "Ny Nx"]
            x-velocity at U-points.
        v : Float[Array, "Ny Nx"]
            y-velocity at V-points.

        Returns
        -------
        tuple
            (qu at U-points, qv at V-points)
        """
        q_on_u = self.interp.X_to_U(q)  # q_on_u[j, i+1/2] = avg in y
        q_on_v = self.interp.X_to_V(q)  # q_on_v[j+1/2, i] = avg in x
        # qu[j, i+1/2] = q_on_u[j, i+1/2] * u[j, i+1/2]
        qu = interior(q_on_u[1:-1, 1:-1] * u[1:-1, 1:-1], u)
        # qv[j+1/2, i] = q_on_v[j+1/2, i] * v[j+1/2, i]
        qv = interior(q_on_v[1:-1, 1:-1] * v[1:-1, 1:-1], v)
        return qu, qv

    def pv_flux_enstrophy_conserving(
        self,
        q: Float[Array, "Ny Nx"],
        u: Float[Array, "Ny Nx"],
        v: Float[Array, "Ny Nx"],
    ) -> tuple:
        """Enstrophy-conserving PV flux.

        Multiply q*u at corners/faces, then interpolate to faces.

        qu[j, i+1/2] = X_to_U(q * U_to_X(u))
        qv[j+1/2, i] = X_to_V(q * V_to_X(v))

        Parameters
        ----------
        q : Float[Array, "Ny Nx"]
            Potential vorticity at X-points.
        u : Float[Array, "Ny Nx"]
            x-velocity at U-points.
        v : Float[Array, "Ny Nx"]
            y-velocity at V-points.

        Returns
        -------
        tuple
            (qu at U-points, qv at V-points)
        """
        u_on_q = self.interp.U_to_X(u)  # u_on_q[j+1/2, i+1/2]
        v_on_q = self.interp.V_to_X(v)  # v_on_q[j+1/2, i+1/2]
        # Multiply at corners
        # qu_at_q[j+1/2, i+1/2] = q[j+1/2, i+1/2] * u_on_q[j+1/2, i+1/2]
        qu_at_q = interior(q[1:-1, 1:-1] * u_on_q[1:-1, 1:-1], q)
        # qv_at_q[j+1/2, i+1/2] = q[j+1/2, i+1/2] * v_on_q[j+1/2, i+1/2]
        qv_at_q = interior(q[1:-1, 1:-1] * v_on_q[1:-1, 1:-1], q)
        # Interpolate back to faces
        qu = self.interp.X_to_U(qu_at_q)  # qu[j, i+1/2]
        qv = self.interp.X_to_V(qv_at_q)  # qv[j+1/2, i]
        return qu, qv

    def pv_flux_arakawa_lamb(
        self,
        q: Float[Array, "Ny Nx"],
        u: Float[Array, "Ny Nx"],
        v: Float[Array, "Ny Nx"],
        alpha: float = 1.0 / 3.0,
    ) -> tuple:
        """Arakawa-Lamb PV flux: weighted blend of energy and enstrophy.

        flux = alpha * energy_conserving + (1 - alpha) * enstrophy_conserving

        Parameters
        ----------
        q : Float[Array, "Ny Nx"]
            Potential vorticity at X-points.
        u : Float[Array, "Ny Nx"]
            x-velocity at U-points.
        v : Float[Array, "Ny Nx"]
            y-velocity at V-points.
        alpha : float
            Blending weight.  Default 1/3 gives Arakawa-Lamb scheme.

        Returns
        -------
        tuple
            (qu at U-points, qv at V-points)
        """
        qu_e, qv_e = self.pv_flux_energy_conserving(q, u, v)
        qu_s, qv_s = self.pv_flux_enstrophy_conserving(q, u, v)
        # Weighted blend
        qu = alpha * qu_e + (1.0 - alpha) * qu_s
        qv = alpha * qv_e + (1.0 - alpha) * qv_s
        return qu, qv

potential_vorticity(u, v, h, f)

Potential vorticity at X-points (corners).

q[j+1/2, i+1/2] = (zeta[j+1/2, i+1/2] + f_on_q[j+1/2, i+1/2]) / h_on_q[j+1/2, i+1/2]

Parameters:

Name Type Description Default
u Float[Array, 'Ny Nx']

x-velocity at U-points.

required
v Float[Array, 'Ny Nx']

y-velocity at V-points.

required
h Float[Array, 'Ny Nx']

Layer thickness at T-points.

required
f Float[Array, 'Ny Nx']

Coriolis parameter at T-points.

required

Returns:

Type Description
Float[Array, 'Ny Nx']

Potential vorticity at X-points.

Source code in finitevolx/_src/operators/vorticity.py
def potential_vorticity(
    self,
    u: Float[Array, "Ny Nx"],
    v: Float[Array, "Ny Nx"],
    h: Float[Array, "Ny Nx"],
    f: Float[Array, "Ny Nx"],
) -> Float[Array, "Ny Nx"]:
    """Potential vorticity at X-points (corners).

    q[j+1/2, i+1/2] = (zeta[j+1/2, i+1/2] + f_on_q[j+1/2, i+1/2])
                     / h_on_q[j+1/2, i+1/2]

    Parameters
    ----------
    u : Float[Array, "Ny Nx"]
        x-velocity at U-points.
    v : Float[Array, "Ny Nx"]
        y-velocity at V-points.
    h : Float[Array, "Ny Nx"]
        Layer thickness at T-points.
    f : Float[Array, "Ny Nx"]
        Coriolis parameter at T-points.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Potential vorticity at X-points.
    """
    zeta = self.relative_vorticity(u, v)  # zeta at X-points
    f_on_q = self.interp.T_to_X(f)  # f interpolated to X-points
    h_on_q = self.interp.T_to_X(h)  # h interpolated to X-points
    # q[j+1/2, i+1/2] = (zeta + f) / h  at X-points
    num = zeta[1:-1, 1:-1] + f_on_q[1:-1, 1:-1]
    den = h_on_q[1:-1, 1:-1]
    out = interior(jnp.where(den == 0, jnp.nan, num / den), h)
    return out

pv_flux_arakawa_lamb(q, u, v, alpha=1.0 / 3.0)

Arakawa-Lamb PV flux: weighted blend of energy and enstrophy.

flux = alpha * energy_conserving + (1 - alpha) * enstrophy_conserving

Parameters:

Name Type Description Default
q Float[Array, 'Ny Nx']

Potential vorticity at X-points.

required
u Float[Array, 'Ny Nx']

x-velocity at U-points.

required
v Float[Array, 'Ny Nx']

y-velocity at V-points.

required
alpha float

Blending weight. Default 1/3 gives Arakawa-Lamb scheme.

1.0 / 3.0

Returns:

Type Description
tuple

(qu at U-points, qv at V-points)

Source code in finitevolx/_src/operators/vorticity.py
def pv_flux_arakawa_lamb(
    self,
    q: Float[Array, "Ny Nx"],
    u: Float[Array, "Ny Nx"],
    v: Float[Array, "Ny Nx"],
    alpha: float = 1.0 / 3.0,
) -> tuple:
    """Arakawa-Lamb PV flux: weighted blend of energy and enstrophy.

    flux = alpha * energy_conserving + (1 - alpha) * enstrophy_conserving

    Parameters
    ----------
    q : Float[Array, "Ny Nx"]
        Potential vorticity at X-points.
    u : Float[Array, "Ny Nx"]
        x-velocity at U-points.
    v : Float[Array, "Ny Nx"]
        y-velocity at V-points.
    alpha : float
        Blending weight.  Default 1/3 gives Arakawa-Lamb scheme.

    Returns
    -------
    tuple
        (qu at U-points, qv at V-points)
    """
    qu_e, qv_e = self.pv_flux_energy_conserving(q, u, v)
    qu_s, qv_s = self.pv_flux_enstrophy_conserving(q, u, v)
    # Weighted blend
    qu = alpha * qu_e + (1.0 - alpha) * qu_s
    qv = alpha * qv_e + (1.0 - alpha) * qv_s
    return qu, qv

pv_flux_energy_conserving(q, u, v)

Energy-conserving PV flux.

Interpolate q and velocity independently to faces, then multiply.

qu[j, i+1/2] = q_on_u[j, i+1/2] * u[j, i+1/2] qv[j+1/2, i] = q_on_v[j+1/2, i] * v[j+1/2, i]

Parameters:

Name Type Description Default
q Float[Array, 'Ny Nx']

Potential vorticity at X-points.

required
u Float[Array, 'Ny Nx']

x-velocity at U-points.

required
v Float[Array, 'Ny Nx']

y-velocity at V-points.

required

Returns:

Type Description
tuple

(qu at U-points, qv at V-points)

Source code in finitevolx/_src/operators/vorticity.py
def pv_flux_energy_conserving(
    self,
    q: Float[Array, "Ny Nx"],
    u: Float[Array, "Ny Nx"],
    v: Float[Array, "Ny Nx"],
) -> tuple:
    """Energy-conserving PV flux.

    Interpolate q and velocity independently to faces, then multiply.

    qu[j, i+1/2] = q_on_u[j, i+1/2] * u[j, i+1/2]
    qv[j+1/2, i] = q_on_v[j+1/2, i] * v[j+1/2, i]

    Parameters
    ----------
    q : Float[Array, "Ny Nx"]
        Potential vorticity at X-points.
    u : Float[Array, "Ny Nx"]
        x-velocity at U-points.
    v : Float[Array, "Ny Nx"]
        y-velocity at V-points.

    Returns
    -------
    tuple
        (qu at U-points, qv at V-points)
    """
    q_on_u = self.interp.X_to_U(q)  # q_on_u[j, i+1/2] = avg in y
    q_on_v = self.interp.X_to_V(q)  # q_on_v[j+1/2, i] = avg in x
    # qu[j, i+1/2] = q_on_u[j, i+1/2] * u[j, i+1/2]
    qu = interior(q_on_u[1:-1, 1:-1] * u[1:-1, 1:-1], u)
    # qv[j+1/2, i] = q_on_v[j+1/2, i] * v[j+1/2, i]
    qv = interior(q_on_v[1:-1, 1:-1] * v[1:-1, 1:-1], v)
    return qu, qv

pv_flux_enstrophy_conserving(q, u, v)

Enstrophy-conserving PV flux.

Multiply q*u at corners/faces, then interpolate to faces.

qu[j, i+1/2] = X_to_U(q * U_to_X(u)) qv[j+1/2, i] = X_to_V(q * V_to_X(v))

Parameters:

Name Type Description Default
q Float[Array, 'Ny Nx']

Potential vorticity at X-points.

required
u Float[Array, 'Ny Nx']

x-velocity at U-points.

required
v Float[Array, 'Ny Nx']

y-velocity at V-points.

required

Returns:

Type Description
tuple

(qu at U-points, qv at V-points)

Source code in finitevolx/_src/operators/vorticity.py
def pv_flux_enstrophy_conserving(
    self,
    q: Float[Array, "Ny Nx"],
    u: Float[Array, "Ny Nx"],
    v: Float[Array, "Ny Nx"],
) -> tuple:
    """Enstrophy-conserving PV flux.

    Multiply q*u at corners/faces, then interpolate to faces.

    qu[j, i+1/2] = X_to_U(q * U_to_X(u))
    qv[j+1/2, i] = X_to_V(q * V_to_X(v))

    Parameters
    ----------
    q : Float[Array, "Ny Nx"]
        Potential vorticity at X-points.
    u : Float[Array, "Ny Nx"]
        x-velocity at U-points.
    v : Float[Array, "Ny Nx"]
        y-velocity at V-points.

    Returns
    -------
    tuple
        (qu at U-points, qv at V-points)
    """
    u_on_q = self.interp.U_to_X(u)  # u_on_q[j+1/2, i+1/2]
    v_on_q = self.interp.V_to_X(v)  # v_on_q[j+1/2, i+1/2]
    # Multiply at corners
    # qu_at_q[j+1/2, i+1/2] = q[j+1/2, i+1/2] * u_on_q[j+1/2, i+1/2]
    qu_at_q = interior(q[1:-1, 1:-1] * u_on_q[1:-1, 1:-1], q)
    # qv_at_q[j+1/2, i+1/2] = q[j+1/2, i+1/2] * v_on_q[j+1/2, i+1/2]
    qv_at_q = interior(q[1:-1, 1:-1] * v_on_q[1:-1, 1:-1], q)
    # Interpolate back to faces
    qu = self.interp.X_to_U(qu_at_q)  # qu[j, i+1/2]
    qv = self.interp.X_to_V(qv_at_q)  # qv[j+1/2, i]
    return qu, qv

relative_vorticity(u, v)

Relative vorticity at X-points (corners).

zeta[j+1/2, i+1/2] = dv_dx[j+1/2, i+1/2] - du_dy[j+1/2, i+1/2] = (v[j+1/2, i+1] - v[j+1/2, i]) / dx - (u[j+1, i+1/2] - u[j, i+1/2]) / dy

Parameters:

Name Type Description Default
u Float[Array, 'Ny Nx']

x-velocity at U-points.

required
v Float[Array, 'Ny Nx']

y-velocity at V-points.

required

Returns:

Type Description
Float[Array, 'Ny Nx']

Relative vorticity at X-points.

Source code in finitevolx/_src/operators/vorticity.py
def relative_vorticity(
    self,
    u: Float[Array, "Ny Nx"],
    v: Float[Array, "Ny Nx"],
) -> Float[Array, "Ny Nx"]:
    """Relative vorticity at X-points (corners).

    zeta[j+1/2, i+1/2] = dv_dx[j+1/2, i+1/2] - du_dy[j+1/2, i+1/2]
                        = (v[j+1/2, i+1] - v[j+1/2, i]) / dx
                        - (u[j+1, i+1/2] - u[j, i+1/2]) / dy

    Parameters
    ----------
    u : Float[Array, "Ny Nx"]
        x-velocity at U-points.
    v : Float[Array, "Ny Nx"]
        y-velocity at V-points.

    Returns
    -------
    Float[Array, "Ny Nx"]
        Relative vorticity at X-points.
    """
    return self.diff.curl(u, v)

finitevolx.Vorticity3D

Bases: Module

3-D vorticity operators (horizontal plane per z-level).

Parameters:

Name Type Description Default
grid ArakawaCGrid3D
required
Source code in finitevolx/_src/operators/vorticity.py
class Vorticity3D(eqx.Module):
    """3-D vorticity operators (horizontal plane per z-level).

    Parameters
    ----------
    grid : ArakawaCGrid3D
    """

    grid: ArakawaCGrid3D

    def relative_vorticity(
        self,
        u: Float[Array, "Nz Ny Nx"],
        v: Float[Array, "Nz Ny Nx"],
    ) -> Float[Array, "Nz Ny Nx"]:
        """Relative vorticity at X-points over all z-levels.

        zeta[k, j+1/2, i+1/2] = (v[k, j+1/2, i+1] - v[k, j+1/2, i]) / dx
                               - (u[k, j+1, i+1/2] - u[k, j, i+1/2]) / dy

        Parameters
        ----------
        u : Float[Array, "Nz Ny Nx"]
            x-velocity at U-points.
        v : Float[Array, "Nz Ny Nx"]
            y-velocity at V-points.

        Returns
        -------
        Float[Array, "Nz Ny Nx"]
            Relative vorticity at X-points.
        """
        out = eqx.filter_vmap(
            lambda u_k, v_k: _curl_2d(u_k, v_k, self.grid.dx, self.grid.dy)
        )(u, v)
        # Zero z-ghost slices to match 3D ghost-ring convention.
        return zero_z_ghosts(out)

relative_vorticity(u, v)

Relative vorticity at X-points over all z-levels.

zeta[k, j+1/2, i+1/2] = (v[k, j+1/2, i+1] - v[k, j+1/2, i]) / dx - (u[k, j+1, i+1/2] - u[k, j, i+1/2]) / dy

Parameters:

Name Type Description Default
u Float[Array, 'Nz Ny Nx']

x-velocity at U-points.

required
v Float[Array, 'Nz Ny Nx']

y-velocity at V-points.

required

Returns:

Type Description
Float[Array, 'Nz Ny Nx']

Relative vorticity at X-points.

Source code in finitevolx/_src/operators/vorticity.py
def relative_vorticity(
    self,
    u: Float[Array, "Nz Ny Nx"],
    v: Float[Array, "Nz Ny Nx"],
) -> Float[Array, "Nz Ny Nx"]:
    """Relative vorticity at X-points over all z-levels.

    zeta[k, j+1/2, i+1/2] = (v[k, j+1/2, i+1] - v[k, j+1/2, i]) / dx
                           - (u[k, j+1, i+1/2] - u[k, j, i+1/2]) / dy

    Parameters
    ----------
    u : Float[Array, "Nz Ny Nx"]
        x-velocity at U-points.
    v : Float[Array, "Nz Ny Nx"]
        y-velocity at V-points.

    Returns
    -------
    Float[Array, "Nz Ny Nx"]
        Relative vorticity at X-points.
    """
    out = eqx.filter_vmap(
        lambda u_k, v_k: _curl_2d(u_k, v_k, self.grid.dx, self.grid.dy)
    )(u, v)
    # Zero z-ghost slices to match 3D ghost-ring convention.
    return zero_z_ghosts(out)

Jacobian

finitevolx.arakawa_jacobian(f, g, dx, dy)

Arakawa (1966) discretization of J(f, g).

Computes the Jacobian J(f, g) = ∂f/∂x·∂g/∂y − ∂f/∂y·∂g/∂x using the energy- and enstrophy-conserving three-term Arakawa scheme on a collocated grid. The inputs must include a one-point boundary halo on each side so that the returned interior array has shape (..., Ny-2, Nx-2) (i.e. Ny_i = Ny - 2, Nx_i = Nx - 2).

Parameters:

Name Type Description Default
f Float[Array, '... Ny Nx']

First scalar field (including one halo cell on each side).

required
g Float[Array, '... Ny Nx']

Second scalar field (same shape as f).

required
dx float

Grid spacing in the x-direction (last array axis).

required
dy float

Grid spacing in the y-direction (second-to-last array axis).

required

Returns:

Type Description
Float[Array, '... Ny_i Nx_i']

Jacobian evaluated on the interior grid points, where Ny_i = Ny - 2 and Nx_i = Nx - 2. Boundary points are consumed by the stencil and are not included in the output.

Notes

The Arakawa scheme averages three discrete forms:

  • J⁺⁺ (standard centred form):

.. code-block:: text

 Jpp[j,i] = ( (f[j,i+1] - f[j,i-1]) * (g[j+1,i] - g[j-1,i])
            - (f[j+1,i] - f[j-1,i]) * (g[j,i+1] - g[j,i-1]) ) / (4 dx dy)
  • J⁺× (advective form):

.. code-block:: text

 Jpx[j,i] = ( f[j,i+1] * (g[j+1,i+1] - g[j-1,i+1])
            - f[j,i-1] * (g[j+1,i-1] - g[j-1,i-1])
            - f[j+1,i] * (g[j+1,i+1] - g[j+1,i-1])
            + f[j-1,i] * (g[j-1,i+1] - g[j-1,i-1]) ) / (4 dx dy)
  • J×⁺ (divergence form):

.. code-block:: text

 Jxp[j,i] = ( g[j+1,i] * (f[j+1,i+1] - f[j+1,i-1])
            - g[j-1,i] * (f[j-1,i+1] - f[j-1,i-1])
            - g[j,i+1] * (f[j+1,i+1] - f[j-1,i+1])
            + g[j,i-1] * (f[j+1,i-1] - f[j-1,i-1]) ) / (4 dx dy)

Together: J = (Jpp + Jpx + Jxp) / 3

This triple average conserves energy (∫∫ f·J dA = 0), enstrophy (∫∫ g·J dA = 0), satisfies J(f, f) = 0, and ∫∫ J(f, g) dA = 0 at the discrete level.

The function is JAX-compatible and jit-able. Batch dimensions (...) are supported via standard broadcasting.

Examples:

>>> import jax.numpy as jnp
>>> from finitevolx import arakawa_jacobian
>>> Ny, Nx = 12, 10
>>> x = jnp.linspace(0, 1, Nx)
>>> y = jnp.linspace(0, 1, Ny)
>>> dx, dy = x[1] - x[0], y[1] - y[0]
>>> X, Y = jnp.meshgrid(x, y)
>>> J = arakawa_jacobian(X, Y, float(dx), float(dy))
>>> J.shape
(10, 8)
Source code in finitevolx/_src/operators/jacobian.py
def arakawa_jacobian(
    f: Float[Array, "... Ny Nx"],
    g: Float[Array, "... Ny Nx"],
    dx: float,
    dy: float,
) -> Float[Array, "... Ny_i Nx_i"]:
    """Arakawa (1966) discretization of J(f, g).

    Computes the Jacobian J(f, g) = ∂f/∂x·∂g/∂y − ∂f/∂y·∂g/∂x using the
    energy- and enstrophy-conserving three-term Arakawa scheme on a collocated
    grid.  The inputs must include a one-point boundary halo on each side so
    that the returned interior array has shape ``(..., Ny-2, Nx-2)``
    (i.e. ``Ny_i = Ny - 2``, ``Nx_i = Nx - 2``).

    Parameters
    ----------
    f : Float[Array, "... Ny Nx"]
        First scalar field (including one halo cell on each side).
    g : Float[Array, "... Ny Nx"]
        Second scalar field (same shape as *f*).
    dx : float
        Grid spacing in the x-direction (last array axis).
    dy : float
        Grid spacing in the y-direction (second-to-last array axis).

    Returns
    -------
    Float[Array, "... Ny_i Nx_i"]
        Jacobian evaluated on the interior grid points, where
        ``Ny_i = Ny - 2`` and ``Nx_i = Nx - 2``.
        Boundary points are consumed by the stencil and are not included
        in the output.

    Notes
    -----
    The Arakawa scheme averages three discrete forms:

    * J⁺⁺ (standard centred form):

      .. code-block:: text

         Jpp[j,i] = ( (f[j,i+1] - f[j,i-1]) * (g[j+1,i] - g[j-1,i])
                    - (f[j+1,i] - f[j-1,i]) * (g[j,i+1] - g[j,i-1]) ) / (4 dx dy)

    * J⁺× (advective form):

      .. code-block:: text

         Jpx[j,i] = ( f[j,i+1] * (g[j+1,i+1] - g[j-1,i+1])
                    - f[j,i-1] * (g[j+1,i-1] - g[j-1,i-1])
                    - f[j+1,i] * (g[j+1,i+1] - g[j+1,i-1])
                    + f[j-1,i] * (g[j-1,i+1] - g[j-1,i-1]) ) / (4 dx dy)

    * J×⁺ (divergence form):

      .. code-block:: text

         Jxp[j,i] = ( g[j+1,i] * (f[j+1,i+1] - f[j+1,i-1])
                    - g[j-1,i] * (f[j-1,i+1] - f[j-1,i-1])
                    - g[j,i+1] * (f[j+1,i+1] - f[j-1,i+1])
                    + g[j,i-1] * (f[j+1,i-1] - f[j-1,i-1]) ) / (4 dx dy)

    Together: ``J = (Jpp + Jpx + Jxp) / 3``

    This triple average conserves energy (∫∫ f·J dA = 0), enstrophy
    (∫∫ g·J dA = 0), satisfies J(f, f) = 0, and ∫∫ J(f, g) dA = 0 at the
    discrete level.

    The function is JAX-compatible and ``jit``-able.  Batch dimensions
    (``...``) are supported via standard broadcasting.

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> from finitevolx import arakawa_jacobian
    >>> Ny, Nx = 12, 10
    >>> x = jnp.linspace(0, 1, Nx)
    >>> y = jnp.linspace(0, 1, Ny)
    >>> dx, dy = x[1] - x[0], y[1] - y[0]
    >>> X, Y = jnp.meshgrid(x, y)
    >>> J = arakawa_jacobian(X, Y, float(dx), float(dy))
    >>> J.shape
    (10, 8)
    """
    # J++ (standard centred form)
    # Jpp[j,i] = (df/dx * dg/dy - df/dy * dg/dx)
    # where df/dx ~ (f[j, i+1] - f[j, i-1]) / (2dx), etc.
    Jpp = (f[..., 1:-1, 2:] - f[..., 1:-1, :-2]) * (
        g[..., 2:, 1:-1] - g[..., :-2, 1:-1]
    ) - (f[..., 2:, 1:-1] - f[..., :-2, 1:-1]) * (g[..., 1:-1, 2:] - g[..., 1:-1, :-2])

    # J+x (advective form)
    # f evaluated at off-centre x-neighbours, g differenced in y at those neighbours
    Jpx = (
        f[..., 1:-1, 2:] * (g[..., 2:, 2:] - g[..., :-2, 2:])
        - f[..., 1:-1, :-2] * (g[..., 2:, :-2] - g[..., :-2, :-2])
        - f[..., 2:, 1:-1] * (g[..., 2:, 2:] - g[..., 2:, :-2])
        + f[..., :-2, 1:-1] * (g[..., :-2, 2:] - g[..., :-2, :-2])
    )

    # Jx+ (divergence form)
    # g evaluated at off-centre y-neighbours, f differenced in x at those neighbours
    Jxp = (
        g[..., 2:, 1:-1] * (f[..., 2:, 2:] - f[..., 2:, :-2])
        - g[..., :-2, 1:-1] * (f[..., :-2, 2:] - f[..., :-2, :-2])
        - g[..., 1:-1, 2:] * (f[..., 2:, 2:] - f[..., :-2, 2:])
        + g[..., 1:-1, :-2] * (f[..., 2:, :-2] - f[..., :-2, :-2])
    )

    return (Jpp + Jpx + Jxp) / (12.0 * dx * dy)