Boundary Conditions

J. Emmanuel JohnsonTakaya Uchida
import autoroot
import jax
import jax.numpy as jnp
import kernex as kex
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from jaxtyping import Array
import einops
import finitediffx as fdx
from jaxsw._src.operators.functional import grid as F_grid
from jaxsw._src.boundaries import functional as F_bc


sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
jax.config.update("jax_enable_x64", True)

%matplotlib inline
%load_ext autoreload
%autoreload 2

1-Dimensional

u = jnp.arange(1, 11)
u
Array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=int64)

Padding

Inevitably, there are many ways we may want to pad an array. Some examples include:

  • Symmetric Boundaries
  • Wrap for periodic conditions
mode = "constant"  # "linear_ramp" # "reflect" #  "wrap" #"symmetric" #
constant_values = (100, 100)
u_pad = jnp.pad(u, pad_width=((1, 1)), mode=mode, constant_values=constant_values)
u_pad
Array([100, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 100], dtype=int64)

Boundary Conditions

Periodic Boundary Conditions

mode = "wrap"  # "linear_ramp" # "reflect" #  "wrap" #"symmetric" #
u_periodic = jnp.pad(u, pad_width=((1, 1)), mode=mode)
u_periodic
Array([10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1], dtype=int64)
u_periodic = F_bc.apply_periodic_pad_1D(u)
u_periodic
Array([10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1], dtype=int64)
jnp.gradient(u_periodic)
Array([-9., -4., 1., 1., 1., 1., 1., 1., 1., 1., -4., -9.], dtype=float64)

Neumann Boundaries

u_neumann = jnp.pad(u, pad_width=((1, 1)), mode="constant")
u_neumann = u_neumann.at[0].set(u_neumann[1])
u_neumann = u_neumann.at[-1].set(u_neumann[-2])
u_neumann
Array([ 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10], dtype=int64)
u_neumann = F_bc.apply_neumann_pad_1D(u)
u_neumann
Array([ 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10], dtype=int64)
jnp.gradient(u_neumann)
Array([0. , 0.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.5, 0. ], dtype=float64)

Dirichlet Boundaries

EDFES

# pad the array with empty values
u_dirichlet = jnp.pad(u, pad_width=((1, 1)), mode="empty")

# modify values manually
u_dirichlet = u_dirichlet.at[0].set(-u_dirichlet[1])
u_dirichlet = u_dirichlet.at[-1].set(-u_dirichlet[-2])

u_dirichlet
Array([ -1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, -10], dtype=int64)
u_dirichlet = F_bc.apply_dirichlet_pad_edge_1D(u)
u_dirichlet
Array([ -1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, -10], dtype=int64)
jnp.gradient(u_dirichlet)
Array([ 2. , 1.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , -9.5, -20. ], dtype=float64)

FACES

# pad the array with empty values
u_dirichlet = jnp.pad(u, pad_width=((1, 1)), mode="empty")

# modify values manually
u_dirichlet = u_dirichlet.at[0].set(jnp.asarray(0.0, dtype=u_dirichlet.dtype))
u_dirichlet = u_dirichlet.at[-1].set(jnp.asarray(0.0, dtype=u_dirichlet.dtype))

u_dirichlet
Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0], dtype=int64)
u_dirichlet = F_bc.apply_dirichlet_pad_face_1D(u)
u_dirichlet
Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0], dtype=int64)
jnp.gradient(u_dirichlet)
Array([ 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , -4.5, -10. ], dtype=float64)

Two-Dimensional

u = jnp.arange(1, 6)
u = einops.repeat(u, "Nx -> Nx Ny", Ny=5)

u.T
Array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]], dtype=int64)

Padding

Inevitably, there are many ways we may want to pad an array. Some examples include:

  • Symmetric Boundaries
  • Wrap for periodic conditions
  • Ghost Points
mode = "constant"  # "linear_ramp" # "reflect" #  "wrap" #"symmetric" #
constant_values = jnp.nan  # ((100,100), (100, 100))
u_pad = jnp.pad(
    u, pad_width=((1, 1), (1, 1)), mode=mode, constant_values=constant_values
)
u_pad
Array([[0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 1, 1, 1, 0], [0, 2, 2, 2, 2, 2, 0], [0, 3, 3, 3, 3, 3, 0], [0, 4, 4, 4, 4, 4, 0], [0, 5, 5, 5, 5, 5, 0], [0, 0, 0, 0, 0, 0, 0]], dtype=int64)

Boundary Conditions

Periodic Boundary Conditions

mode = "wrap"  # "linear_ramp" # "reflect" #  "wrap" #"symmetric" #
u_periodic = jnp.pad(u, pad_width=((1, 1), (1, 1)), mode=mode)
u_periodic
Array([[5, 5, 5, 5, 5, 5, 5], [1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5, 5], [1, 1, 1, 1, 1, 1, 1]], dtype=int64)
u_periodic = F_bc.apply_periodic_pad_2D(u)
u_periodic
Array([[5, 5, 5, 5, 5, 5, 5], [1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5, 5], [1, 1, 1, 1, 1, 1, 1]], dtype=int64)
jnp.gradient(u_periodic, axis=0)
Array([[-4. , -4. , -4. , -4. , -4. , -4. , -4. ], [-1.5, -1.5, -1.5, -1.5, -1.5, -1.5, -1.5], [ 1. , 1. , 1. , 1. , 1. , 1. , 1. ], [ 1. , 1. , 1. , 1. , 1. , 1. , 1. ], [ 1. , 1. , 1. , 1. , 1. , 1. , 1. ], [-1.5, -1.5, -1.5, -1.5, -1.5, -1.5, -1.5], [-4. , -4. , -4. , -4. , -4. , -4. , -4. ]], dtype=float64)
jnp.gradient(u_periodic, axis=1)
Array([[0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.]], dtype=float64)

Neumann Boundaries

# F_bc.apply_dirichlet_x??
# F_bc.apply_dirichlet_y??
u_neumann = jnp.pad(u, pad_width=((1, 1), (1, 1)), mode="constant")
u_neumann = F_bc.apply_neumann_y(F_bc.apply_neumann_x(u_neumann))
u_neumann
Array([[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5, 5], [5, 5, 5, 5, 5, 5, 5]], dtype=int64)
u_neumann = F_bc.apply_neumann_pad_2D(u)
u_neumann
Array([[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5, 5], [5, 5, 5, 5, 5, 5, 5]], dtype=int64)
jnp.gradient(u_neumann, axis=0)
Array([[0. , 0. , 0. , 0. , 0. , 0. , 0. ], [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [1. , 1. , 1. , 1. , 1. , 1. , 1. ], [1. , 1. , 1. , 1. , 1. , 1. , 1. ], [1. , 1. , 1. , 1. , 1. , 1. , 1. ], [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0. , 0. , 0. , 0. , 0. , 0. , 0. ]], dtype=float64)
jnp.gradient(u_neumann, axis=1)
Array([[0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0.]], dtype=float64)

Dirichlet Boundaries

jnp.pad(u, pad_width=((1, 1), (1, 1)), mode="symmetric", reflect_type="odd")
Array([[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5, 5], [5, 5, 5, 5, 5, 5, 5]], dtype=int64)

Edges

# pad the array with empty values
u_dirichlet = jnp.pad(u, pad_width=((1, 1), (1, 1)), mode="empty")

# modify values manually
u_dirichlet = F_bc.apply_dirichlet_y_edge(F_bc.apply_dirichlet_x_edge(u_dirichlet))

u_dirichlet
Array([[ 1, -1, -1, -1, -1, -1, 1], [-1, 1, 1, 1, 1, 1, -1], [-2, 2, 2, 2, 2, 2, -2], [-3, 3, 3, 3, 3, 3, -3], [-4, 4, 4, 4, 4, 4, -4], [-5, 5, 5, 5, 5, 5, -5], [ 5, -5, -5, -5, -5, -5, 5]], dtype=int64)
u_dirichlet = F_bc.apply_dirichlet_pad_edge_2D(u)
u_dirichlet
Array([[ 1, -1, -1, -1, -1, -1, 1], [-1, 1, 1, 1, 1, 1, -1], [-2, 2, 2, 2, 2, 2, -2], [-3, 3, 3, 3, 3, 3, -3], [-4, 4, 4, 4, 4, 4, -4], [-5, 5, 5, 5, 5, 5, -5], [ 5, -5, -5, -5, -5, -5, 5]], dtype=int64)
jnp.gradient(u_dirichlet, axis=0)
Array([[ -2. , 2. , 2. , 2. , 2. , 2. , -2. ], [ -1.5, 1.5, 1.5, 1.5, 1.5, 1.5, -1.5], [ -1. , 1. , 1. , 1. , 1. , 1. , -1. ], [ -1. , 1. , 1. , 1. , 1. , 1. , -1. ], [ -1. , 1. , 1. , 1. , 1. , 1. , -1. ], [ 4.5, -4.5, -4.5, -4.5, -4.5, -4.5, 4.5], [ 10. , -10. , -10. , -10. , -10. , -10. , 10. ]], dtype=float64)
jnp.gradient(u_dirichlet, axis=1)
Array([[ -2., -1., 0., 0., 0., 1., 2.], [ 2., 1., 0., 0., 0., -1., -2.], [ 4., 2., 0., 0., 0., -2., -4.], [ 6., 3., 0., 0., 0., -3., -6.], [ 8., 4., 0., 0., 0., -4., -8.], [ 10., 5., 0., 0., 0., -5., -10.], [-10., -5., 0., 0., 0., 5., 10.]], dtype=float64)

FACES

# pad the array with empty values
u_dirichlet = jnp.pad(u, pad_width=((1, 1), (1, 1)), mode="empty")

# modify values manually
u_dirichlet = F_bc.apply_dirichlet_y_face(F_bc.apply_dirichlet_x_face(u_dirichlet))

u_dirichlet
Array([[0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 1, 1, 1, 0], [0, 2, 2, 2, 2, 2, 0], [0, 3, 3, 3, 3, 3, 0], [0, 4, 4, 4, 4, 4, 0], [0, 5, 5, 5, 5, 5, 0], [0, 0, 0, 0, 0, 0, 0]], dtype=int64)
u_dirichlet = F_bc.apply_dirichlet_pad_face_2D(u)
u_dirichlet
Array([[0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 1, 1, 1, 0], [0, 2, 2, 2, 2, 2, 0], [0, 3, 3, 3, 3, 3, 0], [0, 4, 4, 4, 4, 4, 0], [0, 5, 5, 5, 5, 5, 0], [0, 0, 0, 0, 0, 0, 0]], dtype=int64)
jnp.gradient(u_dirichlet, axis=0)
Array([[ 0., 1., 1., 1., 1., 1., 0.], [ 0., 1., 1., 1., 1., 1., 0.], [ 0., 1., 1., 1., 1., 1., 0.], [ 0., 1., 1., 1., 1., 1., 0.], [ 0., 1., 1., 1., 1., 1., 0.], [ 0., -2., -2., -2., -2., -2., 0.], [ 0., -5., -5., -5., -5., -5., 0.]], dtype=float64)
jnp.gradient(u_dirichlet, axis=1)
Array([[ 0. , 0. , 0. , 0. , 0. , 0. , 0. ], [ 1. , 0.5, 0. , 0. , 0. , -0.5, -1. ], [ 2. , 1. , 0. , 0. , 0. , -1. , -2. ], [ 3. , 1.5, 0. , 0. , 0. , -1.5, -3. ], [ 4. , 2. , 0. , 0. , 0. , -2. , -4. ], [ 5. , 2.5, 0. , 0. , 0. , -2.5, -5. ], [ 0. , 0. , 0. , 0. , 0. , 0. , 0. ]], dtype=float64)