In the previous tutorial, we broke up all of the pieces into abstract concepts. We also looked at how to jax-ify all the things by creating functions for all of the bits and pieces; ranging from the domain definition to the stepper. In this tutorial, we will repeat the same procedure but define some
pytree
objects which we will operator. Many times, we need more than just the array values. There is a lot of auxillary information including stuff about the domain.
import autoroot
import jax
import jax.numpy as jnp
import numpy as np
import jaxdf
import kernex as kex
import finitediffx as fdx
import diffrax as dfx
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
from jaxtyping import Float, Array
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
jax.config.update("jax_enable_x64", True)
%matplotlib inline
%load_ext autoreload
%autoreload 2
Recall: PDE - 1D Linear Convection¶
Let's continue from the previous tutorial. Recall, we are working with a 1D Linear Convection scheme:
For the PDE (1), we are going to do a backwards difference discretization in space and a forwards in time.
Spatial Domain¶
We define our spatial discretization as:
Where:
- - the space of coordinates, e.g. Euclidean, Spherical, and the size, e.g. scalar, vector, etc
- - the domain and discretization
For this problem our domain bounds are and our time step is .
from jaxsw._src.domain.base import Domain
There are a few ways to initialize the domain. Here are a few:
- Define the number of values and the step
- Define the start/end points and the number of values
xmin = 0.0
xmax = 2.0
nx = 51
domain = Domain.from_numpoints(xmin=(xmin,), xmax=(xmax,), N=(nx,))
print(f"Size: {domain.size}")
print(f"nDims: {domain.ndim}")
print(f"Grid Size: {domain.grid.shape}")
print(f"Cell Volume: {domain.cell_volume}")
print(f"dx: {domain.dx}")
Size: (50,)
nDims: 1
Grid Size: (50, 1)
Cell Volume: 0.04
dx: (0.04,)
Initial Condition¶
We said that the initial condition is actually a function which operates on a discretized domain .
In practical terms, we need to initialize our state, , based on the domain and discretization, .
from jaxsw._src.models.pde import DynamicalSystem
from jaxsw._src.fields.base import Field
def init_u0(domain):
"""Initial condition from grid"""
u = jnp.ones_like(domain.grid, dtype=jnp.float64)
u = u.at[int(0.5 / domain.dx[0]) : int(1 / domain.dx[0] + 1)].set(2.0)
return u
u_init = init_u0(domain)
Boundary Conditions¶
For the boundary conditions, this is another function that is dependent upon the domain and discretization. It can be defined as:
In practical terms, we want to apply some function, , which modifies our state, , at the coordinates along the boundaries of the domain, i.e. . In our case it is constant. But we can easily think of strategies where it is an actual function.
def bc_fn(u: Float[Array, "D"]) -> Float[Array, "D"]:
u = u.at[0].set(1.0)
u = u.at[-1].set(1.0)
return u
u_out = bc_fn(u_init)
Differential Operators¶
We identified the differential operators, on the RHS of the equation. In this case, we used the backwards time finite difference method because this is a convection scheme. There are (hyper)-parameters, , in every differential operator because there are possibly (ad-hoc) decisions that one must make for all discretizations, e.g. finite difference schem (backwards, forwards, central) and stencil order.
The 1st order backwards difference is defined as:
We showcased how we can use the stencil operator as a way to handle finite differences instead of the messy slicing operations. (Tutorial on this soon!)
du_dx = fdx.difference(
u_init, axis=0, step_size=domain.dx[0], method="backward", accuracy=1
)
fig, ax = plt.subplots()
ax.plot(u_init, label="Convolution")
ax.plot(-du_dx[..., 0], label="Slicing")
plt.legend()
plt.show()
Right Hand Side¶
Now, this is the ultimate part that glues the differential operators and the boundary conditions together.
We dump all of this within the function to define our RHS function.
from jaxsw._src.models.pde import DynamicalSystem
from jaxsw._src.domain.time import TimeDomain
from jaxsw._src.operators.functional import advection
c = 1.0
class LinearAdvection1D(DynamicalSystem):
@staticmethod
def equation_of_motion(t: float, u: Array, args) -> Array:
# u = bc_fn(u)
c, domain = args
rhs = advection.advection_1D(u, a=c, step_size=domain.dx[0], axis=0, accuracy=1)
# rhs = fd_backwards(u, step_size=domain.dx[0])
# rhs = fd_backwards_kernel(u, step_size=domain.dx[0], stencil=stencil, nodes=nodes)
return -rhs
out = LinearAdvection1D.equation_of_motion(0, u_init, (c, domain))
out.shape
(50, 1)
fig, ax = plt.subplots()
ax.plot(u_init)
ax.plot(-out[..., 0])
plt.show()
Time Stepping¶
So in the previous tutorial, we talked about the step function but let's pause for a moment and rethink. The step function is actually a combination of two things:
- The RHS which includes the spatial derivatives.
- The time stepping scheme
We have already taken care of the spatial derivatives in the previous step. So looking again at equations (1) and discretization_full, we write the half step in between them.
where we have the RHS as defined in equation (7) but we did not write the time stepping scheme for Euler. Instead, we're going to offset this to another library called diffrax
.
Temporal Domain¶
We define our temporal discretization as:
where the space time coordinates are bounded by . For our problem, it is bounded between . Recall, the time step was calculated from the CFL condition given by
# SPATIAL DISCRETIZATION
u_init = init_u0(domain)
# TEMPORAL DISCRETIZATION
# initialize temporal domain
t0 = 0.0
tmax = 0.5
# CFL condition
def cfl_cond(dx, c, sigma):
assert sigma <= 1.0
return (sigma * dx) / c
# temporal parameters
c = 1.0
sigma = 0.2
dt = cfl_cond(dx=domain.dx[0], c=c, sigma=sigma)
t_domain = TimeDomain(tmin=0.0, tmax=0.5, dt=dt)
ts = jnp.linspace(t0, tmax, 25)
saveat = dfx.SaveAt(ts=ts)
# DYNAMICAL SYSTEM
dyn_model = LinearAdvection1D(
t_domain=t_domain,
saveat=saveat,
)
# Euler, Constant StepSize
solver = dfx.Euler()
stepsize_controller = dfx.ConstantStepSize()
sol = dfx.diffeqsolve(
terms=dfx.ODETerm(dyn_model.equation_of_motion),
solver=solver,
t0=ts.min(),
t1=ts.max(),
dt0=dt,
y0=u_init,
saveat=saveat,
args=(c, domain),
stepsize_controller=stepsize_controller,
)
sol.ys.shape, sol.ts.shape
((25, 50, 1), (25,))
Analysis¶
da_sol = xr.DataArray(
data=np.asarray(sol.ys).squeeze(),
dims=["time", "x"],
coords={
"x": (["x"], np.asarray(domain.coords[0])),
"time": (["time"], np.asarray(sol.ts)),
},
attrs={"pde": "linear_convection", "c": c, "sigma": sigma},
)
da_sol
fig, ax = plt.subplots()
da_sol.T.plot.pcolormesh(ax=ax, cmap="gray_r")
plt.show()
fig, ax = plt.subplots()
for i in range(0, len(da_sol.time), 5):
da_sol.isel(time=i).plot.line(ax=ax, color="gray")
plt.show()
As we can see, there seems to be some dissipation. Could be the solver? Something to think about. But I think the most important thing is the fact that it is so easy to just try things without getting bogged down by the coding details!!
At least that is what I hope to convey!