import autoroot
import jax
import jax.numpy as jnp
import kernex as kex
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from jaxtyping import Array
import einops
import finitediffx as fdx
from jaxsw._src.operators.functional import grid as F_grid
from jaxsw._src.boundaries import functional as F_bc
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
jax.config.update("jax_enable_x64", True)
%matplotlib inline
%load_ext autoreload
%autoreload 21-Dimensional¶
u = jnp.arange(1, 11)
uArray([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=int64)Padding¶
Inevitably, there are many ways we may want to pad an array. Some examples include:
- Symmetric Boundaries
- Wrap for periodic conditions
mode = "constant" # "linear_ramp" # "reflect" # "wrap" #"symmetric" #
constant_values = (100, 100)
u_pad = jnp.pad(u, pad_width=((1, 1)), mode=mode, constant_values=constant_values)
u_padArray([100, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 100], dtype=int64)Boundary Conditions¶
Periodic Boundary Conditions¶
mode = "wrap" # "linear_ramp" # "reflect" # "wrap" #"symmetric" #
u_periodic = jnp.pad(u, pad_width=((1, 1)), mode=mode)
u_periodicArray([10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1], dtype=int64)u_periodic = F_bc.apply_periodic_pad_1D(u)
u_periodicArray([10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1], dtype=int64)jnp.gradient(u_periodic)Array([-9., -4., 1., 1., 1., 1., 1., 1., 1., 1., -4., -9.], dtype=float64)Neumann Boundaries¶
u_neumann = jnp.pad(u, pad_width=((1, 1)), mode="constant")
u_neumann = u_neumann.at[0].set(u_neumann[1])
u_neumann = u_neumann.at[-1].set(u_neumann[-2])
u_neumannArray([ 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10], dtype=int64)u_neumann = F_bc.apply_neumann_pad_1D(u)
u_neumannArray([ 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10], dtype=int64)jnp.gradient(u_neumann)Array([0. , 0.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.5, 0. ], dtype=float64)Dirichlet Boundaries¶
EDFES
# pad the array with empty values
u_dirichlet = jnp.pad(u, pad_width=((1, 1)), mode="empty")
# modify values manually
u_dirichlet = u_dirichlet.at[0].set(-u_dirichlet[1])
u_dirichlet = u_dirichlet.at[-1].set(-u_dirichlet[-2])
u_dirichletArray([ -1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, -10], dtype=int64)u_dirichlet = F_bc.apply_dirichlet_pad_edge_1D(u)
u_dirichletArray([ -1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, -10], dtype=int64)jnp.gradient(u_dirichlet)Array([ 2. , 1.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. ,
1. , -9.5, -20. ], dtype=float64)FACES
# pad the array with empty values
u_dirichlet = jnp.pad(u, pad_width=((1, 1)), mode="empty")
# modify values manually
u_dirichlet = u_dirichlet.at[0].set(jnp.asarray(0.0, dtype=u_dirichlet.dtype))
u_dirichlet = u_dirichlet.at[-1].set(jnp.asarray(0.0, dtype=u_dirichlet.dtype))
u_dirichletArray([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0], dtype=int64)u_dirichlet = F_bc.apply_dirichlet_pad_face_1D(u)
u_dirichletArray([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0], dtype=int64)jnp.gradient(u_dirichlet)Array([ 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,
1. , -4.5, -10. ], dtype=float64)Two-Dimensional¶
u = jnp.arange(1, 6)
u = einops.repeat(u, "Nx -> Nx Ny", Ny=5)
u.TArray([[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]], dtype=int64)Padding¶
Inevitably, there are many ways we may want to pad an array. Some examples include:
- Symmetric Boundaries
- Wrap for periodic conditions
- Ghost Points
mode = "constant" # "linear_ramp" # "reflect" # "wrap" #"symmetric" #
constant_values = jnp.nan # ((100,100), (100, 100))
u_pad = jnp.pad(
u, pad_width=((1, 1), (1, 1)), mode=mode, constant_values=constant_values
)
u_padArray([[0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 1, 1, 0],
[0, 2, 2, 2, 2, 2, 0],
[0, 3, 3, 3, 3, 3, 0],
[0, 4, 4, 4, 4, 4, 0],
[0, 5, 5, 5, 5, 5, 0],
[0, 0, 0, 0, 0, 0, 0]], dtype=int64)Boundary Conditions¶
Periodic Boundary Conditions¶
mode = "wrap" # "linear_ramp" # "reflect" # "wrap" #"symmetric" #
u_periodic = jnp.pad(u, pad_width=((1, 1), (1, 1)), mode=mode)
u_periodicArray([[5, 5, 5, 5, 5, 5, 5],
[1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4, 4],
[5, 5, 5, 5, 5, 5, 5],
[1, 1, 1, 1, 1, 1, 1]], dtype=int64)u_periodic = F_bc.apply_periodic_pad_2D(u)
u_periodicArray([[5, 5, 5, 5, 5, 5, 5],
[1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4, 4],
[5, 5, 5, 5, 5, 5, 5],
[1, 1, 1, 1, 1, 1, 1]], dtype=int64)jnp.gradient(u_periodic, axis=0)Array([[-4. , -4. , -4. , -4. , -4. , -4. , -4. ],
[-1.5, -1.5, -1.5, -1.5, -1.5, -1.5, -1.5],
[ 1. , 1. , 1. , 1. , 1. , 1. , 1. ],
[ 1. , 1. , 1. , 1. , 1. , 1. , 1. ],
[ 1. , 1. , 1. , 1. , 1. , 1. , 1. ],
[-1.5, -1.5, -1.5, -1.5, -1.5, -1.5, -1.5],
[-4. , -4. , -4. , -4. , -4. , -4. , -4. ]], dtype=float64)jnp.gradient(u_periodic, axis=1)Array([[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.]], dtype=float64)Neumann Boundaries¶
# F_bc.apply_dirichlet_x??
# F_bc.apply_dirichlet_y??u_neumann = jnp.pad(u, pad_width=((1, 1), (1, 1)), mode="constant")
u_neumann = F_bc.apply_neumann_y(F_bc.apply_neumann_x(u_neumann))
u_neumannArray([[1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4, 4],
[5, 5, 5, 5, 5, 5, 5],
[5, 5, 5, 5, 5, 5, 5]], dtype=int64)u_neumann = F_bc.apply_neumann_pad_2D(u)
u_neumannArray([[1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4, 4],
[5, 5, 5, 5, 5, 5, 5],
[5, 5, 5, 5, 5, 5, 5]], dtype=int64)jnp.gradient(u_neumann, axis=0)Array([[0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
[1. , 1. , 1. , 1. , 1. , 1. , 1. ],
[1. , 1. , 1. , 1. , 1. , 1. , 1. ],
[1. , 1. , 1. , 1. , 1. , 1. , 1. ],
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
[0. , 0. , 0. , 0. , 0. , 0. , 0. ]], dtype=float64)jnp.gradient(u_neumann, axis=1)Array([[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.]], dtype=float64)Dirichlet Boundaries¶
jnp.pad(u, pad_width=((1, 1), (1, 1)), mode="symmetric", reflect_type="odd")Array([[1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4, 4],
[5, 5, 5, 5, 5, 5, 5],
[5, 5, 5, 5, 5, 5, 5]], dtype=int64)Edges
# pad the array with empty values
u_dirichlet = jnp.pad(u, pad_width=((1, 1), (1, 1)), mode="empty")
# modify values manually
u_dirichlet = F_bc.apply_dirichlet_y_edge(F_bc.apply_dirichlet_x_edge(u_dirichlet))
u_dirichletArray([[ 1, -1, -1, -1, -1, -1, 1],
[-1, 1, 1, 1, 1, 1, -1],
[-2, 2, 2, 2, 2, 2, -2],
[-3, 3, 3, 3, 3, 3, -3],
[-4, 4, 4, 4, 4, 4, -4],
[-5, 5, 5, 5, 5, 5, -5],
[ 5, -5, -5, -5, -5, -5, 5]], dtype=int64)u_dirichlet = F_bc.apply_dirichlet_pad_edge_2D(u)
u_dirichletArray([[ 1, -1, -1, -1, -1, -1, 1],
[-1, 1, 1, 1, 1, 1, -1],
[-2, 2, 2, 2, 2, 2, -2],
[-3, 3, 3, 3, 3, 3, -3],
[-4, 4, 4, 4, 4, 4, -4],
[-5, 5, 5, 5, 5, 5, -5],
[ 5, -5, -5, -5, -5, -5, 5]], dtype=int64)jnp.gradient(u_dirichlet, axis=0)Array([[ -2. , 2. , 2. , 2. , 2. , 2. , -2. ],
[ -1.5, 1.5, 1.5, 1.5, 1.5, 1.5, -1.5],
[ -1. , 1. , 1. , 1. , 1. , 1. , -1. ],
[ -1. , 1. , 1. , 1. , 1. , 1. , -1. ],
[ -1. , 1. , 1. , 1. , 1. , 1. , -1. ],
[ 4.5, -4.5, -4.5, -4.5, -4.5, -4.5, 4.5],
[ 10. , -10. , -10. , -10. , -10. , -10. , 10. ]], dtype=float64)jnp.gradient(u_dirichlet, axis=1)Array([[ -2., -1., 0., 0., 0., 1., 2.],
[ 2., 1., 0., 0., 0., -1., -2.],
[ 4., 2., 0., 0., 0., -2., -4.],
[ 6., 3., 0., 0., 0., -3., -6.],
[ 8., 4., 0., 0., 0., -4., -8.],
[ 10., 5., 0., 0., 0., -5., -10.],
[-10., -5., 0., 0., 0., 5., 10.]], dtype=float64)FACES
# pad the array with empty values
u_dirichlet = jnp.pad(u, pad_width=((1, 1), (1, 1)), mode="empty")
# modify values manually
u_dirichlet = F_bc.apply_dirichlet_y_face(F_bc.apply_dirichlet_x_face(u_dirichlet))
u_dirichletArray([[0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 1, 1, 0],
[0, 2, 2, 2, 2, 2, 0],
[0, 3, 3, 3, 3, 3, 0],
[0, 4, 4, 4, 4, 4, 0],
[0, 5, 5, 5, 5, 5, 0],
[0, 0, 0, 0, 0, 0, 0]], dtype=int64)u_dirichlet = F_bc.apply_dirichlet_pad_face_2D(u)
u_dirichletArray([[0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 1, 1, 0],
[0, 2, 2, 2, 2, 2, 0],
[0, 3, 3, 3, 3, 3, 0],
[0, 4, 4, 4, 4, 4, 0],
[0, 5, 5, 5, 5, 5, 0],
[0, 0, 0, 0, 0, 0, 0]], dtype=int64)jnp.gradient(u_dirichlet, axis=0)Array([[ 0., 1., 1., 1., 1., 1., 0.],
[ 0., 1., 1., 1., 1., 1., 0.],
[ 0., 1., 1., 1., 1., 1., 0.],
[ 0., 1., 1., 1., 1., 1., 0.],
[ 0., 1., 1., 1., 1., 1., 0.],
[ 0., -2., -2., -2., -2., -2., 0.],
[ 0., -5., -5., -5., -5., -5., 0.]], dtype=float64)jnp.gradient(u_dirichlet, axis=1)Array([[ 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 1. , 0.5, 0. , 0. , 0. , -0.5, -1. ],
[ 2. , 1. , 0. , 0. , 0. , -1. , -2. ],
[ 3. , 1.5, 0. , 0. , 0. , -1.5, -3. ],
[ 4. , 2. , 0. , 0. , 0. , -2. , -4. ],
[ 5. , 2.5, 0. , 0. , 0. , -2.5, -5. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. ]], dtype=float64)