2D Linear Convection

J. Emmanuel JohnsonTakaya Uchida
  • Jax-ify
  • Don't Reinvent the Wheel
import autoroot
import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
import kernex as kex
import finitediffx as fdx
import diffrax as dfx
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm, trange
from jaxtyping import Float, Array, PyTree, ArrayLike
import wandb
from jaxsw._src.domain.base import Domain

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
jax.config.update("jax_enable_x64", True)

%matplotlib inline
%load_ext autoreload
%autoreload 2

Let's start with a simple 2D Linear Advection scheme. This PDE is defined as:

ut+cux+cuy=0\frac{\partial u}{\partial t} + c \frac{\partial u}{\partial x} + c\frac{\partial u}{\partial y} = 0

Domain

nx, ny = 81, 81
domain = Domain.from_numpoints(xmin=(0, 0), xmax=(2.0, 2.0), N=(81, 81))
print(f"Size: {domain.size}")
print(f"nDims: {domain.ndim}")
print(f"Grid Size: {domain.grid.shape}")
print(f"Cell Volume: {domain.cell_volume}")
Size: (80, 80)
nDims: 2
Grid Size: (80, 80, 2)
Cell Volume: 0.0006250000000000001
def init_hat(domain):
    dx, dy = domain.dx[0], domain.dx[0]
    nx, ny = domain.size[0], domain.size[1]

    u = np.ones((nx, ny))

    u[int(0.5 / dx) : int(1 / dx + 1), int(0.5 / dy) : int(1 / dy + 1)] = 2

    return u
def fin_bump(x):
    if x <= 0 or x >= 1:
        return 0
    else:
        return 100 * np.exp(-1.0 / (x - np.power(x, 2.0)))


def init_smooth(domain):
    dx, dy = domain.dx[0], domain.dx[0]
    nx, ny = domain.size[0], domain.size[1]

    u = np.ones((nx, ny))

    for ix in range(nx):
        for iy in range(ny):
            x = ix * dx
            y = iy * dy
            u[ix, iy] = fin_bump(x / 1.5) * fin_bump(y / 1.5) + 1.0

    return u
domain.dx[0]
0.025
# initialize field to be zero
# u_init = init_hat(nx, ny, dx, dy)
u_init = init_smooth(domain)

u = jnp.asarray(u_init)
u.shape
(80, 80)
fig, ax = plt.subplots()
pts = ax.imshow(u, cmap="Reds")
plt.colorbar(pts)
plt.tight_layout()
plt.show()
<Figure size 640x480 with 2 Axes>
from matplotlib import cm

grid = domain.grid

fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
surf = ax.plot_surface(grid[..., 0], grid[..., 1], u, cmap=cm.coolwarm)
plt.colorbar(surf, shrink=0.5, aspect=5)
plt.tight_layout()
plt.show()
<Figure size 640x480 with 2 Axes>

Steps:

  1. Calculate the RHS
  2. Apply the Boundary Conditions

Boundary Conditions

def bc_fn(u: Array) -> Array:
    u = u.at[0, :].set(1.0)
    u = u.at[-1, :].set(1.0)
    u = u.at[:, 0].set(1.0)
    u = u.at[:, -1].set(1.0)
    return u
out = bc_fn(u)

Dynamical System (RHS)

from jaxsw._src.models.pde import DynamicalSystem
from jaxsw._src.domain.time import TimeDomain
from jaxsw._src.operators.functional import advection
class LinearAdvection2D(DynamicalSystem):
    @staticmethod
    def equation_of_motion(t: float, u: Array, args) -> Array:
        u = bc_fn(u)

        c, domain = args

        rhs = advection.advection_2D(u=u, a=c, b=c, step_size=domain.dx[0])

        return -rhs
c = 0.5
out = LinearAdvection2D.equation_of_motion(0, u, (c, domain))
fig, ax = plt.subplots()
pts = ax.imshow(out, cmap="Reds")
plt.colorbar(pts)
plt.tight_layout()
plt.show()
<Figure size 640x480 with 2 Axes>

Time Stepping

# TEMPORAL DISCRETIZATION
# initialize temporal domain
tmin = 0.0
tmax = 1.0
num_save = 50

CFL Condition

# CFL condition
def cfl_cond(dx, c, sigma):
    assert sigma <= 1.0
    return (sigma * dx) / c
# temporal parameters
c = 0.5
sigma = 0.2
dt = cfl_cond(dx=domain.dx[0], c=c, sigma=sigma)
# SPATIAL DISCRETIZATION


u = jnp.asarray(u_init)

t_domain = TimeDomain(tmin=tmin, tmax=tmax, dt=dt)
ts = jnp.linspace(tmin, tmax, num_save)
saveat = dfx.SaveAt(ts=ts)

# DYNAMICAL SYSTEM
dyn_model = LinearAdvection2D(t_domain=t_domain, saveat=saveat)

Solver

# Euler, Constant StepSize
solver = dfx.Euler()
stepsize_controller = dfx.ConstantStepSize()

# initialize field to be zero
# u_init = init_hat(nx, ny, dx, dy)
u_init = init_smooth(domain)

# integration
sol = dfx.diffeqsolve(
    terms=dfx.ODETerm(dyn_model.equation_of_motion),
    solver=solver,
    t0=ts.min(),
    t1=ts.max(),
    dt0=dt,
    y0=u_init.squeeze(),
    saveat=saveat,
    args=(c, domain),
    stepsize_controller=stepsize_controller,
)

Analysis

da_sol = xr.DataArray(
    data=np.asarray(sol.ys),
    dims=["time", "x", "y"],
    coords={
        "x": (["x"], np.asarray(domain.coords[0])),
        "y": (["y"], np.asarray(domain.coords[1])),
        "time": (["time"], np.asarray(sol.ts)),
    },
    attrs={"pde": "linear_convection", "c": c, "sigma": sigma},
)
da_sol
Loading...
fig, ax = plt.subplots(ncols=2, figsize=(10, 4))


da_sol.isel(time=0).T.plot.pcolormesh(ax=ax[0], cmap="RdBu_r")
da_sol.isel(time=-1).T.plot.pcolormesh(ax=ax[1], cmap="RdBu_r")

plt.tight_layout()
plt.show()
<Figure size 1000x400 with 4 Axes>
fig, ax = plt.subplots(ncols=2, subplot_kw={"projection": "3d"}, figsize=(10, 6))

vmin = da_sol.min()
vmax = da_sol.max()

cbar_kwargs = dict(shrink=0.3, aspect=5)

pts = da_sol.isel(time=0).T.plot.surface(
    ax=ax[0], vmin=vmin, vmax=vmax, cmap="coolwarm", add_colorbar=False
)
plt.colorbar(pts, **cbar_kwargs)
pts = da_sol.isel(time=-1).T.plot.surface(
    ax=ax[1], vmin=vmin, vmax=vmax, cmap="coolwarm", add_colorbar=False
)
plt.colorbar(pts, **cbar_kwargs)
plt.tight_layout()
plt.show()
<Figure size 1000x600 with 4 Axes>