1D Diffusion

J. Emmanuel JohnsonTakaya Uchida
  • Jax-ify
  • Don't Reinvent the Wheel
import autoroot
import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
import kernex as kex
import finitediffx as fdx
import diffrax as dfx
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm, trange
from jaxtyping import Float, Array, PyTree, ArrayLike
import wandb
from jaxsw._src.domain.base import Domain
from jaxsw._src.models.pde import DynamicalSystem
from jaxsw._src.domain.time import TimeDomain

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
jax.config.update("jax_enable_x64", True)

%matplotlib inline
%load_ext autoreload
%autoreload 2

Let's start with a simple 1D Diffusion scheme. This PDE is defined as:

ut=ν2ux2\begin{aligned} \frac{\partial u}{\partial t} &= \nu\frac{\partial^2 u}{\partial x^2} \end{aligned}

Here, we are advised to a 2nd order accurate central difference scheme in space and 1st order temporal scheme (Euler).

Domain

nx = 41
xmin = 0.0
xmax = 2.0
domain = Domain.from_numpoints(xmin=(xmin,), xmax=(xmax,), N=(nx,))

print(f"Size: {domain.size}")
print(f"nDims: {domain.ndim}")
print(f"Grid Size: {domain.grid.shape}")
print(f"Cell Volume: {domain.cell_volume}")
Size: (40,)
nDims: 1
Grid Size: (40, 1)
Cell Volume: 0.05

Initial Conditions

IC[u]={2for x(0.5,1)1everywhere else\begin{aligned} \mathcal{IC}[u] &= \begin{cases} 2 && \text{for }x \in (0.5, 1) \\ 1 && \text{everywhere else} \end{cases} \end{aligned}
def init_u0(domain):
    """Initial condition from grid"""
    u = jnp.ones_like(domain.grid, dtype=jnp.float64)
    u = u.at[int(0.5 / domain.dx[0]) : int(1 / domain.dx[0] + 1)].set(2.0)
    return u
u_init = init_u0(domain)
fig, ax = plt.subplots(figsize=(5, 3))

ax.plot(domain.grid.squeeze(), u_init)

plt.show()
<Figure size 500x300 with 1 Axes>

Equation of Motion

Because we are doing diffusion, we will use 2nd order central difference method for the terms.

Dx[u]:=2ux2\begin{aligned} D_x[u] &:= \frac{\partial^2 u}{\partial x^2} \end{aligned}

where DD is the central finite difference method.

from jaxsw._src.models.pde import DynamicalSystem
from jaxsw._src.domain.time import TimeDomain
from jaxsw._src.operators.functional import diffusion
from typing import Optional


class Diffusion1D(DynamicalSystem):
    @staticmethod
    def equation_of_motion(t: float, u: Array, args):
        nu, domain = args
        # FD - u vector
        rhs = diffusion.diffusion_1D(u=u, diffusivity=nu, step_size=domain.dx[0])

        return rhs
# SPATIAL DISCRETIZATION
u_init = init_u0(domain)

nu = 0.3

out = Diffusion1D.equation_of_motion(0, u_init, (nu, domain))


out.min(), out.max()
(Array(-120., dtype=float64), Array(120., dtype=float64))
fig, ax = plt.subplots(figsize=(5, 3))

ax.plot(domain.grid.squeeze(), u_init)
ax.plot(domain.grid.squeeze(), out)
plt.show()
<Figure size 500x300 with 1 Axes>

Time Stepping

# TEMPORAL DISCRETIZATION
# initialize temporal domain
tmin = 0.0
tmax = 0.2
num_save = 25

CFD Condition

# temporal parameters
c = 1.0
sigma = 0.2
nu = 0.2
dt = sigma * domain.dx[0] ** 2 / nu
# SPATIAL DISCRETIZATION


t_domain = TimeDomain(tmin=tmin, tmax=tmax, dt=dt)

ts = jnp.linspace(tmin, tmax, num_save)

saveat = dfx.SaveAt(ts=ts)

# DYNAMICAL SYSTEM
dyn_model = Diffusion1D(t_domain=t_domain, saveat=saveat)
# Euler, Constant StepSize
solver = dfx.Euler()
stepsize_controller = dfx.ConstantStepSize()

# initial condition
u_init = init_u0(domain)


sol = dfx.diffeqsolve(
    terms=dfx.ODETerm(dyn_model.equation_of_motion),
    solver=solver,
    t0=ts.min(),
    t1=ts.max(),
    dt0=dt,
    y0=u_init.squeeze(),
    saveat=saveat,
    args=(nu, domain),
    stepsize_controller=stepsize_controller,
)

Analysis

da_sol = xr.DataArray(
    data=np.asarray(sol.ys),
    dims=["time", "x"],
    coords={
        "x": (["x"], np.asarray(domain.coords[0])),
        "time": (["time"], np.asarray(sol.ts)),
    },
    attrs={"pde": "diffusion", "c": c, "sigma": sigma},
)
da_sol
Loading...
fig, ax = plt.subplots()

da_sol.T.plot.pcolormesh(ax=ax, cmap="gray_r")

plt.show()
<Figure size 640x480 with 2 Axes>
fig, ax = plt.subplots()

for i in range(0, len(da_sol.time), 5):
    da_sol.isel(time=i).plot.line(ax=ax, color="gray")

plt.show()
<Figure size 640x480 with 1 Axes>