- Jax-ify
 - Don't Reinvent the Wheel
 
import autoroot
import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
import kernex as kex
import finitediffx as fdx
import diffrax as dfx
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm, trange
from jaxtyping import Float, Array, PyTree, ArrayLike
import wandb
from jaxsw._src.domain.base import Domain
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
jax.config.update("jax_enable_x64", True)
%matplotlib inline
%load_ext autoreload
%autoreload 2Let's start with a simple 2D Linear Advection scheme. This PDE is defined as:
Domain¶
nx, ny = 81, 81domain = Domain.from_numpoints(xmin=(0, 0), xmax=(2.0, 2.0), N=(81, 81))
print(f"Size: {domain.size}")
print(f"nDims: {domain.ndim}")
print(f"Grid Size: {domain.grid.shape}")
print(f"Cell Volume: {domain.cell_volume}")Size: (80, 80)
nDims: 2
Grid Size: (80, 80, 2)
Cell Volume: 0.0006250000000000001
def init_hat(domain):
    dx, dy = domain.dx[0], domain.dx[0]
    nx, ny = domain.size[0], domain.size[1]
    u = np.ones((nx, ny))
    u[int(0.5 / dx) : int(1 / dx + 1), int(0.5 / dy) : int(1 / dy + 1)] = 2
    return udef fin_bump(x):
    if x <= 0 or x >= 1:
        return 0
    else:
        return 100 * np.exp(-1.0 / (x - np.power(x, 2.0)))
def init_smooth(domain):
    dx, dy = domain.dx[0], domain.dx[0]
    nx, ny = domain.size[0], domain.size[1]
    u = np.ones((nx, ny))
    for ix in range(nx):
        for iy in range(ny):
            x = ix * dx
            y = iy * dy
            u[ix, iy] = fin_bump(x / 1.5) * fin_bump(y / 1.5) + 1.0
    return udomain.dx[0]0.025# initialize field to be zero
# u_init = init_hat(nx, ny, dx, dy)
u_init = init_smooth(domain)
u = jnp.asarray(u_init)u.shape(80, 80)fig, ax = plt.subplots()
pts = ax.imshow(u, cmap="Reds")
plt.colorbar(pts)
plt.tight_layout()
plt.show()
from matplotlib import cm
grid = domain.grid
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
surf = ax.plot_surface(grid[..., 0], grid[..., 1], u, cmap=cm.coolwarm)
plt.colorbar(surf, shrink=0.5, aspect=5)
plt.tight_layout()
plt.show()
Steps:
- Calculate the RHS
 - Apply the Boundary Conditions
 
Boundary Conditions¶
def bc_fn(u: Array) -> Array:
    u = u.at[0, :].set(1.0)
    u = u.at[-1, :].set(1.0)
    u = u.at[:, 0].set(1.0)
    u = u.at[:, -1].set(1.0)
    return uout = bc_fn(u)Dynamical System (RHS)¶
from jaxsw._src.models.pde import DynamicalSystem
from jaxsw._src.domain.time import TimeDomain
from jaxsw._src.operators.functional import advectionclass LinearAdvection2D(DynamicalSystem):
    @staticmethod
    def equation_of_motion(t: float, u: Array, args) -> Array:
        u = bc_fn(u)
        c, domain = args
        rhs = advection.advection_2D(u=u, a=c, b=c, step_size=domain.dx[0])
        return -rhsc = 0.5
out = LinearAdvection2D.equation_of_motion(0, u, (c, domain))fig, ax = plt.subplots()
pts = ax.imshow(out, cmap="Reds")
plt.colorbar(pts)
plt.tight_layout()
plt.show()
Time Stepping¶
# TEMPORAL DISCRETIZATION
# initialize temporal domain
tmin = 0.0
tmax = 1.0
num_save = 50CFL Condition¶
# CFL condition
def cfl_cond(dx, c, sigma):
    assert sigma <= 1.0
    return (sigma * dx) / c# temporal parameters
c = 0.5
sigma = 0.2
dt = cfl_cond(dx=domain.dx[0], c=c, sigma=sigma)# SPATIAL DISCRETIZATION
u = jnp.asarray(u_init)
t_domain = TimeDomain(tmin=tmin, tmax=tmax, dt=dt)
ts = jnp.linspace(tmin, tmax, num_save)
saveat = dfx.SaveAt(ts=ts)
# DYNAMICAL SYSTEM
dyn_model = LinearAdvection2D(t_domain=t_domain, saveat=saveat)Solver¶
# Euler, Constant StepSize
solver = dfx.Euler()
stepsize_controller = dfx.ConstantStepSize()
# initialize field to be zero
# u_init = init_hat(nx, ny, dx, dy)
u_init = init_smooth(domain)
# integration
sol = dfx.diffeqsolve(
    terms=dfx.ODETerm(dyn_model.equation_of_motion),
    solver=solver,
    t0=ts.min(),
    t1=ts.max(),
    dt0=dt,
    y0=u_init.squeeze(),
    saveat=saveat,
    args=(c, domain),
    stepsize_controller=stepsize_controller,
)Analysis¶
da_sol = xr.DataArray(
    data=np.asarray(sol.ys),
    dims=["time", "x", "y"],
    coords={
        "x": (["x"], np.asarray(domain.coords[0])),
        "y": (["y"], np.asarray(domain.coords[1])),
        "time": (["time"], np.asarray(sol.ts)),
    },
    attrs={"pde": "linear_convection", "c": c, "sigma": sigma},
)
da_solLoading...
fig, ax = plt.subplots(ncols=2, figsize=(10, 4))
da_sol.isel(time=0).T.plot.pcolormesh(ax=ax[0], cmap="RdBu_r")
da_sol.isel(time=-1).T.plot.pcolormesh(ax=ax[1], cmap="RdBu_r")
plt.tight_layout()
plt.show()
fig, ax = plt.subplots(ncols=2, subplot_kw={"projection": "3d"}, figsize=(10, 6))
vmin = da_sol.min()
vmax = da_sol.max()
cbar_kwargs = dict(shrink=0.3, aspect=5)
pts = da_sol.isel(time=0).T.plot.surface(
    ax=ax[0], vmin=vmin, vmax=vmax, cmap="coolwarm", add_colorbar=False
)
plt.colorbar(pts, **cbar_kwargs)
pts = da_sol.isel(time=-1).T.plot.surface(
    ax=ax[1], vmin=vmin, vmax=vmax, cmap="coolwarm", add_colorbar=False
)
plt.colorbar(pts, **cbar_kwargs)
plt.tight_layout()
plt.show()