1D Linear Convection#
# !pip install --upgrade jax jaxlib
# !pip install jaxtyping diffrax xarray FiniteDiffX jaxdf
import typing as tp
import numpy as np
import xarray as xr
import jax
import jax.numpy as jnp
import diffrax as dfx
import finitediffx as fdx
import matplotlib.pyplot as plt
import seaborn as sns
from jaxtyping import Float, Array
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
jax.config.update("jax_enable_x64", True)
%matplotlib inline
Problem#
Let’s continue from the previous tutorial. Recall, we are working with a 1D Linear Convection scheme:
(1)#\[
\frac{\partial u}{\partial t} + c \frac{\partial u}{\partial x} = 0
\]
For the PDE (1), we are going to do a backwards difference discretization in space and a forwards in time.
Geometry#
JaxDF - API#
from jaxdf.geometry import Domain
nx = 51
dx = 0.04
# initialize domain
domain = Domain(N=(nx,), dx=(dx,))
print(f"Size: {domain.size}")
print(f"nDims: {domain.ndim}")
print(f"Grid Size: {domain.grid.shape}")
print(f"Cell Volume: {domain.cell_volume}")
print(f"dx: {domain.dx}")
print(f"Type: {type(domain)}")
Initial Condition#
def init_u0(domain):
"""Initial condition from grid"""
u = jnp.ones_like(domain.grid, dtype=jnp.float64)
u = u.at[int(0.5 / domain.dx[0]) : int(1 / domain.dx[0] + 1)].set(2.0)
return u
u_init = init_u0(domain)
print(type(u_init))
Equation of Motion#
jaxdf
- API#
from jaxdf.discretization import FiniteDifferences, OnGrid
from jaxdf.operators import gradient
def equation_of_motion(t: Array, u: Array, args: tuple):
c = args
# initialize spatial discretization
u = FiniteDifferences.from_grid(u, domain)
u.accuracy = 2
u_rhs = -c * gradient(u, stagger=[1])
return u_rhs.on_grid
from jaxdf.operators.differential import get_fd_coefficients
u = FiniteDifferences.from_grid(u_init, domain)
coeffs = get_fd_coefficients(u, order=1, stagger=1)
coeffs
c = 1.0
# initialize grid
u_init = init_u0(domain)
# RHS of equation of motion
out = equation_of_motion(0, u_init, c)
From Scratch#
from jaxdf.discretization import FiniteDifferences, FourierSeries
from jaxdf.operators import gradient
def equation_of_motion_scratch(t: Array, u: Array, args: tuple):
c = args
u_rhs = fdx.difference(
u, axis=0, accuracy=1, method="backward", step_size=domain.dx[0]
)
return -c * u_rhs
# RHS of equation of motion
out_scratch = equation_of_motion_scratch(0, u_init, c)
Custom Difference Operator#
from jaxdf import operator
import equinox as eqx
from jaxtyping import Float, Array
class FDParams(eqx.Module):
axis: int = eqx.static_field()
accuracy: int = eqx.static_field()
method: str = eqx.static_field()
def __init__(self, axis=0, accuracy=1, method="backward"):
self.axis = axis
self.accuracy = accuracy
self.method = method
class Field(eqx.Module):
u: Array
domain: Domain
@operator
def difference(u: OnGrid, *, params: tp.Optional[FDParams] = None):
if params is None:
params = FDParams()
# extract params from grid object
u_values = u.on_grid
# apply custom FD method
u_values = fdx.difference(
u_values,
axis=params.axis,
accuracy=params.accuracy,
method=params.method,
step_size=u.domain.dx[params.axis],
)
# update grid
u = u.replace_params(u_values)
return u, params
from jaxdf import operator
import equinox as eqx
from jaxtyping import Float, Array
class FDParams(eqx.Module):
axis: int = eqx.static_field()
accuracy: int = eqx.static_field()
method: str = eqx.static_field()
def __init__(self, axis=0, accuracy=1, method="backward"):
self.axis = axis
self.accuracy = accuracy
self.method = method
class Field(eqx.Module):
u: Array
domain: Domain
def __init__(self, domain, init_fn: tp.Callable):
self.u = init_fn(domain)
self.domain = domain
@property
def values(self):
return self.u
@operator
def difference(u: Field, *, params: tp.Optional[FDParams] = None):
if params is None:
params = FDParams()
# apply custom FD method
u_diff = fdx.difference(
u.values,
axis=params.axis,
accuracy=params.accuracy,
method=params.method,
step_size=u.domain.dx[params.axis],
)
# update grid
u = eqx.tree_at(lambda x: x.u, u, u_diff)
return u, params
# initialize Field
u_field = Field(domain, init_u0)
# RHS of equation of motion
u_rhs = difference(u_field)
u_rhs
params = FDParams(axis=0, accuracy=1, method="backward")
u_rhs = difference(u_field, params=params)
u_rhs
def equation_of_motion_custom(t: Array, u: FiniteDifferences, args: tuple):
c = args
# initialize grid
u = OnGrid.from_grid(u, domain)
# u = FourierSeries.from_grid(u, domain)
print(u)
# initialize parameters (OPTIONAL)
params = None # FDParams()
u_rhs = -c * difference(u, params=params)
return u_rhs.on_grid
out_custom = equation_of_motion_custom(0, u_init, c)
fig, ax = plt.subplots()
ax.plot(domain.spatial_axis[0], u_init[..., 0], label="Initial Condition")
ax.plot(domain.spatial_axis[0], out[..., 0], label="JaxDF")
ax.plot(domain.spatial_axis[0], out_scratch[..., 0], label="Scratch")
ax.plot(domain.spatial_axis[0], out_custom[..., 0], label="Custom")
plt.legend()
plt.show()
Time Stepping#
# temporal parameters
c = 1.0
sigma = 0.2
# CFL condition
def cfl_cond(dx, c, sigma):
assert sigma <= 1.0
return (sigma * dx) / c
dt = cfl_cond(dx=domain.dx[0], c=c, sigma=sigma)
t0 = 0.0
t1 = 0.5
ts = jnp.arange(t0, t1, dt)
saveat = dfx.SaveAt(ts=ts)
# Euler, Constant StepSize
solver = dfx.Euler()
stepsize_controller = dfx.ConstantStepSize()
integration_fn = lambda u, f: dfx.diffeqsolve(
terms=dfx.ODETerm(f),
solver=solver,
t0=ts.min(),
t1=ts.max(),
dt0=dt,
y0=u,
saveat=saveat,
args=c,
stepsize_controller=stepsize_controller,
)
sol = integration_fn(u_init, equation_of_motion)
sol_scratch = integration_fn(u_init, equation_of_motion_scratch)
sol_custom = integration_fn(u_init, equation_of_motion_custom)
Analysis#
da_sol = xr.Dataset(
{
"jaxdf": (("time", "x"), np.asarray(sol.ys).squeeze()),
"scratch": (("time", "x"), np.asarray(sol_scratch.ys).squeeze()),
"custom": (("time", "x"), np.asarray(sol_custom.ys).squeeze()),
},
coords={
"x": (["x"], np.asarray(domain.spatial_axis[0])),
"time": (["time"], np.asarray(sol.ts)),
},
attrs={"pde": "linear_convection", "c": c, "sigma": sigma},
)
da_sol
fig, ax = plt.subplots(nrows=3, figsize=(5, 7))
da_sol.jaxdf.T.plot.pcolormesh(ax=ax[0], cmap="gray_r")
da_sol.scratch.T.plot.pcolormesh(ax=ax[1], cmap="gray_r")
da_sol.custom.T.plot.pcolormesh(ax=ax[2], cmap="gray_r")
ax[0].set_title("JaxDF")
ax[1].set_title("Scratch")
ax[2].set_title("Custom")
plt.legend()
plt.tight_layout()
plt.show()
fig, ax = plt.subplots(nrows=3, figsize=(5, 7))
for i in range(0, len(da_sol.time), 5):
da_sol.jaxdf.isel(time=i).plot.line(ax=ax[0], color="gray")
da_sol.scratch.isel(time=i).plot.line(ax=ax[1], color="gray")
da_sol.custom.isel(time=i).plot.line(ax=ax[2], color="gray")
plt.show()