- Jax-ify
- Don't Reinvent the Wheel
import autoroot
import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
import kernex as kex
import finitediffx as fdx
import diffrax as dfx
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm, trange
from jaxtyping import Float, Array, PyTree, ArrayLike
import wandb
from jaxsw._src.domain.base import Domain
from jaxsw._src.models.pde import DynamicalSystem
from jaxsw._src.domain.time import TimeDomain
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
jax.config.update("jax_enable_x64", True)
%matplotlib inline
%load_ext autoreload
%autoreload 2
Let's start with a simple 1D Diffusion scheme. This PDE is defined as:
Here, we are advised to a 2nd order accurate central difference scheme in space and 1st order temporal scheme (Euler).
Domain¶
nx = 41
xmin = 0.0
xmax = 2.0
domain = Domain.from_numpoints(xmin=(xmin,), xmax=(xmax,), N=(nx,))
print(f"Size: {domain.size}")
print(f"nDims: {domain.ndim}")
print(f"Grid Size: {domain.grid.shape}")
print(f"Cell Volume: {domain.cell_volume}")
Size: (40,)
nDims: 1
Grid Size: (40, 1)
Cell Volume: 0.05
def init_u0(domain):
"""Initial condition from grid"""
u = jnp.ones_like(domain.grid, dtype=jnp.float64)
u = u.at[int(0.5 / domain.dx[0]) : int(1 / domain.dx[0] + 1)].set(2.0)
return u
u_init = init_u0(domain)
fig, ax = plt.subplots(figsize=(5, 3))
ax.plot(domain.grid.squeeze(), u_init)
plt.show()
Equation of Motion¶
Because we are doing diffusion, we will use 2nd order central difference method for the terms.
where is the central finite difference method.
from jaxsw._src.models.pde import DynamicalSystem
from jaxsw._src.domain.time import TimeDomain
from jaxsw._src.operators.functional import diffusion
from typing import Optional
class Diffusion1D(DynamicalSystem):
@staticmethod
def equation_of_motion(t: float, u: Array, args):
nu, domain = args
# FD - u vector
rhs = diffusion.diffusion_1D(u=u, diffusivity=nu, step_size=domain.dx[0])
return rhs
# SPATIAL DISCRETIZATION
u_init = init_u0(domain)
nu = 0.3
out = Diffusion1D.equation_of_motion(0, u_init, (nu, domain))
out.min(), out.max()
(Array(-120., dtype=float64), Array(120., dtype=float64))
fig, ax = plt.subplots(figsize=(5, 3))
ax.plot(domain.grid.squeeze(), u_init)
ax.plot(domain.grid.squeeze(), out)
plt.show()
Time Stepping¶
# TEMPORAL DISCRETIZATION
# initialize temporal domain
tmin = 0.0
tmax = 0.2
num_save = 25
CFD Condition¶
# temporal parameters
c = 1.0
sigma = 0.2
nu = 0.2
dt = sigma * domain.dx[0] ** 2 / nu
# SPATIAL DISCRETIZATION
t_domain = TimeDomain(tmin=tmin, tmax=tmax, dt=dt)
ts = jnp.linspace(tmin, tmax, num_save)
saveat = dfx.SaveAt(ts=ts)
# DYNAMICAL SYSTEM
dyn_model = Diffusion1D(t_domain=t_domain, saveat=saveat)
# Euler, Constant StepSize
solver = dfx.Euler()
stepsize_controller = dfx.ConstantStepSize()
# initial condition
u_init = init_u0(domain)
sol = dfx.diffeqsolve(
terms=dfx.ODETerm(dyn_model.equation_of_motion),
solver=solver,
t0=ts.min(),
t1=ts.max(),
dt0=dt,
y0=u_init.squeeze(),
saveat=saveat,
args=(nu, domain),
stepsize_controller=stepsize_controller,
)
Analysis¶
da_sol = xr.DataArray(
data=np.asarray(sol.ys),
dims=["time", "x"],
coords={
"x": (["x"], np.asarray(domain.coords[0])),
"time": (["time"], np.asarray(sol.ts)),
},
attrs={"pde": "diffusion", "c": c, "sigma": sigma},
)
da_sol
Loading...
fig, ax = plt.subplots()
da_sol.T.plot.pcolormesh(ax=ax, cmap="gray_r")
plt.show()
fig, ax = plt.subplots()
for i in range(0, len(da_sol.time), 5):
da_sol.isel(time=i).plot.line(ax=ax, color="gray")
plt.show()