Skip to content

Time Integration

Pure functional integrators, diffrax-based solvers, and convenience wrappers for time-stepping PDEs on Arakawa C-grids.

Explicit Runge-Kutta (Pure Functional)

finitevolx.euler_step(state, rhs_fn, dt)

Forward Euler: y_{n+1} = y_n + dt * f(y_n).

Parameters:

Name Type Description Default
state PyTree

Current state (arbitrary JAX pytree).

required
rhs_fn Callable[[PyTree], PyTree]

Right-hand-side function f(state) -> tendency.

required
dt float

Timestep.

required

Returns:

Type Description
PyTree

Updated state after one Euler step.

Source code in finitevolx/_src/timestepping/explicit_rk.py
def euler_step(state: PyTree, rhs_fn: Callable[[PyTree], PyTree], dt: float) -> PyTree:
    """Forward Euler: y_{n+1} = y_n + dt * f(y_n).

    Parameters
    ----------
    state : PyTree
        Current state (arbitrary JAX pytree).
    rhs_fn : Callable[[PyTree], PyTree]
        Right-hand-side function ``f(state) -> tendency``.
    dt : float
        Timestep.

    Returns
    -------
    PyTree
        Updated state after one Euler step.
    """
    k1 = rhs_fn(state)
    return jax.tree.map(lambda y, f: y + dt * f, state, k1)

finitevolx.heun_step(state, rhs_fn, dt)

Heun (RK2) predictor-corrector.

.. math::

k_1 = f(y_n)
k_2 = f(y_n + dt \cdot k_1)
y_{n+1} = y_n + (dt/2)(k_1 + k_2)

Order 2, SSP with C = 1.

Parameters:

Name Type Description Default
state PyTree

Current state.

required
rhs_fn Callable[[PyTree], PyTree]

Right-hand-side function.

required
dt float

Timestep.

required

Returns:

Type Description
PyTree

Updated state after one Heun step.

Source code in finitevolx/_src/timestepping/explicit_rk.py
def heun_step(state: PyTree, rhs_fn: Callable[[PyTree], PyTree], dt: float) -> PyTree:
    """Heun (RK2) predictor-corrector.

    .. math::

        k_1 = f(y_n)
        k_2 = f(y_n + dt \\cdot k_1)
        y_{n+1} = y_n + (dt/2)(k_1 + k_2)

    Order 2, SSP with C = 1.

    Parameters
    ----------
    state : PyTree
        Current state.
    rhs_fn : Callable[[PyTree], PyTree]
        Right-hand-side function.
    dt : float
        Timestep.

    Returns
    -------
    PyTree
        Updated state after one Heun step.
    """
    k1 = rhs_fn(state)
    state_star = jax.tree.map(lambda y, f: y + dt * f, state, k1)
    k2 = rhs_fn(state_star)
    return jax.tree.map(lambda y, f1, f2: y + 0.5 * dt * (f1 + f2), state, k1, k2)

finitevolx.rk3_ssp_step(state, rhs_fn, dt)

3rd-order Strong-Stability-Preserving Runge-Kutta (Shu-Osher form).

.. math::

y^{(1)} &= y_n + dt \cdot f(y_n)
y^{(2)} &= \tfrac{3}{4} y_n + \tfrac{1}{4} y^{(1)}
            + \tfrac{1}{4} dt \cdot f(y^{(1)})
y_{n+1} &= \tfrac{1}{3} y_n + \tfrac{2}{3} y^{(2)}
            + \tfrac{2}{3} dt \cdot f(y^{(2)})

Order 3, SSP coefficient C = 1 (optimal). Preserves monotonicity, positivity, and TVD properties.

Parameters:

Name Type Description Default
state PyTree

Current state.

required
rhs_fn Callable[[PyTree], PyTree]

Right-hand-side function.

required
dt float

Timestep.

required

Returns:

Type Description
PyTree

Updated state after one SSP-RK3 step.

Source code in finitevolx/_src/timestepping/explicit_rk.py
def rk3_ssp_step(
    state: PyTree, rhs_fn: Callable[[PyTree], PyTree], dt: float
) -> PyTree:
    """3rd-order Strong-Stability-Preserving Runge-Kutta (Shu-Osher form).

    .. math::

        y^{(1)} &= y_n + dt \\cdot f(y_n)
        y^{(2)} &= \\tfrac{3}{4} y_n + \\tfrac{1}{4} y^{(1)}
                    + \\tfrac{1}{4} dt \\cdot f(y^{(1)})
        y_{n+1} &= \\tfrac{1}{3} y_n + \\tfrac{2}{3} y^{(2)}
                    + \\tfrac{2}{3} dt \\cdot f(y^{(2)})

    Order 3, SSP coefficient C = 1 (optimal).  Preserves monotonicity,
    positivity, and TVD properties.

    Parameters
    ----------
    state : PyTree
        Current state.
    rhs_fn : Callable[[PyTree], PyTree]
        Right-hand-side function.
    dt : float
        Timestep.

    Returns
    -------
    PyTree
        Updated state after one SSP-RK3 step.
    """
    k1 = rhs_fn(state)
    y1 = jax.tree.map(lambda y, f: y + dt * f, state, k1)

    k2 = rhs_fn(y1)
    y2 = jax.tree.map(
        lambda y, y1_, f: 0.75 * y + 0.25 * y1_ + 0.25 * dt * f,
        state,
        y1,
        k2,
    )

    k3 = rhs_fn(y2)
    return jax.tree.map(
        lambda y, y2_, f: (1.0 / 3.0) * y + (2.0 / 3.0) * y2_ + (2.0 / 3.0) * dt * f,
        state,
        y2,
        k3,
    )

finitevolx.rk4_step(state, rhs_fn, dt)

Classic 4th-order Runge-Kutta.

.. math::

k_1 &= f(y_n)
k_2 &= f(y_n + (dt/2) k_1)
k_3 &= f(y_n + (dt/2) k_2)
k_4 &= f(y_n + dt \cdot k_3)
y_{n+1} &= y_n + (dt/6)(k_1 + 2 k_2 + 2 k_3 + k_4)

Order 4, 4 stages, not SSP.

Parameters:

Name Type Description Default
state PyTree

Current state.

required
rhs_fn Callable[[PyTree], PyTree]

Right-hand-side function.

required
dt float

Timestep.

required

Returns:

Type Description
PyTree

Updated state after one RK4 step.

Source code in finitevolx/_src/timestepping/explicit_rk.py
def rk4_step(state: PyTree, rhs_fn: Callable[[PyTree], PyTree], dt: float) -> PyTree:
    """Classic 4th-order Runge-Kutta.

    .. math::

        k_1 &= f(y_n)
        k_2 &= f(y_n + (dt/2) k_1)
        k_3 &= f(y_n + (dt/2) k_2)
        k_4 &= f(y_n + dt \\cdot k_3)
        y_{n+1} &= y_n + (dt/6)(k_1 + 2 k_2 + 2 k_3 + k_4)

    Order 4, 4 stages, not SSP.

    Parameters
    ----------
    state : PyTree
        Current state.
    rhs_fn : Callable[[PyTree], PyTree]
        Right-hand-side function.
    dt : float
        Timestep.

    Returns
    -------
    PyTree
        Updated state after one RK4 step.
    """
    k1 = rhs_fn(state)
    y1 = jax.tree.map(lambda y, f: y + 0.5 * dt * f, state, k1)

    k2 = rhs_fn(y1)
    y2 = jax.tree.map(lambda y, f: y + 0.5 * dt * f, state, k2)

    k3 = rhs_fn(y2)
    y3 = jax.tree.map(lambda y, f: y + dt * f, state, k3)

    k4 = rhs_fn(y3)
    return jax.tree.map(
        lambda y, f1, f2, f3, f4: y + (dt / 6.0) * (f1 + 2.0 * f2 + 2.0 * f3 + f4),
        state,
        k1,
        k2,
        k3,
        k4,
    )

Multistep Methods (Pure Functional)

finitevolx.ab2_step(state, rhs_fn, dt, rhs_nm1)

2nd-order Adams-Bashforth.

.. math::

y_{n+1} = y_n + (dt/2)(3 f_n - f_{n-1})

Only one RHS evaluation per step (efficiency advantage over RK2). Requires one previous RHS evaluation rhs_nm1 = f(y_{n-1}).

Parameters:

Name Type Description Default
state PyTree

Current state y_n.

required
rhs_fn Callable[[PyTree], PyTree]

Right-hand-side function f(state) -> tendency.

required
dt float

Timestep.

required
rhs_nm1 PyTree

RHS evaluated at the previous step, f(y_{n-1}).

required

Returns:

Type Description
tuple[PyTree, PyTree, PyTree]

(new_state, rhs_n, rhs_nm1) — the caller must thread rhs_n and rhs_nm1 into the next call (shifted by one level).

Source code in finitevolx/_src/timestepping/multistep.py
def ab2_step(
    state: PyTree, rhs_fn: Callable[[PyTree], PyTree], dt: float, rhs_nm1: PyTree
) -> tuple[PyTree, PyTree, PyTree]:
    """2nd-order Adams-Bashforth.

    .. math::

        y_{n+1} = y_n + (dt/2)(3 f_n - f_{n-1})

    Only one RHS evaluation per step (efficiency advantage over RK2).
    Requires one previous RHS evaluation ``rhs_nm1 = f(y_{n-1})``.

    Parameters
    ----------
    state : PyTree
        Current state y_n.
    rhs_fn : Callable[[PyTree], PyTree]
        Right-hand-side function ``f(state) -> tendency``.
    dt : float
        Timestep.
    rhs_nm1 : PyTree
        RHS evaluated at the previous step, ``f(y_{n-1})``.

    Returns
    -------
    tuple[PyTree, PyTree, PyTree]
        ``(new_state, rhs_n, rhs_nm1)`` — the caller must thread ``rhs_n``
        and ``rhs_nm1`` into the next call (shifted by one level).
    """
    rhs_n = rhs_fn(state)
    new_state = jax.tree.map(
        lambda y, fn, fnm1: y + (dt / 2.0) * (3.0 * fn - fnm1),
        state,
        rhs_n,
        rhs_nm1,
    )
    return new_state, rhs_n, rhs_nm1

finitevolx.ab3_step(state, rhs_fn, dt, rhs_nm1, rhs_nm2)

3rd-order Adams-Bashforth.

.. math::

y_{n+1} = y_n + (dt/12)(23 f_n - 16 f_{n-1} + 5 f_{n-2})

One RHS evaluation per step. Requires two previous RHS evaluations.

Parameters:

Name Type Description Default
state PyTree

Current state y_n.

required
rhs_fn Callable[[PyTree], PyTree]

Right-hand-side function.

required
dt float

Timestep.

required
rhs_nm1 PyTree

RHS at step n-1.

required
rhs_nm2 PyTree

RHS at step n-2.

required

Returns:

Type Description
tuple[PyTree, PyTree, PyTree]

(new_state, rhs_n, rhs_nm1) — thread rhs_n as the new rhs_nm1 and the old rhs_nm1 as the new rhs_nm2 in subsequent calls.

Source code in finitevolx/_src/timestepping/multistep.py
def ab3_step(
    state: PyTree,
    rhs_fn: Callable[[PyTree], PyTree],
    dt: float,
    rhs_nm1: PyTree,
    rhs_nm2: PyTree,
) -> tuple[PyTree, PyTree, PyTree]:
    """3rd-order Adams-Bashforth.

    .. math::

        y_{n+1} = y_n + (dt/12)(23 f_n - 16 f_{n-1} + 5 f_{n-2})

    One RHS evaluation per step.  Requires two previous RHS evaluations.

    Parameters
    ----------
    state : PyTree
        Current state y_n.
    rhs_fn : Callable[[PyTree], PyTree]
        Right-hand-side function.
    dt : float
        Timestep.
    rhs_nm1 : PyTree
        RHS at step n-1.
    rhs_nm2 : PyTree
        RHS at step n-2.

    Returns
    -------
    tuple[PyTree, PyTree, PyTree]
        ``(new_state, rhs_n, rhs_nm1)`` — thread ``rhs_n`` as the new
        ``rhs_nm1`` and the old ``rhs_nm1`` as the new ``rhs_nm2`` in
        subsequent calls.
    """
    rhs_n = rhs_fn(state)
    new_state = jax.tree.map(
        lambda y, fn, fnm1, fnm2: (
            y + (dt / 12.0) * (23.0 * fn - 16.0 * fnm1 + 5.0 * fnm2)
        ),
        state,
        rhs_n,
        rhs_nm1,
        rhs_nm2,
    )
    return new_state, rhs_n, rhs_nm1

finitevolx.leapfrog_raf_step(state, state_nm1, rhs_fn, dt, alpha=0.05)

Leapfrog with Robert-Asselin filter.

.. math::

\tilde{y}_{n+1} &= y_{n-1} + 2 \Delta t \, f(y_n)
\bar{y}_n &= y_n + \alpha (y_{n-1} - 2 y_n + \tilde{y}_{n+1})

The filtered middle value bar{y}_n damps the spurious computational mode inherent to the three-level leapfrog scheme.

Parameters:

Name Type Description Default
state PyTree

Current state y_n.

required
state_nm1 PyTree

Previous state y_{n-1}.

required
rhs_fn Callable[[PyTree], PyTree]

Right-hand-side function.

required
dt float

Timestep.

required
alpha float

Robert-Asselin filter coefficient (default 0.05). Typical range 0.01-0.1; larger values damp the computational mode more aggressively but introduce additional dissipation.

0.05

Returns:

Type Description
tuple[PyTree, PyTree]

(y_{n+1}, filtered_y_n) — use y_{n+1} as the new current state and filtered_y_n as the new state_nm1 at the next step.

Source code in finitevolx/_src/timestepping/multistep.py
def leapfrog_raf_step(
    state: PyTree,
    state_nm1: PyTree,
    rhs_fn: Callable[[PyTree], PyTree],
    dt: float,
    alpha: float = 0.05,
) -> tuple[PyTree, PyTree]:
    """Leapfrog with Robert-Asselin filter.

    .. math::

        \\tilde{y}_{n+1} &= y_{n-1} + 2 \\Delta t \\, f(y_n)
        \\bar{y}_n &= y_n + \\alpha (y_{n-1} - 2 y_n + \\tilde{y}_{n+1})

    The filtered middle value ``bar{y}_n`` damps the spurious computational
    mode inherent to the three-level leapfrog scheme.

    Parameters
    ----------
    state : PyTree
        Current state y_n.
    state_nm1 : PyTree
        Previous state y_{n-1}.
    rhs_fn : Callable[[PyTree], PyTree]
        Right-hand-side function.
    dt : float
        Timestep.
    alpha : float, optional
        Robert-Asselin filter coefficient (default 0.05).  Typical range
        0.01-0.1; larger values damp the computational mode more aggressively
        but introduce additional dissipation.

    Returns
    -------
    tuple[PyTree, PyTree]
        ``(y_{n+1}, filtered_y_n)`` — use ``y_{n+1}`` as the new current
        state and ``filtered_y_n`` as the new ``state_nm1`` at the next step.
    """
    rhs_n = rhs_fn(state)

    # Leapfrog step
    y_next = jax.tree.map(
        lambda ynm1, fn: ynm1 + 2.0 * dt * fn,
        state_nm1,
        rhs_n,
    )

    # Robert-Asselin filter on the middle level
    state_filtered = jax.tree.map(
        lambda ynm1, yn, ynp1: yn + alpha * (ynm1 - 2.0 * yn + ynp1),
        state_nm1,
        state,
        y_next,
    )

    return y_next, state_filtered

IMEX (Pure Functional)

finitevolx.imex_ssp2_step(state, rhs_explicit, rhs_implicit, implicit_solve, dt)

IMEX-SSP2(2,2,2) time step.

The explicit part is SSP with C = 1; the implicit part is A-stable and L-stable (SDIRK with gamma = 1 - 1/sqrt(2)).

Algorithm (two stages)::

Stage 1:
    Y_1 = y_n + gamma * dt * f_I(Y_1)
    -> solved via implicit_solve(y_n, gamma * dt)

Stage 2:
    Y_2_star = y_n + dt * f_E(y_n)
    Y_2 = Y_2_star + gamma * dt * f_I(Y_2)
    -> solved via implicit_solve(Y_2_star, gamma * dt)

Update:
    y_{n+1} = y_n + (dt/2) * [f_E(y_n) + f_E(Y_2)]
                  + (dt/2) * [f_I(Y_1) + f_I(Y_2)]

Parameters:

Name Type Description Default
state PyTree

Current state y_n.

required
rhs_explicit Callable[[PyTree], PyTree]

Explicit (non-stiff) right-hand side, e.g. advection.

required
rhs_implicit Callable[[PyTree], PyTree]

Implicit (stiff) right-hand side, e.g. vertical diffusion.

required
implicit_solve Callable[[PyTree, float], PyTree]

Solves Y - gamma * dt * f_I(Y) = rhs for Y given (rhs, gamma * dt). For vertical diffusion this is typically a tridiagonal (TDMA) solve along columns.

required
dt float

Timestep.

required

Returns:

Type Description
PyTree

Updated state after one IMEX-SSP2 step.

Source code in finitevolx/_src/timestepping/imex.py
def imex_ssp2_step(
    state: PyTree,
    rhs_explicit: Callable[[PyTree], PyTree],
    rhs_implicit: Callable[[PyTree], PyTree],
    implicit_solve: Callable[[PyTree, float], PyTree],
    dt: float,
) -> PyTree:
    """IMEX-SSP2(2,2,2) time step.

    The explicit part is SSP with C = 1; the implicit part is A-stable and
    L-stable (SDIRK with gamma = 1 - 1/sqrt(2)).

    Algorithm (two stages)::

        Stage 1:
            Y_1 = y_n + gamma * dt * f_I(Y_1)
            -> solved via implicit_solve(y_n, gamma * dt)

        Stage 2:
            Y_2_star = y_n + dt * f_E(y_n)
            Y_2 = Y_2_star + gamma * dt * f_I(Y_2)
            -> solved via implicit_solve(Y_2_star, gamma * dt)

        Update:
            y_{n+1} = y_n + (dt/2) * [f_E(y_n) + f_E(Y_2)]
                          + (dt/2) * [f_I(Y_1) + f_I(Y_2)]

    Parameters
    ----------
    state : PyTree
        Current state y_n.
    rhs_explicit : Callable[[PyTree], PyTree]
        Explicit (non-stiff) right-hand side, e.g. advection.
    rhs_implicit : Callable[[PyTree], PyTree]
        Implicit (stiff) right-hand side, e.g. vertical diffusion.
    implicit_solve : Callable[[PyTree, float], PyTree]
        Solves ``Y - gamma * dt * f_I(Y) = rhs`` for ``Y`` given ``(rhs,
        gamma * dt)``.  For vertical diffusion this is typically a tridiagonal
        (TDMA) solve along columns.
    dt : float
        Timestep.

    Returns
    -------
    PyTree
        Updated state after one IMEX-SSP2 step.
    """
    gamma_dt = _GAMMA * dt

    # Stage 1: implicit solve from y_n
    y1 = implicit_solve(state, gamma_dt)

    # Stage 2: explicit predictor then implicit correction
    fe_0 = rhs_explicit(state)
    y2_star = jax.tree.map(lambda y, f: y + dt * f, state, fe_0)
    y2 = implicit_solve(y2_star, gamma_dt)

    # Final update: average explicit and implicit contributions
    fe_1 = rhs_explicit(y2)
    fi_0 = rhs_implicit(y1)
    fi_1 = rhs_implicit(y2)

    return jax.tree.map(
        lambda y, fe0, fe1, fi0, fi1: (
            y + 0.5 * dt * (fe0 + fe1) + 0.5 * dt * (fi0 + fi1)
        ),
        state,
        fe_0,
        fe_1,
        fi_0,
        fi_1,
    )

Split-Explicit (Pure Functional)

finitevolx.split_explicit_step(state_3d, state_2d, rhs_3d, rhs_2d, couple_2d_to_3d, dt_slow, n_substeps)

Split-explicit barotropic/baroclinic time step.

Algorithm::

1. Subcycle 2D barotropic with n_substeps Forward-Euler steps
   (dt_fast = dt_slow / n_substeps), accumulating a time-average.
2. Couple the time-averaged 2D state into the slow RHS.
3. Advance the 3D baroclinic state with one Forward-Euler step
   using dt_slow.

Parameters:

Name Type Description Default
state_3d PyTree

3D baroclinic state (slow mode).

required
state_2d PyTree

2D barotropic state (fast mode).

required
rhs_3d Callable[[PyTree, PyTree], PyTree]

Slow RHS: rhs_3d(state_3d, state_2d_avg) -> tendency_3d.

required
rhs_2d Callable[[float, PyTree, PyTree], PyTree]

Fast RHS: rhs_2d(t_sub, state_2d, state_3d) -> tendency_2d. The first argument is the sub-step time offset from the beginning of the slow step.

required
couple_2d_to_3d Callable[[PyTree, PyTree], PyTree]

Coupling function: couple_2d_to_3d(state_3d, state_2d_avg) -> state_3d_corrected. Applied after the slow step to ensure consistency between the 2D and 3D solutions.

required
dt_slow float

Slow (baroclinic) timestep.

required
n_substeps int

Number of fast (barotropic) substeps per slow step.

required

Returns:

Type Description
tuple[PyTree, PyTree]

(new_state_3d, new_state_2d) after the split-explicit step.

Source code in finitevolx/_src/timestepping/split_explicit.py
def split_explicit_step(
    state_3d: PyTree,
    state_2d: PyTree,
    rhs_3d: Callable[[PyTree, PyTree], PyTree],
    rhs_2d: Callable[[float, PyTree, PyTree], PyTree],
    couple_2d_to_3d: Callable[[PyTree, PyTree], PyTree],
    dt_slow: float,
    n_substeps: int,
) -> tuple[PyTree, PyTree]:
    """Split-explicit barotropic/baroclinic time step.

    Algorithm::

        1. Subcycle 2D barotropic with n_substeps Forward-Euler steps
           (dt_fast = dt_slow / n_substeps), accumulating a time-average.
        2. Couple the time-averaged 2D state into the slow RHS.
        3. Advance the 3D baroclinic state with one Forward-Euler step
           using dt_slow.

    Parameters
    ----------
    state_3d : PyTree
        3D baroclinic state (slow mode).
    state_2d : PyTree
        2D barotropic state (fast mode).
    rhs_3d : Callable[[PyTree, PyTree], PyTree]
        Slow RHS: ``rhs_3d(state_3d, state_2d_avg) -> tendency_3d``.
    rhs_2d : Callable[[float, PyTree, PyTree], PyTree]
        Fast RHS: ``rhs_2d(t_sub, state_2d, state_3d) -> tendency_2d``.
        The first argument is the sub-step time offset from the beginning
        of the slow step.
    couple_2d_to_3d : Callable[[PyTree, PyTree], PyTree]
        Coupling function: ``couple_2d_to_3d(state_3d, state_2d_avg) ->
        state_3d_corrected``.  Applied after the slow step to ensure
        consistency between the 2D and 3D solutions.
    dt_slow : float
        Slow (baroclinic) timestep.
    n_substeps : int
        Number of fast (barotropic) substeps per slow step.

    Returns
    -------
    tuple[PyTree, PyTree]
        ``(new_state_3d, new_state_2d)`` after the split-explicit step.
    """
    if n_substeps < 1:
        raise ValueError(f"n_substeps must be >= 1, got {n_substeps}")

    dt_fast = dt_slow / n_substeps

    # --- Fast (barotropic) subcycling ---
    y_2d_sum = jax.tree.map(jax.numpy.zeros_like, state_2d)
    y_2d_curr = state_2d

    def _fast_body(carry, _):
        y_2d, y_2d_acc, substep = carry
        t_sub = substep * dt_fast
        f_fast = rhs_2d(t_sub, y_2d, state_3d)
        y_2d_new = jax.tree.map(lambda y, f: y + dt_fast * f, y_2d, f_fast)
        y_2d_acc_new = jax.tree.map(lambda s, y: s + y, y_2d_acc, y_2d_new)
        return (y_2d_new, y_2d_acc_new, substep + 1), None

    (y_2d_curr, y_2d_sum, _), _ = jax.lax.scan(
        _fast_body,
        (y_2d_curr, y_2d_sum, 0),
        None,
        length=n_substeps,
    )

    # Time-average the fast solution
    y_2d_avg = jax.tree.map(lambda s: s / n_substeps, y_2d_sum)

    # --- Slow (baroclinic) step ---
    f_slow = rhs_3d(state_3d, y_2d_avg)
    y_3d_new = jax.tree.map(lambda y, f: y + dt_slow * f, state_3d, f_slow)

    # --- Coupling ---
    y_3d_new = couple_2d_to_3d(y_3d_new, y_2d_avg)

    return y_3d_new, y_2d_curr

Semi-Lagrangian (Pure Functional)

finitevolx.semi_lagrangian_step(field, u, v, dx, dy, dt, interp_order=1, bc='periodic')

Advect a 2D scalar field using semi-Lagrangian backtracking.

Algorithm::

1. Compute departure points: x_dep = x_i - u * dt,
                              y_dep = y_j - v * dt
2. Interpolate ``field`` at the departure points.
3. Return interpolated values as the new field.

Parameters:

Name Type Description Default
field Array[Ny, Nx]

Scalar field to advect.

required
u Array[Ny, Nx]

Velocity components at the same grid points as field, in physical units (m/s).

required
v Array[Ny, Nx]

Velocity components at the same grid points as field, in physical units (m/s).

required
dx float

Grid spacing in x and y (m).

required
dy float

Grid spacing in x and y (m).

required
dt float

Timestep (s).

required
interp_order int

Interpolation order passed to :func:jax.scipy.ndimage.map_coordinates. Currently JAX only supports order <= 1. Default 1.

1
bc str

Boundary handling: "periodic" (wrap) or "edge" (Neumann-like clamp). Default "periodic".

'periodic'

Returns:

Type Description
Array[Ny, Nx]

Advected field.

Source code in finitevolx/_src/timestepping/semi_lagrangian.py
def semi_lagrangian_step(
    field: jax.Array,
    u: jax.Array,
    v: jax.Array,
    dx: float,
    dy: float,
    dt: float,
    interp_order: int = 1,
    bc: str = "periodic",
) -> jax.Array:
    """Advect a 2D scalar field using semi-Lagrangian backtracking.

    Algorithm::

        1. Compute departure points: x_dep = x_i - u * dt,
                                      y_dep = y_j - v * dt
        2. Interpolate ``field`` at the departure points.
        3. Return interpolated values as the new field.

    Parameters
    ----------
    field : Array[Ny, Nx]
        Scalar field to advect.
    u, v : Array[Ny, Nx]
        Velocity components at the same grid points as ``field``, in
        **physical units** (m/s).
    dx, dy : float
        Grid spacing in x and y (m).
    dt : float
        Timestep (s).
    interp_order : int, optional
        Interpolation order passed to :func:`jax.scipy.ndimage.map_coordinates`.
        Currently JAX only supports ``order <= 1``.  Default 1.
    bc : str, optional
        Boundary handling: ``"periodic"`` (wrap) or ``"edge"``
        (Neumann-like clamp).  Default ``"periodic"``.

    Returns
    -------
    Array[Ny, Nx]
        Advected field.
    """
    if bc not in {"periodic", "edge"}:
        raise ValueError(f"bc must be 'periodic' or 'edge', got {bc!r}")
    if interp_order not in {0, 1}:
        raise ValueError(
            f"interp_order must be 0 or 1 (JAX limitation), got {interp_order}"
        )

    ny, nx = field.shape

    # Target grid coordinates (in index space)
    y_coords, x_coords = jnp.meshgrid(
        jnp.arange(ny, dtype=field.dtype),
        jnp.arange(nx, dtype=field.dtype),
        indexing="ij",
    )

    # Departure points in index space
    x_dep = x_coords - u * dt / dx
    y_dep = y_coords - v * dt / dy

    mode = "wrap" if bc == "periodic" else "nearest"

    return jax.scipy.ndimage.map_coordinates(
        field,
        [y_dep, x_dep],
        order=interp_order,
        mode=mode,
    )

Diffrax Solvers

finitevolx.ForwardEulerDfx

Bases: AbstractSolver

Forward Euler via the diffrax AbstractSolver interface.

Order 1, 1 stage. Included for completeness; prefer :class:diffrax.Euler for production use.

Source code in finitevolx/_src/timestepping/diffrax_solvers.py
class ForwardEulerDfx(dfx.AbstractSolver):
    """Forward Euler via the diffrax ``AbstractSolver`` interface.

    Order 1, 1 stage.  Included for completeness; prefer :class:`diffrax.Euler`
    for production use.
    """

    term_structure: ClassVar[Any] = dfx.AbstractTerm
    interpolation_cls: ClassVar[Any] = dfx.LocalLinearInterpolation

    def order(self, terms):
        return 1

    def init(self, terms, t0, t1, y0, args):
        return None

    def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
        del solver_state, made_jump
        dt = t1 - t0
        f0 = terms.vf(t0, y0, args)
        y1 = jax.tree.map(lambda y, f: y + dt * f, y0, f0)
        dense_info = dict(y0=y0, y1=y1)
        return y1, None, dense_info, None, RESULTS.successful

    def func(self, terms, t0, y0, args):
        return terms.vf(t0, y0, args)

finitevolx.RK2Heun

Bases: AbstractERK

Heun's method (RK2) via Butcher tableau.

Order 2, 2 stages, SSP with C = 1.

Butcher tableau::

0   |
1   | 1
----+--------
    | 1/2  1/2
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
class RK2Heun(dfx.AbstractERK):
    """Heun's method (RK2) via Butcher tableau.

    Order 2, 2 stages, SSP with C = 1.

    Butcher tableau::

        0   |
        1   | 1
        ----+--------
            | 1/2  1/2
    """

    tableau: ClassVar[dfx.ButcherTableau] = dfx.ButcherTableau(
        c=np.array([1.0]),
        b_sol=np.array([0.5, 0.5]),
        b_error=np.zeros(2),
        a_lower=(np.array([1.0]),),
    )
    interpolation_cls: ClassVar[Any] = (
        dfx.ThirdOrderHermitePolynomialInterpolation.from_k
    )

finitevolx.RK3SSP

Bases: AbstractERK

3rd-order Strong-Stability-Preserving Runge-Kutta.

Order 3, 3 stages, SSP coefficient C = 1 (optimal). Preserves monotonicity, positivity, and TVD properties.

Butcher tableau::

0   |
1   | 1
1/2 | 1/4  1/4
----+---------------
    | 1/6  1/6  2/3
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
class RK3SSP(dfx.AbstractERK):
    """3rd-order Strong-Stability-Preserving Runge-Kutta.

    Order 3, 3 stages, SSP coefficient C = 1 (optimal).  Preserves
    monotonicity, positivity, and TVD properties.

    Butcher tableau::

        0   |
        1   | 1
        1/2 | 1/4  1/4
        ----+---------------
            | 1/6  1/6  2/3
    """

    tableau: ClassVar[dfx.ButcherTableau] = dfx.ButcherTableau(
        c=np.array([1.0, 0.5]),
        b_sol=np.array([1.0 / 6, 1.0 / 6, 2.0 / 3]),
        b_error=np.zeros(3),
        a_lower=(
            np.array([1.0]),
            np.array([0.25, 0.25]),
        ),
    )
    interpolation_cls: ClassVar[Any] = (
        dfx.ThirdOrderHermitePolynomialInterpolation.from_k
    )

finitevolx.RK4Classic

Bases: AbstractERK

Classic 4th-order Runge-Kutta.

Order 4, 4 stages, not SSP.

Butcher tableau::

0   |
1/2 | 1/2
1/2 | 0    1/2
1   | 0    0    1
----+------------------
    | 1/6  1/3  1/3  1/6
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
class RK4Classic(dfx.AbstractERK):
    """Classic 4th-order Runge-Kutta.

    Order 4, 4 stages, not SSP.

    Butcher tableau::

        0   |
        1/2 | 1/2
        1/2 | 0    1/2
        1   | 0    0    1
        ----+------------------
            | 1/6  1/3  1/3  1/6
    """

    tableau: ClassVar[dfx.ButcherTableau] = dfx.ButcherTableau(
        c=np.array([0.5, 0.5, 1.0]),
        b_sol=np.array([1.0 / 6, 1.0 / 3, 1.0 / 3, 1.0 / 6]),
        b_error=np.zeros(4),
        a_lower=(
            np.array([0.5]),
            np.array([0.0, 0.5]),
            np.array([0.0, 0.0, 1.0]),
        ),
    )
    interpolation_cls: ClassVar[Any] = (
        dfx.ThirdOrderHermitePolynomialInterpolation.from_k
    )

finitevolx.SSP_RK2

Bases: AbstractERK

2nd-order SSP Runge-Kutta (same as Heun).

Order 2, 2 stages, SSP coefficient C = 1.

Source code in finitevolx/_src/timestepping/diffrax_solvers.py
class SSP_RK2(dfx.AbstractERK):
    """2nd-order SSP Runge-Kutta (same as Heun).

    Order 2, 2 stages, SSP coefficient C = 1.
    """

    tableau: ClassVar[dfx.ButcherTableau] = dfx.ButcherTableau(
        c=np.array([1.0]),
        b_sol=np.array([0.5, 0.5]),
        b_error=np.zeros(2),
        a_lower=(np.array([1.0]),),
    )
    interpolation_cls: ClassVar[Any] = (
        dfx.ThirdOrderHermitePolynomialInterpolation.from_k
    )

finitevolx.SSP_RK104

Bases: AbstractERK

4th-order SSP Runge-Kutta with 10 stages (Ketcheson 2008).

Order 4, 10 stages, SSP coefficient C = 6. Highest SSP coefficient achievable at 4th-order accuracy.

Reference: Ketcheson (2008). Highly efficient strong stability-preserving Runge-Kutta methods with low-storage implementations.

Source code in finitevolx/_src/timestepping/diffrax_solvers.py
class SSP_RK104(dfx.AbstractERK):
    """4th-order SSP Runge-Kutta with 10 stages (Ketcheson 2008).

    Order 4, 10 stages, SSP coefficient C = 6.  Highest SSP coefficient
    achievable at 4th-order accuracy.

    Reference: Ketcheson (2008). Highly efficient strong stability-preserving
    Runge-Kutta methods with low-storage implementations.
    """

    tableau: ClassVar[dfx.ButcherTableau] = dfx.ButcherTableau(
        # c[i] = sum(a_lower[i]) for i = 0..8 (excludes first stage c=0)
        c=np.array(
            [
                1.0 / 6,
                1.0 / 3,
                1.0 / 2,
                2.0 / 3,
                1.0 / 3,
                1.0 / 2,
                2.0 / 3,
                5.0 / 6,
                1.0,
            ]
        ),
        b_sol=np.array([1.0 / 10] * 10),
        b_error=np.zeros(10),
        a_lower=(
            # Rows 1-4: all entries 1/6
            np.array([1.0 / 6]),
            np.array([1.0 / 6, 1.0 / 6]),
            np.array([1.0 / 6, 1.0 / 6, 1.0 / 6]),
            np.array([1.0 / 6, 1.0 / 6, 1.0 / 6, 1.0 / 6]),
            # Rows 5-9: 1/15 for cols 0-4, then 1/6 for later cols
            # (convex combination at stage 6: 3/5*y0 + 2/5*y5 maps 1/6 -> 1/15)
            np.array([1.0 / 15, 1.0 / 15, 1.0 / 15, 1.0 / 15, 1.0 / 15]),
            np.array([1.0 / 15, 1.0 / 15, 1.0 / 15, 1.0 / 15, 1.0 / 15, 1.0 / 6]),
            np.array(
                [1.0 / 15, 1.0 / 15, 1.0 / 15, 1.0 / 15, 1.0 / 15, 1.0 / 6, 1.0 / 6]
            ),
            np.array(
                [
                    1.0 / 15,
                    1.0 / 15,
                    1.0 / 15,
                    1.0 / 15,
                    1.0 / 15,
                    1.0 / 6,
                    1.0 / 6,
                    1.0 / 6,
                ]
            ),
            np.array(
                [
                    1.0 / 15,
                    1.0 / 15,
                    1.0 / 15,
                    1.0 / 15,
                    1.0 / 15,
                    1.0 / 6,
                    1.0 / 6,
                    1.0 / 6,
                    1.0 / 6,
                ]
            ),
        ),
    )
    interpolation_cls: ClassVar[Any] = (
        dfx.ThirdOrderHermitePolynomialInterpolation.from_k
    )

finitevolx.IMEX_SSP2

Bases: AbstractRungeKutta, AbstractImplicitSolver

IMEX-SSP2(2,2,2) solver using diffrax MultiTerm.

Splits the ODE as dy/dt = f_E(t,y) + f_I(t,y). The explicit part is SSP with C = 1; the implicit part is A-stable (SDIRK with gamma = 1 - 1/sqrt(2)).

Usage::

explicit_term = dfx.ODETerm(advection_rhs)
implicit_term = dfx.ODETerm(diffusion_rhs)
terms = dfx.MultiTerm(explicit_term, implicit_term)
solver = IMEX_SSP2()
sol = dfx.diffeqsolve(terms, solver, ...)

Explicit tableau::

0     |
1     | 1
------+---------
      | 1/2  1/2

Implicit tableau (SDIRK, gamma = 1 - 1/sqrt(2))::

gamma     | gamma
1         | 1-2*gamma  gamma
----------+-------------------
          | 1/2        1/2
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
class IMEX_SSP2(dfx.AbstractRungeKutta, dfx.AbstractImplicitSolver):
    """IMEX-SSP2(2,2,2) solver using diffrax MultiTerm.

    Splits the ODE as ``dy/dt = f_E(t,y) + f_I(t,y)``.  The explicit part
    is SSP with C = 1; the implicit part is A-stable (SDIRK with
    gamma = 1 - 1/sqrt(2)).

    Usage::

        explicit_term = dfx.ODETerm(advection_rhs)
        implicit_term = dfx.ODETerm(diffusion_rhs)
        terms = dfx.MultiTerm(explicit_term, implicit_term)
        solver = IMEX_SSP2()
        sol = dfx.diffeqsolve(terms, solver, ...)

    Explicit tableau::

        0     |
        1     | 1
        ------+---------
              | 1/2  1/2

    Implicit tableau (SDIRK, gamma = 1 - 1/sqrt(2))::

        gamma     | gamma
        1         | 1-2*gamma  gamma
        ----------+-------------------
                  | 1/2        1/2
    """

    tableau: ClassVar[dfx.MultiButcherTableau] = dfx.MultiButcherTableau(
        # Explicit part
        dfx.ButcherTableau(
            c=np.array([1.0]),
            b_sol=np.array([0.5, 0.5]),
            b_error=np.zeros(2),
            a_lower=(np.array([1.0]),),
        ),
        # Implicit part (SDIRK)
        # c1 = gamma (first stage), c[0] = a_lower[0][0] + a_diagonal[1]
        #    = (1 - 2*gamma) + gamma = 1 - gamma
        dfx.ButcherTableau(
            c=np.array([1.0 - _GAMMA]),
            b_sol=np.array([0.5, 0.5]),
            b_error=np.zeros(2),
            a_lower=(np.array([1.0 - 2.0 * _GAMMA]),),
            a_diagonal=np.array([_GAMMA, _GAMMA]),
            a_predictor=(np.array([1.0]),),
            c1=_GAMMA,
        ),
    )
    interpolation_cls: ClassVar[Any] = (
        dfx.ThirdOrderHermitePolynomialInterpolation.from_k
    )

    term_structure: ClassVar[Any] = dfx.MultiTerm[
        tuple[dfx.AbstractTerm, dfx.AbstractTerm]
    ]
    calculate_jacobian: ClassVar[Any] = dfx.CalculateJacobian.first_stage

    root_finder: Any = dfx.with_stepsize_controller_tols(dfx.VeryChord)()
    root_find_max_steps: int = 10

Manual Solver Interfaces

finitevolx.AB2Solver

Bases: Module

Adams-Bashforth 2nd-order solver (equinox Module).

Maintains f_prev as part of the solver state. Not compatible with diffrax.diffeqsolve; use the manual init / step interface.

Usage::

solver = AB2Solver()
solver, y = solver.init(rhs_fn, t0, y0, dt)
for n in range(n_steps):
    y, solver = solver.step(rhs_fn, t0 + n * dt, y, dt)
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
class AB2Solver(eqx.Module):
    """Adams-Bashforth 2nd-order solver (equinox Module).

    Maintains ``f_prev`` as part of the solver state.  Not compatible with
    ``diffrax.diffeqsolve``; use the manual ``init`` / ``step`` interface.

    Usage::

        solver = AB2Solver()
        solver, y = solver.init(rhs_fn, t0, y0, dt)
        for n in range(n_steps):
            y, solver = solver.step(rhs_fn, t0 + n * dt, y, dt)
    """

    f_prev: PyTree | None = None

    def init(self, rhs_fn: Callable, t0: float, y0, dt: float):
        """Bootstrap with an RK2 step, returning ``(updated_solver, y1)``.

        Stores ``f_prev = f(t0, y0)`` (i.e. the RHS at the start of the
        bootstrap step) so that the first AB2 step uses the correct history.
        """
        k1 = rhs_fn(t0, y0)
        k2 = rhs_fn(t0 + dt, jax.tree.map(lambda y, f: y + dt * f, y0, k1))
        y1 = jax.tree.map(lambda y, f1, f2: y + 0.5 * dt * (f1 + f2), y0, k1, k2)
        return eqx.tree_at(lambda s: s.f_prev, self, k1), y1

    def step(self, rhs_fn: Callable, t: float, y, dt: float):
        """AB2 step: ``y_{n+1} = y_n + (dt/2)(3 f_n - f_{n-1})``."""
        f_curr = rhs_fn(t, y)
        y_next = jax.tree.map(
            lambda yi, fi, fi_1: yi + (dt / 2.0) * (3.0 * fi - fi_1),
            y,
            f_curr,
            self.f_prev,
        )
        new_solver = eqx.tree_at(lambda s: s.f_prev, self, f_curr)
        return y_next, new_solver

init(rhs_fn, t0, y0, dt)

Bootstrap with an RK2 step, returning (updated_solver, y1).

Stores f_prev = f(t0, y0) (i.e. the RHS at the start of the bootstrap step) so that the first AB2 step uses the correct history.

Source code in finitevolx/_src/timestepping/diffrax_solvers.py
def init(self, rhs_fn: Callable, t0: float, y0, dt: float):
    """Bootstrap with an RK2 step, returning ``(updated_solver, y1)``.

    Stores ``f_prev = f(t0, y0)`` (i.e. the RHS at the start of the
    bootstrap step) so that the first AB2 step uses the correct history.
    """
    k1 = rhs_fn(t0, y0)
    k2 = rhs_fn(t0 + dt, jax.tree.map(lambda y, f: y + dt * f, y0, k1))
    y1 = jax.tree.map(lambda y, f1, f2: y + 0.5 * dt * (f1 + f2), y0, k1, k2)
    return eqx.tree_at(lambda s: s.f_prev, self, k1), y1

step(rhs_fn, t, y, dt)

AB2 step: y_{n+1} = y_n + (dt/2)(3 f_n - f_{n-1}).

Source code in finitevolx/_src/timestepping/diffrax_solvers.py
def step(self, rhs_fn: Callable, t: float, y, dt: float):
    """AB2 step: ``y_{n+1} = y_n + (dt/2)(3 f_n - f_{n-1})``."""
    f_curr = rhs_fn(t, y)
    y_next = jax.tree.map(
        lambda yi, fi, fi_1: yi + (dt / 2.0) * (3.0 * fi - fi_1),
        y,
        f_curr,
        self.f_prev,
    )
    new_solver = eqx.tree_at(lambda s: s.f_prev, self, f_curr)
    return y_next, new_solver

finitevolx.LeapfrogRAFSolver

Bases: Module

Leapfrog with Robert-Asselin filter (equinox Module).

Three-level scheme: y_{n+1} = y_{n-1} + 2 dt f(y_n), with the RAF applied to the middle level to damp the computational mode.

Usage::

solver = LeapfrogRAFSolver(alpha=0.05)
solver, y1 = solver.init(rhs_fn, t0, y0, dt)
y_curr = y1
for n in range(1, n_steps):
    y_curr, solver = solver.step(rhs_fn, t0 + n * dt, y_curr, dt)
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
class LeapfrogRAFSolver(eqx.Module):
    """Leapfrog with Robert-Asselin filter (equinox Module).

    Three-level scheme: ``y_{n+1} = y_{n-1} + 2 dt f(y_n)``, with the RAF
    applied to the middle level to damp the computational mode.

    Usage::

        solver = LeapfrogRAFSolver(alpha=0.05)
        solver, y1 = solver.init(rhs_fn, t0, y0, dt)
        y_curr = y1
        for n in range(1, n_steps):
            y_curr, solver = solver.step(rhs_fn, t0 + n * dt, y_curr, dt)
    """

    alpha: float = 0.05
    y_prev: PyTree | None = None

    def init(self, rhs_fn: Callable, t0: float, y0, dt: float):
        """Bootstrap with an RK2 step, returning ``(updated_solver, y1)``."""
        k1 = rhs_fn(t0, y0)
        k2 = rhs_fn(t0 + dt, jax.tree.map(lambda y, f: y + dt * f, y0, k1))
        y1 = jax.tree.map(lambda y, f1, f2: y + 0.5 * dt * (f1 + f2), y0, k1, k2)
        return eqx.tree_at(lambda s: s.y_prev, self, y0), y1

    def step(self, rhs_fn: Callable, t: float, y_curr, dt: float):
        """Leapfrog + RAF step."""
        f_curr = rhs_fn(t, y_curr)

        # Leapfrog
        y_next = jax.tree.map(
            lambda yp, fc: yp + 2.0 * dt * fc,
            self.y_prev,
            f_curr,
        )

        # Robert-Asselin filter on the middle level
        y_curr_filtered = jax.tree.map(
            lambda yp, yc, yn: yc + self.alpha * (yp - 2.0 * yc + yn),
            self.y_prev,
            y_curr,
            y_next,
        )

        new_solver = eqx.tree_at(lambda s: s.y_prev, self, y_curr_filtered)
        return y_next, new_solver

init(rhs_fn, t0, y0, dt)

Bootstrap with an RK2 step, returning (updated_solver, y1).

Source code in finitevolx/_src/timestepping/diffrax_solvers.py
def init(self, rhs_fn: Callable, t0: float, y0, dt: float):
    """Bootstrap with an RK2 step, returning ``(updated_solver, y1)``."""
    k1 = rhs_fn(t0, y0)
    k2 = rhs_fn(t0 + dt, jax.tree.map(lambda y, f: y + dt * f, y0, k1))
    y1 = jax.tree.map(lambda y, f1, f2: y + 0.5 * dt * (f1 + f2), y0, k1, k2)
    return eqx.tree_at(lambda s: s.y_prev, self, y0), y1

step(rhs_fn, t, y_curr, dt)

Leapfrog + RAF step.

Source code in finitevolx/_src/timestepping/diffrax_solvers.py
def step(self, rhs_fn: Callable, t: float, y_curr, dt: float):
    """Leapfrog + RAF step."""
    f_curr = rhs_fn(t, y_curr)

    # Leapfrog
    y_next = jax.tree.map(
        lambda yp, fc: yp + 2.0 * dt * fc,
        self.y_prev,
        f_curr,
    )

    # Robert-Asselin filter on the middle level
    y_curr_filtered = jax.tree.map(
        lambda yp, yc, yn: yc + self.alpha * (yp - 2.0 * yc + yn),
        self.y_prev,
        y_curr,
        y_next,
    )

    new_solver = eqx.tree_at(lambda s: s.y_prev, self, y_curr_filtered)
    return y_next, new_solver

finitevolx.SplitExplicitRKSolver

Bases: Module

Split-explicit barotropic/baroclinic solver.

Uses Forward-Euler substeps for the fast (2D barotropic) mode and Forward-Euler for the slow (3D baroclinic) mode, with time-averaging of the barotropic solution.

Parameters:

Name Type Description Default
n_substeps int

Number of barotropic substeps per baroclinic step.

required
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
class SplitExplicitRKSolver(eqx.Module):
    """Split-explicit barotropic/baroclinic solver.

    Uses Forward-Euler substeps for the fast (2D barotropic) mode and
    Forward-Euler for the slow (3D baroclinic) mode, with time-averaging
    of the barotropic solution.

    Parameters
    ----------
    n_substeps : int
        Number of barotropic substeps per baroclinic step.
    """

    n_substeps: int = 50

    def step(
        self,
        rhs_slow: Callable,
        rhs_fast: Callable,
        t: float,
        y_3d,
        y_2d,
        dt_slow: float,
    ):
        """Take one split-explicit step.

        Parameters
        ----------
        rhs_slow : Callable[[float, PyTree, PyTree], PyTree]
            Slow RHS: ``rhs_slow(t, y_3d, y_2d_avg) -> tendency_3d``.
        rhs_fast : Callable[[float, PyTree, PyTree], PyTree]
            Fast RHS: ``rhs_fast(t_sub, y_2d, y_3d) -> tendency_2d``.
        t : float
            Current time.
        y_3d, y_2d : PyTree
            3D (slow) and 2D (fast) states.
        dt_slow : float
            Slow timestep.

        Returns
        -------
        tuple[PyTree, PyTree]
            ``(new_y_3d, new_y_2d)``.
        """
        dt_fast = dt_slow / self.n_substeps

        def _fast_body(carry, _):
            y_2d_curr, y_2d_acc, substep = carry
            t_sub = t + substep * dt_fast
            f_fast = rhs_fast(t_sub, y_2d_curr, y_3d)
            y_2d_new = jax.tree.map(lambda y, f: y + dt_fast * f, y_2d_curr, f_fast)
            y_2d_acc_new = jax.tree.map(lambda s, y: s + y, y_2d_acc, y_2d_new)
            return (y_2d_new, y_2d_acc_new, substep + 1), None

        y_2d_sum = jax.tree.map(jnp.zeros_like, y_2d)
        (y_2d_curr, y_2d_sum, _), _ = jax.lax.scan(
            _fast_body,
            (y_2d, y_2d_sum, 0),
            None,
            length=self.n_substeps,
        )

        # Time-average
        y_2d_avg = jax.tree.map(lambda s: s / self.n_substeps, y_2d_sum)

        # Slow step
        f_slow = rhs_slow(t, y_3d, y_2d_avg)
        y_3d_new = jax.tree.map(lambda y, f: y + dt_slow * f, y_3d, f_slow)

        return y_3d_new, y_2d_curr

step(rhs_slow, rhs_fast, t, y_3d, y_2d, dt_slow)

Take one split-explicit step.

Parameters:

Name Type Description Default
rhs_slow Callable[[float, PyTree, PyTree], PyTree]

Slow RHS: rhs_slow(t, y_3d, y_2d_avg) -> tendency_3d.

required
rhs_fast Callable[[float, PyTree, PyTree], PyTree]

Fast RHS: rhs_fast(t_sub, y_2d, y_3d) -> tendency_2d.

required
t float

Current time.

required
y_3d PyTree

3D (slow) and 2D (fast) states.

required
y_2d PyTree

3D (slow) and 2D (fast) states.

required
dt_slow float

Slow timestep.

required

Returns:

Type Description
tuple[PyTree, PyTree]

(new_y_3d, new_y_2d).

Source code in finitevolx/_src/timestepping/diffrax_solvers.py
def step(
    self,
    rhs_slow: Callable,
    rhs_fast: Callable,
    t: float,
    y_3d,
    y_2d,
    dt_slow: float,
):
    """Take one split-explicit step.

    Parameters
    ----------
    rhs_slow : Callable[[float, PyTree, PyTree], PyTree]
        Slow RHS: ``rhs_slow(t, y_3d, y_2d_avg) -> tendency_3d``.
    rhs_fast : Callable[[float, PyTree, PyTree], PyTree]
        Fast RHS: ``rhs_fast(t_sub, y_2d, y_3d) -> tendency_2d``.
    t : float
        Current time.
    y_3d, y_2d : PyTree
        3D (slow) and 2D (fast) states.
    dt_slow : float
        Slow timestep.

    Returns
    -------
    tuple[PyTree, PyTree]
        ``(new_y_3d, new_y_2d)``.
    """
    dt_fast = dt_slow / self.n_substeps

    def _fast_body(carry, _):
        y_2d_curr, y_2d_acc, substep = carry
        t_sub = t + substep * dt_fast
        f_fast = rhs_fast(t_sub, y_2d_curr, y_3d)
        y_2d_new = jax.tree.map(lambda y, f: y + dt_fast * f, y_2d_curr, f_fast)
        y_2d_acc_new = jax.tree.map(lambda s, y: s + y, y_2d_acc, y_2d_new)
        return (y_2d_new, y_2d_acc_new, substep + 1), None

    y_2d_sum = jax.tree.map(jnp.zeros_like, y_2d)
    (y_2d_curr, y_2d_sum, _), _ = jax.lax.scan(
        _fast_body,
        (y_2d, y_2d_sum, 0),
        None,
        length=self.n_substeps,
    )

    # Time-average
    y_2d_avg = jax.tree.map(lambda s: s / self.n_substeps, y_2d_sum)

    # Slow step
    f_slow = rhs_slow(t, y_3d, y_2d_avg)
    y_3d_new = jax.tree.map(lambda y, f: y + dt_slow * f, y_3d, f_slow)

    return y_3d_new, y_2d_curr

finitevolx.SemiLagrangianSolver

Bases: AbstractSolver

Semi-Lagrangian advection solver for diffrax.

Traces characteristic curves backward in time and interpolates the old field at departure points. Unconditionally stable (CFL > 1 allowed).

The terms.vf(t, y, args) must return (u, v) velocity components in grid index units per second (i.e. physical velocity divided by grid spacing).

Parameters:

Name Type Description Default
interpolation_order int

0 = nearest-neighbour, 1 = linear (diffusive, monotone). JAX currently only supports order <= 1.

required
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
class SemiLagrangianSolver(dfx.AbstractSolver):
    """Semi-Lagrangian advection solver for diffrax.

    Traces characteristic curves backward in time and interpolates the old
    field at departure points.  Unconditionally stable (CFL > 1 allowed).

    The ``terms.vf(t, y, args)`` must return ``(u, v)`` velocity components
    in **grid index units per second** (i.e. physical velocity divided by
    grid spacing).

    Parameters
    ----------
    interpolation_order : int
        0 = nearest-neighbour, 1 = linear (diffusive, monotone).
        JAX currently only supports ``order <= 1``.
    """

    term_structure: ClassVar[Any] = dfx.AbstractTerm
    interpolation_cls: ClassVar[Any] = (
        dfx.ThirdOrderHermitePolynomialInterpolation.from_k
    )
    interpolation_order: int = 1

    def order(self, terms):
        return self.interpolation_order

    def init(self, terms, t0, t1, y0, args):
        return None

    def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
        del solver_state, made_jump
        dt = t1 - t0

        # Velocity in grid-index units/second
        u, v = terms.vf(t0, y0, args)

        ny, nx = y0.shape
        y_coords, x_coords = jnp.meshgrid(
            jnp.arange(ny, dtype=y0.dtype),
            jnp.arange(nx, dtype=y0.dtype),
            indexing="ij",
        )

        x_dep = x_coords - u * dt
        y_dep = y_coords - v * dt

        y1 = jax.scipy.ndimage.map_coordinates(
            y0, [y_dep, x_dep], order=self.interpolation_order, mode="wrap"
        )

        dense_info = dict(y0=y0, y1=y1)
        return y1, None, dense_info, None, RESULTS.successful

    def func(self, terms, t0, y0, args):
        return terms.vf(t0, y0, args)

Convenience Wrapper

finitevolx.solve_ocean_pde(rhs_fn, solver, y0, t0, t1, dt0, saveat=None, bc_fn=None, args=None)

Integrate an ocean PDE using diffrax.

Parameters:

Name Type Description Default
rhs_fn Callable

Right-hand side rhs_fn(t, y, args) -> dy/dt.

required
solver AbstractSolver

Time integration scheme (e.g. RK3SSP(), RK4Classic()).

required
y0 PyTree

Initial condition.

required
t0 float

Start and end times.

required
t1 float

Start and end times.

required
dt0 float

Initial (or fixed) timestep.

required
saveat SaveAt

Output saving specification. Defaults to saving the final state.

None
bc_fn Callable

Boundary condition function bc_fn(dydt) -> dydt_corrected applied to the tendency after each RHS evaluation.

None
args PyTree

Static arguments forwarded to rhs_fn.

None

Returns:

Type Description
Solution

Solution object containing the saved states.

Source code in finitevolx/_src/timestepping/_solve.py
def solve_ocean_pde(
    rhs_fn: Callable,
    solver: dfx.AbstractSolver,
    y0: PyTree,
    t0: float,
    t1: float,
    dt0: float,
    saveat: dfx.SaveAt | None = None,
    bc_fn: Callable | None = None,
    args: PyTree = None,
) -> dfx.Solution:
    """Integrate an ocean PDE using diffrax.

    Parameters
    ----------
    rhs_fn : Callable
        Right-hand side ``rhs_fn(t, y, args) -> dy/dt``.
    solver : diffrax.AbstractSolver
        Time integration scheme (e.g. ``RK3SSP()``, ``RK4Classic()``).
    y0 : PyTree
        Initial condition.
    t0, t1 : float
        Start and end times.
    dt0 : float
        Initial (or fixed) timestep.
    saveat : diffrax.SaveAt, optional
        Output saving specification.  Defaults to saving the final state.
    bc_fn : Callable, optional
        Boundary condition function ``bc_fn(dydt) -> dydt_corrected`` applied
        to the tendency after each RHS evaluation.
    args : PyTree, optional
        Static arguments forwarded to ``rhs_fn``.

    Returns
    -------
    diffrax.Solution
        Solution object containing the saved states.
    """
    if bc_fn is not None:

        def rhs_with_bc(t, y, args_):
            dydt = rhs_fn(t, y, args_)
            return bc_fn(dydt)

        term = dfx.ODETerm(rhs_with_bc)
    else:
        term = dfx.ODETerm(rhs_fn)

    return dfx.diffeqsolve(
        term,
        solver,
        t0=t0,
        t1=t1,
        dt0=dt0,
        y0=y0,
        args=args,
        saveat=saveat if saveat is not None else dfx.SaveAt(t1=True),
    )