2D Diffusion

J. Emmanuel JohnsonTakaya Uchida

This tutorial comes from the following resources:

  • 12 Steps to Navier-Stokes - 2D Diffusion -ipynb
  • Practical Numerical Methods in Python - 2D Heat - github
  • Differential Equations in the Earth Systems - Heat 2d - github
  • Devito CFD Tutorials - 2D Diffusion - ipynb
import autoroot
import jax
import jax.numpy as jnp
import numpy as np
import numba as nb
import equinox as eqx
import kernex as kex
import finitediffx as fdx
import diffrax as dfx
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm, trange
from jaxtyping import Float, Array, PyTree, ArrayLike
import wandb
from jaxsw._src.domain.base import Domain
from jaxsw._src.models.pde import DynamicalSystem
from jaxsw._src.domain.time import TimeDomain
from jaxsw._src.operators import functional as F

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
jax.config.update("jax_enable_x64", True)

%matplotlib inline
%load_ext autoreload
%autoreload 2

Let's start with a simple 2D Diffusion scheme. This PDE is defined as:

ut=ν2ux2+ν2uy2\begin{aligned} \frac{\partial u}{\partial t} &= \nu\frac{\partial^2 u}{\partial x^2} + \nu\frac{\partial^2 u}{\partial y^2} \end{aligned}

where kk is the diffusivity. Here, we are advised to a 2nd order accurate central difference scheme in space and 1st order temporal scheme (Euler).

Domain

nx, ny = 101, 101
xmin, ymin = 0.0, 0.0
xmax, ymax = 2.0, 2.0
domain = Domain.from_numpoints(xmin=(xmin, ymin), xmax=(xmax, ymax), N=(nx, ny))

print(f"Size: {domain.size}")
print(f"nDims: {domain.ndim}")
print(f"Grid Size: {domain.grid.shape}")
print(f"Cell Volume: {domain.cell_volume}")
Size: (100, 100)
nDims: 2
Grid Size: (100, 100, 2)
Cell Volume: 0.0004

Initial Conditions

IC[u],IC[v]={2for x,y(0.5,1)×(0.5,1)1everywhere else\begin{aligned} \mathcal{IC}[u],\mathcal{IC}[v] &= \begin{cases} 2 && \text{for }x,y \in (0.5, 1)\times(0.5,1) \\ 1 && \text{everywhere else} \end{cases} \end{aligned}
def init_u0(domain):
    """Initial condition from grid"""
    u = jnp.ones(domain.size, dtype=jnp.float64)
    u = u.at[
        int(0.5 / domain.dx[0]) : int(1 / domain.dx[0] + 1),
        int(0.5 / domain.dx[1]) : int(1 / domain.dx[1] + 1),
    ].set(2.0)
    return u
domain.size
(100, 100)
u_init = init_u0(domain)
grid = domain.grid
u_init.shape
(100, 100)
from matplotlib import cm

fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
surf = ax.plot_surface(
    grid[..., 0],
    grid[..., 1],
    u_init,
    cmap=cm.coolwarm,
    # vmin=u_init.min(), vmax=u.max()+0.1*u.max()
)
plt.colorbar(surf, shrink=0.5, aspect=5)
plt.tight_layout()
plt.show()

# from matplotlib import cm

# fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
# surf = ax.plot_surface(grid[...,0], grid[...,1], u_init_smooth,
#                        cmap=cm.coolwarm,
#                        vmin=u.min(), vmax=u.max()+0.1*u.max())
# plt.colorbar(surf, shrink=0.5, aspect=5)
# plt.tight_layout()
# plt.show()
<Figure size 640x480 with 2 Axes>

Boundary Conditions

BC[u](x)=BC[v](x)=1xΩ\begin{aligned} \mathcal{BC}[u](\mathbf{x}) = \mathcal{BC}[v](\mathbf{x})&= 1 && && \mathbf{x}\in\partial\Omega \end{aligned}
def bc_fn(u: Array) -> Array:
    u = u.at[0, :].set(1.0)
    u = u.at[-1, :].set(1.0)
    u = u.at[:, 0].set(1.0)
    u = u.at[:, -1].set(1.0)
    return u

Equation of Motion

Because we are doing advection, we will use backwards difference for each of the terms.

D2x[u]:=2ux2D2y[u]:=2uy2\begin{aligned} D^-{2}_x[u] &:= \frac{\partial^2 u}{\partial x^2} \\ D^-{2}_y[u] &:= \frac{\partial^2 u}{\partial y^2} \end{aligned}

where D2D^2 is the 2nd order accurate central difference method.

from typing import Optional
from jaxsw._src.operators.functional import diffusion


class Diffusion2D(DynamicalSystem):
    @staticmethod
    def equation_of_motion(t: float, u: Array, args):
        # apply boundary conditions
        u = bc_fn(u)

        nu, domain = args

        u_rhs = diffusion.diffusion_2D(u=u, diffusivity=nu, step_size=domain.dx)

        return u_rhs
# SPATIAL DISCRETIZATION
u_init = init_u0(domain)

nu = 0.05
args = nu, domain
out = Diffusion2D.equation_of_motion(0, u_init, args)


out.min(), out.max()
(Array(-250., dtype=float64), Array(125., dtype=float64))
out.shape, u_init.shape
((100, 100), (100, 100))
from matplotlib import cm

fig, ax = plt.subplots(ncols=2, figsize=(7, 3))
surf = ax[0].pcolormesh(
    domain.grid[..., 0], domain.grid[..., 1], u_init, cmap=cm.coolwarm
)
plt.colorbar(surf, shrink=0.5, aspect=5)

surf = ax[1].pcolormesh(domain.grid[..., 0], domain.grid[..., 1], out, cmap=cm.coolwarm)
plt.colorbar(surf, shrink=0.5, aspect=5)

plt.tight_layout()
plt.show()
<Figure size 700x300 with 4 Axes>
from matplotlib import cm

fig, ax = plt.subplots(ncols=2, figsize=(15, 5), subplot_kw={"projection": "3d"})
surf = ax[0].plot_surface(grid[..., 0], grid[..., 1], u_init, cmap=cm.coolwarm)
plt.colorbar(surf, shrink=0.5, aspect=5)
surf = ax[1].plot_surface(grid[..., 0], grid[..., 1], out, cmap=cm.coolwarm)
plt.colorbar(surf, shrink=0.5, aspect=5)
plt.tight_layout()
plt.show()
<Figure size 1500x500 with 4 Axes>

Time Stepping

# TEMPORAL DISCRETIZATION
# initialize temporal domain
tmin = 0.0
tmax = 2.0
num_save = 50

CFD Condition

# temporal parameters
sigma = 0.2
nu = 0.075
dt = sigma * domain.dx[0] * domain.dx[1] / nu
print(f"Step Size (dt): {dt:.4f}")
Step Size (dt): 0.0011
t_domain = TimeDomain(tmin=tmin, tmax=tmax, dt=dt)
ts = jnp.linspace(tmin, tmax, num_save)
saveat = dfx.SaveAt(ts=ts)

# DYNAMICAL SYSTEM
dyn_model = Diffusion2D(t_domain=t_domain, saveat=saveat)

Integration

# Euler, Constant StepSize
solver = dfx.Euler()
stepsize_controller = dfx.ConstantStepSize()

# SPATIAL DISCRETIZATION
u_init = init_u0(domain).squeeze()


# integration
sol = dfx.diffeqsolve(
    terms=dfx.ODETerm(dyn_model.equation_of_motion),
    solver=solver,
    t0=ts.min(),
    t1=ts.max(),
    dt0=dt,
    y0=u_init,
    saveat=saveat,
    args=args,
    stepsize_controller=stepsize_controller,
)

Analysis

da_sol = xr.DataArray(
    data=np.asarray(sol.ys),
    dims=["time", "x", "y"],
    coords={
        "x": (["x"], np.asarray(domain.coords[0])),
        "y": (["y"], np.asarray(domain.coords[1])),
        "time": (["time"], np.asarray(sol.ts)),
    },
    attrs={"pde": "linear_convection", "nu": nu, "sigma": sigma},
)
da_sol
Loading...
fig, ax = plt.subplots(ncols=2, figsize=(10, 4))


da_sol.isel(time=0).T.plot.pcolormesh(ax=ax[0], cmap="RdBu_r")
da_sol.isel(time=-1).T.plot.pcolormesh(ax=ax[1], cmap="RdBu_r")

plt.tight_layout()
plt.show()
<Figure size 1000x400 with 4 Axes>
fig, ax = plt.subplots(ncols=2, subplot_kw={"projection": "3d"}, figsize=(10, 6))

vmin = None  # da_sol.min()
vmax = None  # da_sol.max()

cbar_kwargs = dict(shrink=0.3, aspect=5)

pts = da_sol.isel(time=0).T.plot.surface(
    ax=ax[0], vmin=vmin, vmax=vmax, cmap="coolwarm", add_colorbar=False
)
plt.colorbar(pts, **cbar_kwargs)
pts = da_sol.isel(time=-1).T.plot.surface(
    ax=ax[1], vmin=vmin, vmax=vmax, cmap="coolwarm", add_colorbar=False
)
plt.colorbar(pts, **cbar_kwargs)
plt.tight_layout()
plt.show()
<Figure size 1000x600 with 4 Axes>