This tutorial comes from the following resources:
import autoroot
import jax
import jax.numpy as jnp
import numpy as np
import numba as nb
import equinox as eqx
import kernex as kex
import finitediffx as fdx
import diffrax as dfx
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm, trange
from jaxtyping import Float, Array, PyTree, ArrayLike
import wandb
from jaxsw._src.domain.base import Domain
from jaxsw._src.models.pde import DynamicalSystem
from jaxsw._src.domain.time import TimeDomain
from jaxsw._src.operators import functional as F
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
jax.config.update("jax_enable_x64", True)
%matplotlib inline
%load_ext autoreload
%autoreload 2
Let's start with a simple 2D Diffusion scheme. This PDE is defined as:
where is the diffusivity. Here, we are advised to a 2nd order accurate central difference scheme in space and 1st order temporal scheme (Euler).
Domain¶
nx, ny = 101, 101
xmin, ymin = 0.0, 0.0
xmax, ymax = 2.0, 2.0
domain = Domain.from_numpoints(xmin=(xmin, ymin), xmax=(xmax, ymax), N=(nx, ny))
print(f"Size: {domain.size}")
print(f"nDims: {domain.ndim}")
print(f"Grid Size: {domain.grid.shape}")
print(f"Cell Volume: {domain.cell_volume}")
Size: (100, 100)
nDims: 2
Grid Size: (100, 100, 2)
Cell Volume: 0.0004
def init_u0(domain):
"""Initial condition from grid"""
u = jnp.ones(domain.size, dtype=jnp.float64)
u = u.at[
int(0.5 / domain.dx[0]) : int(1 / domain.dx[0] + 1),
int(0.5 / domain.dx[1]) : int(1 / domain.dx[1] + 1),
].set(2.0)
return u
domain.size
(100, 100)
u_init = init_u0(domain)
grid = domain.grid
u_init.shape
(100, 100)
from matplotlib import cm
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
surf = ax.plot_surface(
grid[..., 0],
grid[..., 1],
u_init,
cmap=cm.coolwarm,
# vmin=u_init.min(), vmax=u.max()+0.1*u.max()
)
plt.colorbar(surf, shrink=0.5, aspect=5)
plt.tight_layout()
plt.show()
# from matplotlib import cm
# fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
# surf = ax.plot_surface(grid[...,0], grid[...,1], u_init_smooth,
# cmap=cm.coolwarm,
# vmin=u.min(), vmax=u.max()+0.1*u.max())
# plt.colorbar(surf, shrink=0.5, aspect=5)
# plt.tight_layout()
# plt.show()
def bc_fn(u: Array) -> Array:
u = u.at[0, :].set(1.0)
u = u.at[-1, :].set(1.0)
u = u.at[:, 0].set(1.0)
u = u.at[:, -1].set(1.0)
return u
Equation of Motion¶
Because we are doing advection, we will use backwards difference for each of the terms.
where is the 2nd order accurate central difference method.
from typing import Optional
from jaxsw._src.operators.functional import diffusion
class Diffusion2D(DynamicalSystem):
@staticmethod
def equation_of_motion(t: float, u: Array, args):
# apply boundary conditions
u = bc_fn(u)
nu, domain = args
u_rhs = diffusion.diffusion_2D(u=u, diffusivity=nu, step_size=domain.dx)
return u_rhs
# SPATIAL DISCRETIZATION
u_init = init_u0(domain)
nu = 0.05
args = nu, domain
out = Diffusion2D.equation_of_motion(0, u_init, args)
out.min(), out.max()
(Array(-250., dtype=float64), Array(125., dtype=float64))
out.shape, u_init.shape
((100, 100), (100, 100))
from matplotlib import cm
fig, ax = plt.subplots(ncols=2, figsize=(7, 3))
surf = ax[0].pcolormesh(
domain.grid[..., 0], domain.grid[..., 1], u_init, cmap=cm.coolwarm
)
plt.colorbar(surf, shrink=0.5, aspect=5)
surf = ax[1].pcolormesh(domain.grid[..., 0], domain.grid[..., 1], out, cmap=cm.coolwarm)
plt.colorbar(surf, shrink=0.5, aspect=5)
plt.tight_layout()
plt.show()
from matplotlib import cm
fig, ax = plt.subplots(ncols=2, figsize=(15, 5), subplot_kw={"projection": "3d"})
surf = ax[0].plot_surface(grid[..., 0], grid[..., 1], u_init, cmap=cm.coolwarm)
plt.colorbar(surf, shrink=0.5, aspect=5)
surf = ax[1].plot_surface(grid[..., 0], grid[..., 1], out, cmap=cm.coolwarm)
plt.colorbar(surf, shrink=0.5, aspect=5)
plt.tight_layout()
plt.show()
Time Stepping¶
# TEMPORAL DISCRETIZATION
# initialize temporal domain
tmin = 0.0
tmax = 2.0
num_save = 50
CFD Condition¶
# temporal parameters
sigma = 0.2
nu = 0.075
dt = sigma * domain.dx[0] * domain.dx[1] / nu
print(f"Step Size (dt): {dt:.4f}")
Step Size (dt): 0.0011
t_domain = TimeDomain(tmin=tmin, tmax=tmax, dt=dt)
ts = jnp.linspace(tmin, tmax, num_save)
saveat = dfx.SaveAt(ts=ts)
# DYNAMICAL SYSTEM
dyn_model = Diffusion2D(t_domain=t_domain, saveat=saveat)
Integration¶
# Euler, Constant StepSize
solver = dfx.Euler()
stepsize_controller = dfx.ConstantStepSize()
# SPATIAL DISCRETIZATION
u_init = init_u0(domain).squeeze()
# integration
sol = dfx.diffeqsolve(
terms=dfx.ODETerm(dyn_model.equation_of_motion),
solver=solver,
t0=ts.min(),
t1=ts.max(),
dt0=dt,
y0=u_init,
saveat=saveat,
args=args,
stepsize_controller=stepsize_controller,
)
Analysis¶
da_sol = xr.DataArray(
data=np.asarray(sol.ys),
dims=["time", "x", "y"],
coords={
"x": (["x"], np.asarray(domain.coords[0])),
"y": (["y"], np.asarray(domain.coords[1])),
"time": (["time"], np.asarray(sol.ts)),
},
attrs={"pde": "linear_convection", "nu": nu, "sigma": sigma},
)
da_sol
Loading...
fig, ax = plt.subplots(ncols=2, figsize=(10, 4))
da_sol.isel(time=0).T.plot.pcolormesh(ax=ax[0], cmap="RdBu_r")
da_sol.isel(time=-1).T.plot.pcolormesh(ax=ax[1], cmap="RdBu_r")
plt.tight_layout()
plt.show()
fig, ax = plt.subplots(ncols=2, subplot_kw={"projection": "3d"}, figsize=(10, 6))
vmin = None # da_sol.min()
vmax = None # da_sol.max()
cbar_kwargs = dict(shrink=0.3, aspect=5)
pts = da_sol.isel(time=0).T.plot.surface(
ax=ax[0], vmin=vmin, vmax=vmax, cmap="coolwarm", add_colorbar=False
)
plt.colorbar(pts, **cbar_kwargs)
pts = da_sol.isel(time=-1).T.plot.surface(
ax=ax[1], vmin=vmin, vmax=vmax, cmap="coolwarm", add_colorbar=False
)
plt.colorbar(pts, **cbar_kwargs)
plt.tight_layout()
plt.show()