Time Integration
Pure functional integrators, diffrax-based solvers, and convenience wrappers for time-stepping PDEs on Arakawa C-grids.
Explicit Runge-Kutta (Pure Functional)
finitevolx.euler_step(state, rhs_fn, dt)
Forward Euler: y_{n+1} = y_n + dt * f(y_n).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
PyTree
|
Current state (arbitrary JAX pytree). |
required |
rhs_fn
|
Callable[[PyTree], PyTree]
|
Right-hand-side function |
required |
dt
|
float
|
Timestep. |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
Updated state after one Euler step. |
Source code in finitevolx/_src/timestepping/explicit_rk.py
finitevolx.heun_step(state, rhs_fn, dt)
Heun (RK2) predictor-corrector.
.. math::
k_1 = f(y_n)
k_2 = f(y_n + dt \cdot k_1)
y_{n+1} = y_n + (dt/2)(k_1 + k_2)
Order 2, SSP with C = 1.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
PyTree
|
Current state. |
required |
rhs_fn
|
Callable[[PyTree], PyTree]
|
Right-hand-side function. |
required |
dt
|
float
|
Timestep. |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
Updated state after one Heun step. |
Source code in finitevolx/_src/timestepping/explicit_rk.py
finitevolx.rk3_ssp_step(state, rhs_fn, dt)
3rd-order Strong-Stability-Preserving Runge-Kutta (Shu-Osher form).
.. math::
y^{(1)} &= y_n + dt \cdot f(y_n)
y^{(2)} &= \tfrac{3}{4} y_n + \tfrac{1}{4} y^{(1)}
+ \tfrac{1}{4} dt \cdot f(y^{(1)})
y_{n+1} &= \tfrac{1}{3} y_n + \tfrac{2}{3} y^{(2)}
+ \tfrac{2}{3} dt \cdot f(y^{(2)})
Order 3, SSP coefficient C = 1 (optimal). Preserves monotonicity, positivity, and TVD properties.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
PyTree
|
Current state. |
required |
rhs_fn
|
Callable[[PyTree], PyTree]
|
Right-hand-side function. |
required |
dt
|
float
|
Timestep. |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
Updated state after one SSP-RK3 step. |
Source code in finitevolx/_src/timestepping/explicit_rk.py
finitevolx.rk4_step(state, rhs_fn, dt)
Classic 4th-order Runge-Kutta.
.. math::
k_1 &= f(y_n)
k_2 &= f(y_n + (dt/2) k_1)
k_3 &= f(y_n + (dt/2) k_2)
k_4 &= f(y_n + dt \cdot k_3)
y_{n+1} &= y_n + (dt/6)(k_1 + 2 k_2 + 2 k_3 + k_4)
Order 4, 4 stages, not SSP.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
PyTree
|
Current state. |
required |
rhs_fn
|
Callable[[PyTree], PyTree]
|
Right-hand-side function. |
required |
dt
|
float
|
Timestep. |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
Updated state after one RK4 step. |
Source code in finitevolx/_src/timestepping/explicit_rk.py
Multistep Methods (Pure Functional)
finitevolx.ab2_step(state, rhs_fn, dt, rhs_nm1)
2nd-order Adams-Bashforth.
.. math::
y_{n+1} = y_n + (dt/2)(3 f_n - f_{n-1})
Only one RHS evaluation per step (efficiency advantage over RK2).
Requires one previous RHS evaluation rhs_nm1 = f(y_{n-1}).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
PyTree
|
Current state y_n. |
required |
rhs_fn
|
Callable[[PyTree], PyTree]
|
Right-hand-side function |
required |
dt
|
float
|
Timestep. |
required |
rhs_nm1
|
PyTree
|
RHS evaluated at the previous step, |
required |
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, PyTree]
|
|
Source code in finitevolx/_src/timestepping/multistep.py
finitevolx.ab3_step(state, rhs_fn, dt, rhs_nm1, rhs_nm2)
3rd-order Adams-Bashforth.
.. math::
y_{n+1} = y_n + (dt/12)(23 f_n - 16 f_{n-1} + 5 f_{n-2})
One RHS evaluation per step. Requires two previous RHS evaluations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
PyTree
|
Current state y_n. |
required |
rhs_fn
|
Callable[[PyTree], PyTree]
|
Right-hand-side function. |
required |
dt
|
float
|
Timestep. |
required |
rhs_nm1
|
PyTree
|
RHS at step n-1. |
required |
rhs_nm2
|
PyTree
|
RHS at step n-2. |
required |
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, PyTree]
|
|
Source code in finitevolx/_src/timestepping/multistep.py
finitevolx.leapfrog_raf_step(state, state_nm1, rhs_fn, dt, alpha=0.05)
Leapfrog with Robert-Asselin filter.
.. math::
\tilde{y}_{n+1} &= y_{n-1} + 2 \Delta t \, f(y_n)
\bar{y}_n &= y_n + \alpha (y_{n-1} - 2 y_n + \tilde{y}_{n+1})
The filtered middle value bar{y}_n damps the spurious computational
mode inherent to the three-level leapfrog scheme.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
PyTree
|
Current state y_n. |
required |
state_nm1
|
PyTree
|
Previous state y_{n-1}. |
required |
rhs_fn
|
Callable[[PyTree], PyTree]
|
Right-hand-side function. |
required |
dt
|
float
|
Timestep. |
required |
alpha
|
float
|
Robert-Asselin filter coefficient (default 0.05). Typical range 0.01-0.1; larger values damp the computational mode more aggressively but introduce additional dissipation. |
0.05
|
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree]
|
|
Source code in finitevolx/_src/timestepping/multistep.py
IMEX (Pure Functional)
finitevolx.imex_ssp2_step(state, rhs_explicit, rhs_implicit, implicit_solve, dt)
IMEX-SSP2(2,2,2) time step.
The explicit part is SSP with C = 1; the implicit part is A-stable and L-stable (SDIRK with gamma = 1 - 1/sqrt(2)).
Algorithm (two stages)::
Stage 1:
Y_1 = y_n + gamma * dt * f_I(Y_1)
-> solved via implicit_solve(y_n, gamma * dt)
Stage 2:
Y_2_star = y_n + dt * f_E(y_n)
Y_2 = Y_2_star + gamma * dt * f_I(Y_2)
-> solved via implicit_solve(Y_2_star, gamma * dt)
Update:
y_{n+1} = y_n + (dt/2) * [f_E(y_n) + f_E(Y_2)]
+ (dt/2) * [f_I(Y_1) + f_I(Y_2)]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
PyTree
|
Current state y_n. |
required |
rhs_explicit
|
Callable[[PyTree], PyTree]
|
Explicit (non-stiff) right-hand side, e.g. advection. |
required |
rhs_implicit
|
Callable[[PyTree], PyTree]
|
Implicit (stiff) right-hand side, e.g. vertical diffusion. |
required |
implicit_solve
|
Callable[[PyTree, float], PyTree]
|
Solves |
required |
dt
|
float
|
Timestep. |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
Updated state after one IMEX-SSP2 step. |
Source code in finitevolx/_src/timestepping/imex.py
Split-Explicit (Pure Functional)
finitevolx.split_explicit_step(state_3d, state_2d, rhs_3d, rhs_2d, couple_2d_to_3d, dt_slow, n_substeps)
Split-explicit barotropic/baroclinic time step.
Algorithm::
1. Subcycle 2D barotropic with n_substeps Forward-Euler steps
(dt_fast = dt_slow / n_substeps), accumulating a time-average.
2. Couple the time-averaged 2D state into the slow RHS.
3. Advance the 3D baroclinic state with one Forward-Euler step
using dt_slow.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state_3d
|
PyTree
|
3D baroclinic state (slow mode). |
required |
state_2d
|
PyTree
|
2D barotropic state (fast mode). |
required |
rhs_3d
|
Callable[[PyTree, PyTree], PyTree]
|
Slow RHS: |
required |
rhs_2d
|
Callable[[float, PyTree, PyTree], PyTree]
|
Fast RHS: |
required |
couple_2d_to_3d
|
Callable[[PyTree, PyTree], PyTree]
|
Coupling function: |
required |
dt_slow
|
float
|
Slow (baroclinic) timestep. |
required |
n_substeps
|
int
|
Number of fast (barotropic) substeps per slow step. |
required |
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree]
|
|
Source code in finitevolx/_src/timestepping/split_explicit.py
Semi-Lagrangian (Pure Functional)
finitevolx.semi_lagrangian_step(field, u, v, dx, dy, dt, interp_order=1, bc='periodic')
Advect a 2D scalar field using semi-Lagrangian backtracking.
Algorithm::
1. Compute departure points: x_dep = x_i - u * dt,
y_dep = y_j - v * dt
2. Interpolate ``field`` at the departure points.
3. Return interpolated values as the new field.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
field
|
Array[Ny, Nx]
|
Scalar field to advect. |
required |
u
|
Array[Ny, Nx]
|
Velocity components at the same grid points as |
required |
v
|
Array[Ny, Nx]
|
Velocity components at the same grid points as |
required |
dx
|
float
|
Grid spacing in x and y (m). |
required |
dy
|
float
|
Grid spacing in x and y (m). |
required |
dt
|
float
|
Timestep (s). |
required |
interp_order
|
int
|
Interpolation order passed to :func: |
1
|
bc
|
str
|
Boundary handling: |
'periodic'
|
Returns:
| Type | Description |
|---|---|
Array[Ny, Nx]
|
Advected field. |
Source code in finitevolx/_src/timestepping/semi_lagrangian.py
Diffrax Solvers
finitevolx.ForwardEulerDfx
Bases: AbstractSolver
Forward Euler via the diffrax AbstractSolver interface.
Order 1, 1 stage. Included for completeness; prefer :class:diffrax.Euler
for production use.
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
finitevolx.RK2Heun
Bases: AbstractERK
Heun's method (RK2) via Butcher tableau.
Order 2, 2 stages, SSP with C = 1.
Butcher tableau::
0 |
1 | 1
----+--------
| 1/2 1/2
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
finitevolx.RK3SSP
Bases: AbstractERK
3rd-order Strong-Stability-Preserving Runge-Kutta.
Order 3, 3 stages, SSP coefficient C = 1 (optimal). Preserves monotonicity, positivity, and TVD properties.
Butcher tableau::
0 |
1 | 1
1/2 | 1/4 1/4
----+---------------
| 1/6 1/6 2/3
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
finitevolx.RK4Classic
Bases: AbstractERK
Classic 4th-order Runge-Kutta.
Order 4, 4 stages, not SSP.
Butcher tableau::
0 |
1/2 | 1/2
1/2 | 0 1/2
1 | 0 0 1
----+------------------
| 1/6 1/3 1/3 1/6
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
finitevolx.SSP_RK2
Bases: AbstractERK
2nd-order SSP Runge-Kutta (same as Heun).
Order 2, 2 stages, SSP coefficient C = 1.
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
finitevolx.SSP_RK104
Bases: AbstractERK
4th-order SSP Runge-Kutta with 10 stages (Ketcheson 2008).
Order 4, 10 stages, SSP coefficient C = 6. Highest SSP coefficient achievable at 4th-order accuracy.
Reference: Ketcheson (2008). Highly efficient strong stability-preserving Runge-Kutta methods with low-storage implementations.
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
finitevolx.IMEX_SSP2
Bases: AbstractRungeKutta, AbstractImplicitSolver
IMEX-SSP2(2,2,2) solver using diffrax MultiTerm.
Splits the ODE as dy/dt = f_E(t,y) + f_I(t,y). The explicit part
is SSP with C = 1; the implicit part is A-stable (SDIRK with
gamma = 1 - 1/sqrt(2)).
Usage::
explicit_term = dfx.ODETerm(advection_rhs)
implicit_term = dfx.ODETerm(diffusion_rhs)
terms = dfx.MultiTerm(explicit_term, implicit_term)
solver = IMEX_SSP2()
sol = dfx.diffeqsolve(terms, solver, ...)
Explicit tableau::
0 |
1 | 1
------+---------
| 1/2 1/2
Implicit tableau (SDIRK, gamma = 1 - 1/sqrt(2))::
gamma | gamma
1 | 1-2*gamma gamma
----------+-------------------
| 1/2 1/2
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
Manual Solver Interfaces
finitevolx.AB2Solver
Bases: Module
Adams-Bashforth 2nd-order solver (equinox Module).
Maintains f_prev as part of the solver state. Not compatible with
diffrax.diffeqsolve; use the manual init / step interface.
Usage::
solver = AB2Solver()
solver, y = solver.init(rhs_fn, t0, y0, dt)
for n in range(n_steps):
y, solver = solver.step(rhs_fn, t0 + n * dt, y, dt)
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
init(rhs_fn, t0, y0, dt)
Bootstrap with an RK2 step, returning (updated_solver, y1).
Stores f_prev = f(t0, y0) (i.e. the RHS at the start of the
bootstrap step) so that the first AB2 step uses the correct history.
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
step(rhs_fn, t, y, dt)
AB2 step: y_{n+1} = y_n + (dt/2)(3 f_n - f_{n-1}).
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
finitevolx.LeapfrogRAFSolver
Bases: Module
Leapfrog with Robert-Asselin filter (equinox Module).
Three-level scheme: y_{n+1} = y_{n-1} + 2 dt f(y_n), with the RAF
applied to the middle level to damp the computational mode.
Usage::
solver = LeapfrogRAFSolver(alpha=0.05)
solver, y1 = solver.init(rhs_fn, t0, y0, dt)
y_curr = y1
for n in range(1, n_steps):
y_curr, solver = solver.step(rhs_fn, t0 + n * dt, y_curr, dt)
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
init(rhs_fn, t0, y0, dt)
Bootstrap with an RK2 step, returning (updated_solver, y1).
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
step(rhs_fn, t, y_curr, dt)
Leapfrog + RAF step.
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
finitevolx.SplitExplicitRKSolver
Bases: Module
Split-explicit barotropic/baroclinic solver.
Uses Forward-Euler substeps for the fast (2D barotropic) mode and Forward-Euler for the slow (3D baroclinic) mode, with time-averaging of the barotropic solution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_substeps
|
int
|
Number of barotropic substeps per baroclinic step. |
required |
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
step(rhs_slow, rhs_fast, t, y_3d, y_2d, dt_slow)
Take one split-explicit step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rhs_slow
|
Callable[[float, PyTree, PyTree], PyTree]
|
Slow RHS: |
required |
rhs_fast
|
Callable[[float, PyTree, PyTree], PyTree]
|
Fast RHS: |
required |
t
|
float
|
Current time. |
required |
y_3d
|
PyTree
|
3D (slow) and 2D (fast) states. |
required |
y_2d
|
PyTree
|
3D (slow) and 2D (fast) states. |
required |
dt_slow
|
float
|
Slow timestep. |
required |
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree]
|
|
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
finitevolx.SemiLagrangianSolver
Bases: AbstractSolver
Semi-Lagrangian advection solver for diffrax.
Traces characteristic curves backward in time and interpolates the old field at departure points. Unconditionally stable (CFL > 1 allowed).
The terms.vf(t, y, args) must return (u, v) velocity components
in grid index units per second (i.e. physical velocity divided by
grid spacing).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
interpolation_order
|
int
|
0 = nearest-neighbour, 1 = linear (diffusive, monotone).
JAX currently only supports |
required |
Source code in finitevolx/_src/timestepping/diffrax_solvers.py
Convenience Wrapper
finitevolx.solve_ocean_pde(rhs_fn, solver, y0, t0, t1, dt0, saveat=None, bc_fn=None, args=None)
Integrate an ocean PDE using diffrax.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rhs_fn
|
Callable
|
Right-hand side |
required |
solver
|
AbstractSolver
|
Time integration scheme (e.g. |
required |
y0
|
PyTree
|
Initial condition. |
required |
t0
|
float
|
Start and end times. |
required |
t1
|
float
|
Start and end times. |
required |
dt0
|
float
|
Initial (or fixed) timestep. |
required |
saveat
|
SaveAt
|
Output saving specification. Defaults to saving the final state. |
None
|
bc_fn
|
Callable
|
Boundary condition function |
None
|
args
|
PyTree
|
Static arguments forwarded to |
None
|
Returns:
| Type | Description |
|---|---|
Solution
|
Solution object containing the saved states. |