1D Linear Convection#

# !pip install --upgrade jax jaxlib
# !pip install jaxtyping diffrax xarray FiniteDiffX jaxdf
import typing as tp
import numpy as np
import xarray as xr
import jax
import jax.numpy as jnp
import diffrax as dfx
import finitediffx as fdx
import matplotlib.pyplot as plt
import seaborn as sns
from jaxtyping import Float, Array

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
jax.config.update("jax_enable_x64", True)

%matplotlib inline

Problem#

Let’s continue from the previous tutorial. Recall, we are working with a 1D Linear Convection scheme:

(1)#\[ \frac{\partial u}{\partial t} + c \frac{\partial u}{\partial x} = 0 \]

For the PDE (1), we are going to do a backwards difference discretization in space and a forwards in time.

Geometry#

JaxDF - API#

from jaxdf.geometry import Domain

nx = 51
dx = 0.04

# initialize domain
domain = Domain(N=(nx,), dx=(dx,))

print(f"Size: {domain.size}")
print(f"nDims: {domain.ndim}")
print(f"Grid Size: {domain.grid.shape}")
print(f"Cell Volume: {domain.cell_volume}")
print(f"dx: {domain.dx}")
print(f"Type: {type(domain)}")

Initial Condition#

def init_u0(domain):
    """Initial condition from grid"""
    u = jnp.ones_like(domain.grid, dtype=jnp.float64)

    u = u.at[int(0.5 / domain.dx[0]) : int(1 / domain.dx[0] + 1)].set(2.0)

    return u
u_init = init_u0(domain)

print(type(u_init))

Equation of Motion#

jaxdf - API#

from jaxdf.discretization import FiniteDifferences, OnGrid
from jaxdf.operators import gradient


def equation_of_motion(t: Array, u: Array, args: tuple):
    c = args
    # initialize spatial discretization

    u = FiniteDifferences.from_grid(u, domain)

    u.accuracy = 2

    u_rhs = -c * gradient(u, stagger=[1])

    return u_rhs.on_grid
from jaxdf.operators.differential import get_fd_coefficients

u = FiniteDifferences.from_grid(u_init, domain)

coeffs = get_fd_coefficients(u, order=1, stagger=1)
coeffs
c = 1.0

# initialize grid
u_init = init_u0(domain)

# RHS of equation of motion
out = equation_of_motion(0, u_init, c)

From Scratch#

from jaxdf.discretization import FiniteDifferences, FourierSeries
from jaxdf.operators import gradient


def equation_of_motion_scratch(t: Array, u: Array, args: tuple):
    c = args

    u_rhs = fdx.difference(
        u, axis=0, accuracy=1, method="backward", step_size=domain.dx[0]
    )

    return -c * u_rhs
# RHS of equation of motion
out_scratch = equation_of_motion_scratch(0, u_init, c)

Custom Difference Operator#

from jaxdf import operator
import equinox as eqx
from jaxtyping import Float, Array


class FDParams(eqx.Module):
    axis: int = eqx.static_field()
    accuracy: int = eqx.static_field()
    method: str = eqx.static_field()

    def __init__(self, axis=0, accuracy=1, method="backward"):
        self.axis = axis
        self.accuracy = accuracy
        self.method = method


class Field(eqx.Module):
    u: Array
    domain: Domain


@operator
def difference(u: OnGrid, *, params: tp.Optional[FDParams] = None):
    if params is None:
        params = FDParams()

    # extract params from grid object
    u_values = u.on_grid

    # apply custom FD method
    u_values = fdx.difference(
        u_values,
        axis=params.axis,
        accuracy=params.accuracy,
        method=params.method,
        step_size=u.domain.dx[params.axis],
    )

    # update grid
    u = u.replace_params(u_values)

    return u, params
from jaxdf import operator
import equinox as eqx
from jaxtyping import Float, Array


class FDParams(eqx.Module):
    axis: int = eqx.static_field()
    accuracy: int = eqx.static_field()
    method: str = eqx.static_field()

    def __init__(self, axis=0, accuracy=1, method="backward"):
        self.axis = axis
        self.accuracy = accuracy
        self.method = method


class Field(eqx.Module):
    u: Array
    domain: Domain

    def __init__(self, domain, init_fn: tp.Callable):
        self.u = init_fn(domain)
        self.domain = domain

    @property
    def values(self):
        return self.u


@operator
def difference(u: Field, *, params: tp.Optional[FDParams] = None):
    if params is None:
        params = FDParams()

    # apply custom FD method
    u_diff = fdx.difference(
        u.values,
        axis=params.axis,
        accuracy=params.accuracy,
        method=params.method,
        step_size=u.domain.dx[params.axis],
    )

    # update grid
    u = eqx.tree_at(lambda x: x.u, u, u_diff)

    return u, params
# initialize Field
u_field = Field(domain, init_u0)

# RHS of equation of motion
u_rhs = difference(u_field)

u_rhs
params = FDParams(axis=0, accuracy=1, method="backward")

u_rhs = difference(u_field, params=params)
u_rhs
def equation_of_motion_custom(t: Array, u: FiniteDifferences, args: tuple):
    c = args

    # initialize grid
    u = OnGrid.from_grid(u, domain)
    # u = FourierSeries.from_grid(u, domain)

    print(u)

    # initialize parameters (OPTIONAL)
    params = None  # FDParams()

    u_rhs = -c * difference(u, params=params)

    return u_rhs.on_grid


out_custom = equation_of_motion_custom(0, u_init, c)
fig, ax = plt.subplots()

ax.plot(domain.spatial_axis[0], u_init[..., 0], label="Initial Condition")
ax.plot(domain.spatial_axis[0], out[..., 0], label="JaxDF")
ax.plot(domain.spatial_axis[0], out_scratch[..., 0], label="Scratch")
ax.plot(domain.spatial_axis[0], out_custom[..., 0], label="Custom")

plt.legend()
plt.show()

Time Stepping#

# temporal parameters
c = 1.0
sigma = 0.2


# CFL condition
def cfl_cond(dx, c, sigma):
    assert sigma <= 1.0
    return (sigma * dx) / c


dt = cfl_cond(dx=domain.dx[0], c=c, sigma=sigma)

t0 = 0.0
t1 = 0.5
ts = jnp.arange(t0, t1, dt)
saveat = dfx.SaveAt(ts=ts)
# Euler, Constant StepSize
solver = dfx.Euler()
stepsize_controller = dfx.ConstantStepSize()

integration_fn = lambda u, f: dfx.diffeqsolve(
    terms=dfx.ODETerm(f),
    solver=solver,
    t0=ts.min(),
    t1=ts.max(),
    dt0=dt,
    y0=u,
    saveat=saveat,
    args=c,
    stepsize_controller=stepsize_controller,
)

sol = integration_fn(u_init, equation_of_motion)
sol_scratch = integration_fn(u_init, equation_of_motion_scratch)
sol_custom = integration_fn(u_init, equation_of_motion_custom)

Analysis#

da_sol = xr.Dataset(
    {
        "jaxdf": (("time", "x"), np.asarray(sol.ys).squeeze()),
        "scratch": (("time", "x"), np.asarray(sol_scratch.ys).squeeze()),
        "custom": (("time", "x"), np.asarray(sol_custom.ys).squeeze()),
    },
    coords={
        "x": (["x"], np.asarray(domain.spatial_axis[0])),
        "time": (["time"], np.asarray(sol.ts)),
    },
    attrs={"pde": "linear_convection", "c": c, "sigma": sigma},
)
da_sol
fig, ax = plt.subplots(nrows=3, figsize=(5, 7))

da_sol.jaxdf.T.plot.pcolormesh(ax=ax[0], cmap="gray_r")
da_sol.scratch.T.plot.pcolormesh(ax=ax[1], cmap="gray_r")
da_sol.custom.T.plot.pcolormesh(ax=ax[2], cmap="gray_r")

ax[0].set_title("JaxDF")
ax[1].set_title("Scratch")
ax[2].set_title("Custom")

plt.legend()
plt.tight_layout()
plt.show()
fig, ax = plt.subplots(nrows=3, figsize=(5, 7))

for i in range(0, len(da_sol.time), 5):
    da_sol.jaxdf.isel(time=i).plot.line(ax=ax[0], color="gray")
    da_sol.scratch.isel(time=i).plot.line(ax=ax[1], color="gray")
    da_sol.custom.isel(time=i).plot.line(ax=ax[2], color="gray")

plt.show()