import autoroot
import jax
import jax.numpy as jnp
import kernex as kex
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from jaxtyping import Array
import einops
import finitediffx as fdx
from jaxsw._src.operators.functional import grid as F_grid
from jaxsw._src.boundaries import functional as F_bc
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
jax.config.update("jax_enable_x64", True)
%matplotlib inline
%load_ext autoreload
%autoreload 2
1-Dimensional¶
u = jnp.arange(1, 11)
u
Array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=int64)
Padding¶
Inevitably, there are many ways we may want to pad an array. Some examples include:
- Symmetric Boundaries
- Wrap for periodic conditions
mode = "constant" # "linear_ramp" # "reflect" # "wrap" #"symmetric" #
constant_values = (100, 100)
u_pad = jnp.pad(u, pad_width=((1, 1)), mode=mode, constant_values=constant_values)
u_pad
Array([100, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 100], dtype=int64)
Boundary Conditions¶
Periodic Boundary Conditions¶
mode = "wrap" # "linear_ramp" # "reflect" # "wrap" #"symmetric" #
u_periodic = jnp.pad(u, pad_width=((1, 1)), mode=mode)
u_periodic
Array([10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1], dtype=int64)
u_periodic = F_bc.apply_periodic_pad_1D(u)
u_periodic
Array([10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1], dtype=int64)
jnp.gradient(u_periodic)
Array([-9., -4., 1., 1., 1., 1., 1., 1., 1., 1., -4., -9.], dtype=float64)
Neumann Boundaries¶
u_neumann = jnp.pad(u, pad_width=((1, 1)), mode="constant")
u_neumann = u_neumann.at[0].set(u_neumann[1])
u_neumann = u_neumann.at[-1].set(u_neumann[-2])
u_neumann
Array([ 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10], dtype=int64)
u_neumann = F_bc.apply_neumann_pad_1D(u)
u_neumann
Array([ 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10], dtype=int64)
jnp.gradient(u_neumann)
Array([0. , 0.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.5, 0. ], dtype=float64)
Dirichlet Boundaries¶
EDFES
# pad the array with empty values
u_dirichlet = jnp.pad(u, pad_width=((1, 1)), mode="empty")
# modify values manually
u_dirichlet = u_dirichlet.at[0].set(-u_dirichlet[1])
u_dirichlet = u_dirichlet.at[-1].set(-u_dirichlet[-2])
u_dirichlet
Array([ -1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, -10], dtype=int64)
u_dirichlet = F_bc.apply_dirichlet_pad_edge_1D(u)
u_dirichlet
Array([ -1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, -10], dtype=int64)
jnp.gradient(u_dirichlet)
Array([ 2. , 1.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. ,
1. , -9.5, -20. ], dtype=float64)
FACES
# pad the array with empty values
u_dirichlet = jnp.pad(u, pad_width=((1, 1)), mode="empty")
# modify values manually
u_dirichlet = u_dirichlet.at[0].set(jnp.asarray(0.0, dtype=u_dirichlet.dtype))
u_dirichlet = u_dirichlet.at[-1].set(jnp.asarray(0.0, dtype=u_dirichlet.dtype))
u_dirichlet
Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0], dtype=int64)
u_dirichlet = F_bc.apply_dirichlet_pad_face_1D(u)
u_dirichlet
Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0], dtype=int64)
jnp.gradient(u_dirichlet)
Array([ 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,
1. , -4.5, -10. ], dtype=float64)
Two-Dimensional¶
u = jnp.arange(1, 6)
u = einops.repeat(u, "Nx -> Nx Ny", Ny=5)
u.T
Array([[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]], dtype=int64)
Padding¶
Inevitably, there are many ways we may want to pad an array. Some examples include:
- Symmetric Boundaries
- Wrap for periodic conditions
- Ghost Points
mode = "constant" # "linear_ramp" # "reflect" # "wrap" #"symmetric" #
constant_values = jnp.nan # ((100,100), (100, 100))
u_pad = jnp.pad(
u, pad_width=((1, 1), (1, 1)), mode=mode, constant_values=constant_values
)
u_pad
Array([[0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 1, 1, 0],
[0, 2, 2, 2, 2, 2, 0],
[0, 3, 3, 3, 3, 3, 0],
[0, 4, 4, 4, 4, 4, 0],
[0, 5, 5, 5, 5, 5, 0],
[0, 0, 0, 0, 0, 0, 0]], dtype=int64)
Boundary Conditions¶
Periodic Boundary Conditions¶
mode = "wrap" # "linear_ramp" # "reflect" # "wrap" #"symmetric" #
u_periodic = jnp.pad(u, pad_width=((1, 1), (1, 1)), mode=mode)
u_periodic
Array([[5, 5, 5, 5, 5, 5, 5],
[1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4, 4],
[5, 5, 5, 5, 5, 5, 5],
[1, 1, 1, 1, 1, 1, 1]], dtype=int64)
u_periodic = F_bc.apply_periodic_pad_2D(u)
u_periodic
Array([[5, 5, 5, 5, 5, 5, 5],
[1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4, 4],
[5, 5, 5, 5, 5, 5, 5],
[1, 1, 1, 1, 1, 1, 1]], dtype=int64)
jnp.gradient(u_periodic, axis=0)
Array([[-4. , -4. , -4. , -4. , -4. , -4. , -4. ],
[-1.5, -1.5, -1.5, -1.5, -1.5, -1.5, -1.5],
[ 1. , 1. , 1. , 1. , 1. , 1. , 1. ],
[ 1. , 1. , 1. , 1. , 1. , 1. , 1. ],
[ 1. , 1. , 1. , 1. , 1. , 1. , 1. ],
[-1.5, -1.5, -1.5, -1.5, -1.5, -1.5, -1.5],
[-4. , -4. , -4. , -4. , -4. , -4. , -4. ]], dtype=float64)
jnp.gradient(u_periodic, axis=1)
Array([[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.]], dtype=float64)
Neumann Boundaries¶
# F_bc.apply_dirichlet_x??
# F_bc.apply_dirichlet_y??
u_neumann = jnp.pad(u, pad_width=((1, 1), (1, 1)), mode="constant")
u_neumann = F_bc.apply_neumann_y(F_bc.apply_neumann_x(u_neumann))
u_neumann
Array([[1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4, 4],
[5, 5, 5, 5, 5, 5, 5],
[5, 5, 5, 5, 5, 5, 5]], dtype=int64)
u_neumann = F_bc.apply_neumann_pad_2D(u)
u_neumann
Array([[1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4, 4],
[5, 5, 5, 5, 5, 5, 5],
[5, 5, 5, 5, 5, 5, 5]], dtype=int64)
jnp.gradient(u_neumann, axis=0)
Array([[0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
[1. , 1. , 1. , 1. , 1. , 1. , 1. ],
[1. , 1. , 1. , 1. , 1. , 1. , 1. ],
[1. , 1. , 1. , 1. , 1. , 1. , 1. ],
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
[0. , 0. , 0. , 0. , 0. , 0. , 0. ]], dtype=float64)
jnp.gradient(u_neumann, axis=1)
Array([[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.]], dtype=float64)
Dirichlet Boundaries¶
jnp.pad(u, pad_width=((1, 1), (1, 1)), mode="symmetric", reflect_type="odd")
Array([[1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4, 4],
[5, 5, 5, 5, 5, 5, 5],
[5, 5, 5, 5, 5, 5, 5]], dtype=int64)
Edges
# pad the array with empty values
u_dirichlet = jnp.pad(u, pad_width=((1, 1), (1, 1)), mode="empty")
# modify values manually
u_dirichlet = F_bc.apply_dirichlet_y_edge(F_bc.apply_dirichlet_x_edge(u_dirichlet))
u_dirichlet
Array([[ 1, -1, -1, -1, -1, -1, 1],
[-1, 1, 1, 1, 1, 1, -1],
[-2, 2, 2, 2, 2, 2, -2],
[-3, 3, 3, 3, 3, 3, -3],
[-4, 4, 4, 4, 4, 4, -4],
[-5, 5, 5, 5, 5, 5, -5],
[ 5, -5, -5, -5, -5, -5, 5]], dtype=int64)
u_dirichlet = F_bc.apply_dirichlet_pad_edge_2D(u)
u_dirichlet
Array([[ 1, -1, -1, -1, -1, -1, 1],
[-1, 1, 1, 1, 1, 1, -1],
[-2, 2, 2, 2, 2, 2, -2],
[-3, 3, 3, 3, 3, 3, -3],
[-4, 4, 4, 4, 4, 4, -4],
[-5, 5, 5, 5, 5, 5, -5],
[ 5, -5, -5, -5, -5, -5, 5]], dtype=int64)
jnp.gradient(u_dirichlet, axis=0)
Array([[ -2. , 2. , 2. , 2. , 2. , 2. , -2. ],
[ -1.5, 1.5, 1.5, 1.5, 1.5, 1.5, -1.5],
[ -1. , 1. , 1. , 1. , 1. , 1. , -1. ],
[ -1. , 1. , 1. , 1. , 1. , 1. , -1. ],
[ -1. , 1. , 1. , 1. , 1. , 1. , -1. ],
[ 4.5, -4.5, -4.5, -4.5, -4.5, -4.5, 4.5],
[ 10. , -10. , -10. , -10. , -10. , -10. , 10. ]], dtype=float64)
jnp.gradient(u_dirichlet, axis=1)
Array([[ -2., -1., 0., 0., 0., 1., 2.],
[ 2., 1., 0., 0., 0., -1., -2.],
[ 4., 2., 0., 0., 0., -2., -4.],
[ 6., 3., 0., 0., 0., -3., -6.],
[ 8., 4., 0., 0., 0., -4., -8.],
[ 10., 5., 0., 0., 0., -5., -10.],
[-10., -5., 0., 0., 0., 5., 10.]], dtype=float64)
FACES
# pad the array with empty values
u_dirichlet = jnp.pad(u, pad_width=((1, 1), (1, 1)), mode="empty")
# modify values manually
u_dirichlet = F_bc.apply_dirichlet_y_face(F_bc.apply_dirichlet_x_face(u_dirichlet))
u_dirichlet
Array([[0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 1, 1, 0],
[0, 2, 2, 2, 2, 2, 0],
[0, 3, 3, 3, 3, 3, 0],
[0, 4, 4, 4, 4, 4, 0],
[0, 5, 5, 5, 5, 5, 0],
[0, 0, 0, 0, 0, 0, 0]], dtype=int64)
u_dirichlet = F_bc.apply_dirichlet_pad_face_2D(u)
u_dirichlet
Array([[0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 1, 1, 0],
[0, 2, 2, 2, 2, 2, 0],
[0, 3, 3, 3, 3, 3, 0],
[0, 4, 4, 4, 4, 4, 0],
[0, 5, 5, 5, 5, 5, 0],
[0, 0, 0, 0, 0, 0, 0]], dtype=int64)
jnp.gradient(u_dirichlet, axis=0)
Array([[ 0., 1., 1., 1., 1., 1., 0.],
[ 0., 1., 1., 1., 1., 1., 0.],
[ 0., 1., 1., 1., 1., 1., 0.],
[ 0., 1., 1., 1., 1., 1., 0.],
[ 0., 1., 1., 1., 1., 1., 0.],
[ 0., -2., -2., -2., -2., -2., 0.],
[ 0., -5., -5., -5., -5., -5., 0.]], dtype=float64)
jnp.gradient(u_dirichlet, axis=1)
Array([[ 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 1. , 0.5, 0. , 0. , 0. , -0.5, -1. ],
[ 2. , 1. , 0. , 0. , 0. , -1. , -2. ],
[ 3. , 1.5, 0. , 0. , 0. , -1.5, -3. ],
[ 4. , 2. , 0. , 0. , 0. , -2. , -4. ],
[ 5. , 2.5, 0. , 0. , 0. , -2.5, -5. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. ]], dtype=float64)