- Jax-ify
- Don't Reinvent the Wheel
import autoroot
import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
import kernex as kex
import finitediffx as fdx
import diffrax as dfx
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm, trange
from jaxtyping import Float, Array, PyTree, ArrayLike
import wandb
from jaxsw._src.domain.base import Domain
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
jax.config.update("jax_enable_x64", True)
%matplotlib inline
%load_ext autoreload
%autoreload 2
Let's start with a simple 2D Linear Advection scheme. This PDE is defined as:
Domain¶
nx, ny = 81, 81
domain = Domain.from_numpoints(xmin=(0, 0), xmax=(2.0, 2.0), N=(81, 81))
print(f"Size: {domain.size}")
print(f"nDims: {domain.ndim}")
print(f"Grid Size: {domain.grid.shape}")
print(f"Cell Volume: {domain.cell_volume}")
Size: (80, 80)
nDims: 2
Grid Size: (80, 80, 2)
Cell Volume: 0.0006250000000000001
def init_hat(domain):
dx, dy = domain.dx[0], domain.dx[0]
nx, ny = domain.size[0], domain.size[1]
u = np.ones((nx, ny))
u[int(0.5 / dx) : int(1 / dx + 1), int(0.5 / dy) : int(1 / dy + 1)] = 2
return u
def fin_bump(x):
if x <= 0 or x >= 1:
return 0
else:
return 100 * np.exp(-1.0 / (x - np.power(x, 2.0)))
def init_smooth(domain):
dx, dy = domain.dx[0], domain.dx[0]
nx, ny = domain.size[0], domain.size[1]
u = np.ones((nx, ny))
for ix in range(nx):
for iy in range(ny):
x = ix * dx
y = iy * dy
u[ix, iy] = fin_bump(x / 1.5) * fin_bump(y / 1.5) + 1.0
return u
domain.dx[0]
0.025
# initialize field to be zero
# u_init = init_hat(nx, ny, dx, dy)
u_init = init_smooth(domain)
u = jnp.asarray(u_init)
u.shape
(80, 80)
fig, ax = plt.subplots()
pts = ax.imshow(u, cmap="Reds")
plt.colorbar(pts)
plt.tight_layout()
plt.show()
from matplotlib import cm
grid = domain.grid
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
surf = ax.plot_surface(grid[..., 0], grid[..., 1], u, cmap=cm.coolwarm)
plt.colorbar(surf, shrink=0.5, aspect=5)
plt.tight_layout()
plt.show()
Steps:
- Calculate the RHS
- Apply the Boundary Conditions
Boundary Conditions¶
def bc_fn(u: Array) -> Array:
u = u.at[0, :].set(1.0)
u = u.at[-1, :].set(1.0)
u = u.at[:, 0].set(1.0)
u = u.at[:, -1].set(1.0)
return u
out = bc_fn(u)
Dynamical System (RHS)¶
from jaxsw._src.models.pde import DynamicalSystem
from jaxsw._src.domain.time import TimeDomain
from jaxsw._src.operators.functional import advection
class LinearAdvection2D(DynamicalSystem):
@staticmethod
def equation_of_motion(t: float, u: Array, args) -> Array:
u = bc_fn(u)
c, domain = args
rhs = advection.advection_2D(u=u, a=c, b=c, step_size=domain.dx[0])
return -rhs
c = 0.5
out = LinearAdvection2D.equation_of_motion(0, u, (c, domain))
fig, ax = plt.subplots()
pts = ax.imshow(out, cmap="Reds")
plt.colorbar(pts)
plt.tight_layout()
plt.show()
Time Stepping¶
# TEMPORAL DISCRETIZATION
# initialize temporal domain
tmin = 0.0
tmax = 1.0
num_save = 50
CFL Condition¶
# CFL condition
def cfl_cond(dx, c, sigma):
assert sigma <= 1.0
return (sigma * dx) / c
# temporal parameters
c = 0.5
sigma = 0.2
dt = cfl_cond(dx=domain.dx[0], c=c, sigma=sigma)
# SPATIAL DISCRETIZATION
u = jnp.asarray(u_init)
t_domain = TimeDomain(tmin=tmin, tmax=tmax, dt=dt)
ts = jnp.linspace(tmin, tmax, num_save)
saveat = dfx.SaveAt(ts=ts)
# DYNAMICAL SYSTEM
dyn_model = LinearAdvection2D(t_domain=t_domain, saveat=saveat)
Solver¶
# Euler, Constant StepSize
solver = dfx.Euler()
stepsize_controller = dfx.ConstantStepSize()
# initialize field to be zero
# u_init = init_hat(nx, ny, dx, dy)
u_init = init_smooth(domain)
# integration
sol = dfx.diffeqsolve(
terms=dfx.ODETerm(dyn_model.equation_of_motion),
solver=solver,
t0=ts.min(),
t1=ts.max(),
dt0=dt,
y0=u_init.squeeze(),
saveat=saveat,
args=(c, domain),
stepsize_controller=stepsize_controller,
)
Analysis¶
da_sol = xr.DataArray(
data=np.asarray(sol.ys),
dims=["time", "x", "y"],
coords={
"x": (["x"], np.asarray(domain.coords[0])),
"y": (["y"], np.asarray(domain.coords[1])),
"time": (["time"], np.asarray(sol.ts)),
},
attrs={"pde": "linear_convection", "c": c, "sigma": sigma},
)
da_sol
Loading...
fig, ax = plt.subplots(ncols=2, figsize=(10, 4))
da_sol.isel(time=0).T.plot.pcolormesh(ax=ax[0], cmap="RdBu_r")
da_sol.isel(time=-1).T.plot.pcolormesh(ax=ax[1], cmap="RdBu_r")
plt.tight_layout()
plt.show()
fig, ax = plt.subplots(ncols=2, subplot_kw={"projection": "3d"}, figsize=(10, 6))
vmin = da_sol.min()
vmax = da_sol.max()
cbar_kwargs = dict(shrink=0.3, aspect=5)
pts = da_sol.isel(time=0).T.plot.surface(
ax=ax[0], vmin=vmin, vmax=vmax, cmap="coolwarm", add_colorbar=False
)
plt.colorbar(pts, **cbar_kwargs)
pts = da_sol.isel(time=-1).T.plot.surface(
ax=ax[1], vmin=vmin, vmax=vmax, cmap="coolwarm", add_colorbar=False
)
plt.colorbar(pts, **cbar_kwargs)
plt.tight_layout()
plt.show()