Grid Operations

How can I use the Grid Operations to deal with staggered grids?

Authors
Affiliations
J. Emmanuel Johnson
CNRS
MEOM
Takaya Uchida
FSU
import autoroot
import jax
import jax.numpy as jnp
import numpy as np
import kernex as kex
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from jaxtyping import Array
import einops
import finitediffx as fdx
from jaxsw._src.operators.functional import grid as F_grid
from jaxsw._src.operators.functional import cgrid as C_grid
from jaxsw._src.boundaries import functional as F_bc
from jaxsw._src.domain.base import Domain
from jaxsw._src.fields.base import Field
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
jax.config.update("jax_enable_x64", True)

%matplotlib inline
%load_ext autoreload
%autoreload 2

1D Arakawa C-Grid

+ -- ⋅ -- +
u -- u̅ -- u
+ -- ⋅ -- +
# QOI
q: Array["Nx"] = ...
# U-Velocity
u: Array["Nx"] = ...
import typing as tp
ncols = 5

# initialize the Q domain
q_domain = Domain(xmin=(0,), xmax=(ncols,), dx=(1,))
values = q_domain.grid[..., 0]
q = Field(values, q_domain)


# initialize the u-velocity domain
u_domain = Domain(xmin=(0,), xmax=(ncols,), dx=(1,), stagger=("right",))
values = u_domain.grid[..., 0]
u = Field(values, u_domain)

q.values.shape, u.values.shape
((6,), (6,))
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

fig, ax = plt.subplots(figsize=(4, 2))

ax.scatter(q.values, jnp.ones_like(q.values), marker="x", color="black", zorder=2)
ax.scatter(u.values, jnp.ones_like(u.values), marker=">", color="tab:blue", zorder=2)
ax.xaxis.set_major_locator(MultipleLocator(1))

ax.set(xlim=[-0.5, ncols + 0.5])
ax.grid(which="major", zorder=1)
plt.legend()
plt.tight_layout()
plt.show()
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
<Figure size 400x200 with 1 Axes>
jnp.pad(q.values.squeeze(), pad_width=((2, 2)), mode="reflect")
Array([2., 1., 0., 1., 2., 3., 4., 5., 4., 3.], dtype=float64)
u.domain.grid.shape
(6, 1)
print(q.values.squeeze(), q.values.shape)
q_on_u = F_grid.grid_operator(q, ("right",))
q_on_u.values, u.values
[0. 1. 2. 3. 4. 5.] (6,)
(6,)
(Array([0.5, 1.5, 2.5, 3.5, 4.5, 5. ], dtype=float64), Array([0.5, 1.5, 2.5, 3.5, 4.5, 5.5], dtype=float64))
q.domain.grid.shape
(6, 1)
u_on_q = F_grid.grid_operator(u, ("left",))
u_on_q.values, q.values
(6,)
(Array([0.5, 1. , 2. , 3. , 4. , 5. ], dtype=float64), Array([0., 1., 2., 3., 4., 5.], dtype=float64))
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

fig, ax = plt.subplots(figsize=(4, 2))

ax.scatter(q.values, jnp.ones_like(q.values), marker="x", color="black", zorder=2)
ax.scatter(
    u_on_q.values, np.ones_like(u_on_q.values), marker=">", color="tab:blue", zorder=2
)
ax.xaxis.set_major_locator(MultipleLocator(1))

ax.set(xlim=[-0.5, ncols + 0.5])
ax.grid(which="major", zorder=1)
plt.legend()
plt.tight_layout()
plt.show()
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
<Figure size 400x200 with 1 Axes>

Nodes to Edges

This happens when we want to move the field defined on the field to the top-down edges. For example, we may want to use the quantity of interest defined on the nodes to estimate the velocities defined on the top-down edges.

# QOI
q: Array["Nx"] = ...
# U-Velocity
u: Array["Nx-1"] = ...
fig, ax = plt.subplots(figsize=(4, 2))

ax.scatter(
    q_on_u.values, np.ones_like(q_on_u.values), marker="x", color="black", zorder=2
)
ax.scatter(u.values, np.ones_like(u.values), marker=">", color="tab:blue", zorder=2)
ax.xaxis.set_major_locator(MultipleLocator(1))

ax.set(xlim=[-0.5, ncols + 0.5])
ax.grid(which="major", zorder=1)
plt.legend()
plt.tight_layout()
plt.show()
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
<Figure size 400x200 with 1 Axes>

2D Arakawa C-Grid

In this example, we will look at the classic Arakawa C-Grid. This seems to be the standard for using standard finite difference methods on gridded domains. All of the variables are on staggered domains which means that we will have to do some small transformations to put each variable on other domains.

In this example, we'll look at two variables along with their corresponding velocities. The inspiration comes from the QG equation where we have

tq+uq=0q=Hψ\begin{aligned} \partial_t q + \vec{u}\cdot q &= 0 \\ q = \boldsymbol{\nabla}_H \psi \end{aligned}

where we have four variables we need to handle in total:

  • qq - potential vorticity on the cell faces
  • uu - zonal velocity on the Top-Bottom cell edges
  • vv - meridional velocity on the East-West cell edges
  • ψ\psi - stream function on the cell nodes.
n_rows = 3
n_cols = 5

# Stream Function Domain - Cell Corners
psi_domain = Domain(xmin=(0, 0), xmax=(n_rows, n_cols), dx=(1, 1))
psi_x = Field(psi_domain.grid[..., 0], psi_domain)
psi_y = Field(psi_domain.grid[..., 1], psi_domain)

# Q Domain - cell faces
q_domain = Domain(
    xmin=(0.0, 0.0), xmax=(n_rows, n_cols), dx=(1, 1), stagger=("right", "right")
)
q_x = Field(q_domain.grid[..., 0], q_domain)
q_y = Field(q_domain.grid[..., 1], q_domain)

# U Velocity -
u_domain = Domain(
    xmin=(0, 0), xmax=(n_rows, n_cols), dx=(1, 1), stagger=("right", None)
)
u_x = Field(u_domain.grid[..., 0], u_domain)
u_y = Field(u_domain.grid[..., 1], u_domain)

# V Velocity
v_domain = Domain(
    xmin=(0, 0), xmax=(n_rows, n_cols), dx=(1, 1), stagger=(None, "right")
)
v_x = Field(v_domain.grid[..., 0], v_domain)
v_y = Field(v_domain.grid[..., 1], v_domain)


# # initialize the Q domain
# q_domain = Domain(xmin=(0,), xmax=(ncols,), dx=(1,))
# values = q_domain.grid
# q = Field(values, q_domain)


# # initialize the u-velocity domain
# u_domain = Domain(xmin=(0,), xmax=(ncols,), dx=(1,), stagger=("right",))
# values = u_domain.grid
# u = Field(values, u_domain)
psi_x.values.shape
(4, 6)
fig, ax = plt.subplots(figsize=(5, 4))

ax.scatter(
    psi_x.values.ravel(),
    psi_y.values.ravel(),
    marker="x",
    color="black",
    zorder=2,
    label="Stream Function",
)
ax.scatter(
    q_x.values.ravel(),
    q_y.values.ravel(),
    marker=".",
    color="tab:red",
    zorder=2,
    label="Vorticity",
)
ax.scatter(
    u_x.values.ravel(),
    u_y.values.ravel(),
    marker=">",
    color="tab:blue",
    zorder=2,
    label="Zonal Velocity",
)
ax.scatter(
    v_x.values.ravel(),
    v_y.values.ravel(),
    marker="^",
    color="tab:green",
    zorder=2,
    label="Meridional Velocity",
)

ax.xaxis.set_major_locator(MultipleLocator(1))
ax.xaxis.set_minor_locator(MultipleLocator(0.5))
ax.yaxis.set_major_locator(MultipleLocator(1))
ax.yaxis.set_minor_locator(MultipleLocator(0.5))

ax.set(xlim=[-1.5, n_rows + 1.5], ylim=[-1.5, n_cols + 1.5])
ax.grid(
    which="major",
    zorder=1,
)
plt.legend(fontsize=10, edgecolor="black", framealpha=1.0, facecolor="lightgray")
plt.tight_layout()
plt.show()
<Figure size 500x400 with 1 Axes>

This example seems quite complicated and possibly unnecessary because we have two variables and the velocities. This can occur in the contrived example but in many cases, we may only have 1 variable and the velocities, e.g. Shallow Water equations. In this case, we have some options, we can use the cell faces as the main variable or the cell nodes. It's up to us.

Transformations to Variable

In this first case, we will look at all transformations that can get us to the vorticity from any other variable on this grid.

ψquqvq\begin{aligned} \psi &\rightarrow q \\ u &\rightarrow q \\ v &\rightarrow q \end{aligned}
u_x_on_q = F_grid.grid_operator(u_x, (None, "right"))
u_y_on_q = F_grid.grid_operator(u_y, (None, "right"))

v_x_on_q = F_grid.grid_operator(v_x, ("right", None))
v_y_on_q = F_grid.grid_operator(v_y, ("right", None))

psi_x_on_q = F_grid.grid_operator(psi_x, ("right", "right"))
psi_y_on_q = F_grid.grid_operator(psi_y, ("right", "right"))
fig, ax = plt.subplots(figsize=(5, 4))

ax.scatter(
    psi_x_on_q.values.ravel(),
    psi_y_on_q.values.ravel(),
    marker="x",
    color="black",
    zorder=2,
    label="Stream Function",
)
ax.scatter(
    q_x.values.ravel(),
    q_y.values.ravel(),
    marker=".",
    color="tab:red",
    zorder=3,
    label="Vorticity",
)
ax.scatter(
    u_x_on_q.values.ravel(),
    u_y_on_q.values.ravel(),
    marker=">",
    color="tab:blue",
    zorder=2,
    label="Zonal Velocity",
)
ax.scatter(
    v_x_on_q.values.ravel(),
    v_y_on_q.values.ravel(),
    marker="^",
    color="tab:green",
    zorder=2,
    label="Meridional Velocity",
)

ax.xaxis.set_major_locator(MultipleLocator(1))
ax.xaxis.set_minor_locator(MultipleLocator(0.5))
ax.yaxis.set_major_locator(MultipleLocator(1))
ax.yaxis.set_minor_locator(MultipleLocator(0.5))

ax.set(xlim=[-1.5, n_rows + 1.5], ylim=[-1.5, n_cols + 1.5])
ax.grid(
    which="major",
    zorder=1,
)
plt.legend(fontsize=10, edgecolor="black", framealpha=1.0, facecolor="lightgray")
plt.tight_layout()
plt.show()
<Figure size 500x400 with 1 Axes>

Transforms to PSI Variable

In this first case, we will look at all transformations that can get us to the vorticity from any other variable on this grid.

qψuψvψ\begin{aligned} q &\rightarrow \psi \\ u &\rightarrow \psi \\ v &\rightarrow \psi \end{aligned}
u_x_on_psi = F_grid.grid_operator(u_x, ("left", None))
u_y_on_psi = F_grid.grid_operator(u_y, ("left", None))

v_x_on_psi = F_grid.grid_operator(v_x, (None, "left"))
v_y_on_psi = F_grid.grid_operator(v_y, (None, "left"))

q_x_on_psi = F_grid.grid_operator(q_x, ("left", "left"))
q_y_on_psi = F_grid.grid_operator(q_y, ("left", "left"))
fig, ax = plt.subplots(figsize=(5, 4))

ax.scatter(
    psi_x.values.ravel(),
    psi_y.values.ravel(),
    marker="x",
    color="black",
    zorder=2,
    label="Stream Function",
)
ax.scatter(
    q_x_on_psi.values.ravel(),
    q_y_on_psi.values.ravel(),
    marker=".",
    color="tab:red",
    zorder=3,
    label="Vorticity",
)
ax.scatter(
    u_x_on_psi.values.ravel(),
    u_y_on_psi.values.ravel(),
    marker=">",
    color="tab:blue",
    zorder=2,
    label="Zonal Velocity",
)
ax.scatter(
    v_x_on_psi.values.ravel(),
    v_y_on_psi.values.ravel(),
    marker="^",
    color="tab:green",
    zorder=2,
    label="Meridional Velocity",
)

ax.xaxis.set_major_locator(MultipleLocator(1))
ax.xaxis.set_minor_locator(MultipleLocator(0.5))
ax.yaxis.set_major_locator(MultipleLocator(1))
ax.yaxis.set_minor_locator(MultipleLocator(0.5))

ax.set(xlim=[-1.5, n_rows + 1.5], ylim=[-1.5, n_cols + 1.5])
ax.grid(
    which="major",
    zorder=1,
)
plt.legend(fontsize=10, edgecolor="black", framealpha=1.0, facecolor="lightgray")
plt.tight_layout()
plt.show()
<Figure size 500x400 with 1 Axes>

Transforms to U-Velocity

In this first case, we will look at all transformations that can get us to the u-velocity from any other variable on this grid.

quψuvu\begin{aligned} q &\rightarrow u \\ \psi &\rightarrow u \\ v &\rightarrow u \end{aligned}
psi_x_on_u = F_grid.grid_operator(psi_x, ("right", None))
psi_y_on_u = F_grid.grid_operator(psi_y, ("right", None))

v_x_on_u = F_grid.grid_operator(v_x, ("right", "left"))
v_y_on_u = F_grid.grid_operator(v_y, ("right", "left"))

q_x_on_u = F_grid.grid_operator(q_x, (None, "left"))
q_y_on_u = F_grid.grid_operator(q_y, (None, "left"))
fig, ax = plt.subplots(figsize=(5, 4))

ax.scatter(
    psi_x_on_u.values.ravel(),
    psi_y_on_u.values.ravel(),
    marker="x",
    color="black",
    zorder=2,
    label="Stream Function",
)
ax.scatter(
    q_x_on_u.values.ravel(),
    q_y_on_u.values.ravel(),
    marker=".",
    color="tab:red",
    zorder=3,
    label="Vorticity",
)
ax.scatter(
    u_x.values.ravel(),
    u_y.values.ravel(),
    marker=">",
    color="tab:blue",
    zorder=2,
    label="Zonal Velocity",
)
ax.scatter(
    v_x_on_u.values.ravel(),
    v_y_on_u.values.ravel(),
    marker="^",
    color="tab:green",
    zorder=2,
    label="Meridional Velocity",
)

ax.xaxis.set_major_locator(MultipleLocator(1))
ax.xaxis.set_minor_locator(MultipleLocator(0.5))
ax.yaxis.set_major_locator(MultipleLocator(1))
ax.yaxis.set_minor_locator(MultipleLocator(0.5))

ax.set(xlim=[-1.5, n_rows + 1.5], ylim=[-1.5, n_cols + 1.5])
ax.grid(
    which="major",
    zorder=1,
)
plt.legend(fontsize=10, edgecolor="black", framealpha=1.0, facecolor="lightgray")
plt.tight_layout()
plt.show()
<Figure size 500x400 with 1 Axes>

Transforms to V-Velocity

In this first case, we will look at all transformations that can get us to the v-velocity from any other variable on this grid.

qvψvuv\begin{aligned} q &\rightarrow v \\ \psi &\rightarrow v \\ u &\rightarrow v \end{aligned}
psi_x_on_v = F_grid.grid_operator(psi_x, (None, "right"))
psi_y_on_v = F_grid.grid_operator(psi_y, (None, "right"))

u_x_on_v = F_grid.grid_operator(u_x, ("left", "right"))
u_y_on_v = F_grid.grid_operator(u_y, ("left", "right"))

q_x_on_v = F_grid.grid_operator(q_x, ("left", None))
q_y_on_v = F_grid.grid_operator(q_y, ("left", None))
fig, ax = plt.subplots(figsize=(5, 4))

ax.scatter(
    psi_x_on_v.values.ravel(),
    psi_y_on_v.values.ravel(),
    marker="x",
    color="black",
    zorder=2,
    label="Stream Function",
)
ax.scatter(
    q_x_on_v.values.ravel(),
    q_y_on_v.values.ravel(),
    marker=".",
    color="tab:red",
    zorder=3,
    label="Vorticity",
)
ax.scatter(
    u_x_on_v.values.ravel(),
    u_y_on_v.values.ravel(),
    marker=">",
    color="tab:blue",
    zorder=2,
    label="Zonal Velocity",
)
ax.scatter(
    v_x.values.ravel(),
    v_y.values.ravel(),
    marker="^",
    color="tab:green",
    zorder=2,
    label="Meridional Velocity",
)

ax.xaxis.set_major_locator(MultipleLocator(1))
ax.xaxis.set_minor_locator(MultipleLocator(0.5))
ax.yaxis.set_major_locator(MultipleLocator(1))
ax.yaxis.set_minor_locator(MultipleLocator(0.5))

ax.set(xlim=[-1.5, n_rows + 1.5], ylim=[-1.5, n_cols + 1.5])
ax.grid(
    which="major",
    zorder=1,
)
plt.legend(fontsize=10, edgecolor="black", framealpha=1.0, facecolor="lightgray")
plt.tight_layout()
plt.show()
<Figure size 500x400 with 1 Axes>