Free-Run QG with SSH

Authors
Affiliations
J. Emmanuel Johnson
CNRS
MEOM
Takaya Uchida
FSU
import autoroot
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax.config import config
import numpy as np
import numba as nb
import pandas as pd
import equinox as eqx
import finitediffx as fdx
import diffrax as dfx
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
from einops import rearrange, repeat, reduce
from tqdm.notebook import tqdm, trange
from jaxtyping import Array, Float

from jaxsw._src.operators.functional import advection as F_adv
from jaxsw._src.operators.functional import geostrophic as F_geos

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
config.update("jax_enable_x64", True)


%matplotlib inline
%load_ext autoreload
%autoreload 2

In this problem, we are looking at sea surface height (SSH) in relation to the Quasi-Geostrophic (QG) equations. These equations are a simplified form for the Navier-Stokes equations with approximations like hydrostatic approximation, small aspect ratio, and a small Rossby number. Ideally, these equations might be a decent approximation at mesoscale (100km - 10,000km). In this application, we will see how SSH propagates with the QG equations.

Equations

This is a very simplified equation but

tq=detJ(ψ,q)βxψψ=gf0ηq=2ψf02c12ψψ=f0gηu=yψv=xψf=2Ωsinθ0+1R2Ωcosθ0yf0=E[f]LR=c1f0\begin{aligned} \partial_t q &= - \det\boldsymbol{J}(\psi,q) - \beta\partial_x\psi \\ \psi &= \frac{g}{f_0}\eta \\ q &= \nabla^2 \psi - \frac{f_0^2}{c_1^2}\psi \\ \psi &= \frac{f_0}{g}\eta \\ u &= -\partial_y\psi \\ v &= \partial_x\psi \\ f &= 2\Omega\sin\theta_0 + \frac{1}{R}2\Omega\cos\theta_0 y \\ f_0 &= \mathcal{E}[f] \\ L_R &= \frac{c_1}{f_0} \end{aligned}

where:

  • θ0\theta_0 is the mean latitude
  • f0=2Ωsinθ0f_0=2\Omega\sin\theta_0 is the Coriolis parameter at mean latitude
  • β=1R2Ωcosθ0\beta=\frac{1}{R}2\Omega\cos\theta_0 is the β\beta-plane approximation at mean latitude
  • LRL_R is the Rossby deformation radius
  • Ω\Omega is the angular frequency of rotation
  • RR is the radius of the Earth

Source:

Read input SSH

file = "/gpfswork/rech/yrf/commun/data_challenges/dc20a_osse/staging/natl60/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc"
file = "/Users/eman/code_projects/data/scratch/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc"


# ds = xr.open_dataset('/Users/eman/code_projects/data/scratch/NATL60_GULFSTREAM_degraded.nc')
# ds = xr.open_dataset('/Users/eman/code_projects/data/scratch/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc', decode_times=False).assign_coords(time=lambda ds: pd.to_datetime(ds.time))
# ds = xr.open_dataset("/Users/eman/code_projects/data/scratch/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.decoded.nc")
ds = xr.open_dataset(
    file,
    decode_times=False,
).assign_coords(time=lambda ds: pd.to_datetime(ds.time))
# ds = ds.coarsen(lon=3,lat=3).mean()
ds
Loading...
ds.ssh.isel(time=0).plot.pcolormesh(cmap="viridis")
<matplotlib.collections.QuadMesh at 0x14e518710>
<Figure size 640x480 with 2 Axes>
lon = ds.lon.values
lat = ds.lat.values
ssh = ds.ssh[0].values.T
plt.figure()
plt.pcolormesh(ssh.T)
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x14e644290>
<Figure size 640x480 with 2 Axes>

Strategy

This is a slightly different problem than some of the previous problems.

Lat/Lon Domain. Our domain is in

SSH. In this case, we have sea surface height "observations" but they are not actually used within the QG equations. So we need to do a transformation into the QG domain which is in terms of the stream function, ψ\psi, and the potential vorticity, qq.

Domain

Already, we have to do something slightly different than the previous tutorials. We are dealing with latitude/longitude so we need to do a coordinate transformation to x,y space, i.e. a local tangent plane. I have a specialized tutorial about how we do this in practice which can be found here (TODO).

from jaxsw._src.domain.latlon import LatLonMeanDomain
lon = ds.lon.values
lat = ds.lat.values
ssh = jnp.asarray(ds.ssh[0].values.T)

domain = LatLonMeanDomain(lat=lat, lon=lon)

assert domain.size == (lon.shape[0], lat.shape[0])

domain.size, ssh.shape
((201, 201), (201, 201))

State

So here, we need to keep track of the following state variables:

η=η(x)ψ=ψ(x)q=q(x)xΩ,Ω\begin{aligned} \eta=\boldsymbol{\eta}(\vec{\mathbf{x}}) && \psi=\boldsymbol{\psi}(\vec{\mathbf{x}}) && q = \boldsymbol{q}(\vec{\mathbf{x}}) && && \vec{\mathbf{x}}\in\Omega,\partial\Omega \end{aligned}

where η\eta is the Sea Surface Height, ψ\psi is the stream function, and qq is the potential vorticity. All of them are along the same domain.

We also have some constants that we need to keep track of:

f0RβR\begin{aligned} f_0\in\mathbb{R} && \beta\in\mathbb{R} \end{aligned}

where f0f_0 is the coriolis parameter and β\beta is the beta-plane approximation at mean latitudes.

Initial Conditions

So these initial conditions are going to be slightly different than the previous cases: 1) we will initialize with real data and 2) we will have to "precalculate" some of the variables with the formulas listed above. So, let's come back to this until after we define all of the functions necessary.

Stream Function

Here, we can directly relate the

ψ=gfη\psi = \frac{g}{f}\eta
# F_geos.ssh_to_streamfn??
# F_geos.streamfn_to_ssh??
# forward transformation
psi = F_geos.ssh_to_streamfn(jnp.asarray(ssh), f0=domain.f0)

# inverse transform
ssh_ = F_geos.streamfn_to_ssh(psi, f0=domain.f0)

np.testing.assert_array_almost_equal(ssh, ssh_)

Potential Vorticity

q=2ψf02c12ψq = \nabla^2\psi - \frac{f_0^2}{c_1^2}\psi
from jaxsw._src.boundaries.helmholtz import enforce_boundaries_helmholtz
# streamfn_to_pv??
c1 = 1.5
q = F_geos.streamfn_to_pvort(
    psi, dx=domain.dx_mean, dy=domain.dx_mean, f0=domain.f0, c1=c1, accuracy=1
)
# q = enforce_boundaries_helmholtz(q, psi, beta=(f0/c1)**2)

Stream Function from Potential Vorticity

psi_rec = F_geos.pvort_to_streamfn(
    q,
    psi,
    dx=domain.dx_mean,
    dy=domain.dx_mean,
    f0=domain.f0,
    c1=c1,
    accuracy=1,
)

print(np.max(np.abs(psi_rec - psi)))

ssh_rec = F_geos.streamfn_to_ssh(psi_rec, f0=domain.f0)
# np.testing.assert_array_almost_equal(psi, psi_rec)

print(np.max(np.abs(ssh_rec - ssh)))
1.7462298274040222e-10
1.5543122344752192e-15

RHS

We have the advection term which is the dot product between the geostrophic velocities and the gradient of the

(uv)q=uxq+vyq\begin{aligned} \begin{pmatrix} u \\ v \end{pmatrix} \cdot \nabla q &= u \partial_x q + v \partial_y q \end{aligned}

if we plug in the actual u,vu,v terms, we get

(uv)q=yψyq+xψyq=xψyqyψyq\begin{aligned} \begin{pmatrix} u \\ v \end{pmatrix} \cdot \nabla q &= - \partial_y\psi\partial_y q + \partial_x\psi\partial_y q \\ &= \partial_x\psi\partial_y q - \partial_y\psi\partial_y q \end{aligned}

Note, there is a famous common term that is used for a short hand notation. It's call the determinant Jacobian which is denoted as:

detJ(ψ,q)=ψxqyψyqx=x(ψqy)y(ψqx)=y(qψx)x(qψy)\begin{aligned} \det J(\psi, q) &= \frac{\partial\psi}{\partial_x}\frac{\partial q}{\partial_y} - \frac{\partial\psi}{\partial_y}\frac{\partial q}{\partial_x} \\ &= \frac{\partial}{\partial_x}\left(\psi\frac{\partial q}{\partial_y}\right) - \frac{\partial}{\partial_y}\left(\psi\frac{\partial q}{\partial_x}\right) \\ &= \frac{\partial}{\partial_y}\left(q\frac{\partial\psi}{\partial_x}\right) - \frac{\partial}{\partial_x}\left(q\frac{\partial\psi}{\partial_y}\right) \end{aligned}

We can see that this is exactly equal to the above expression when we set u=yψu=-\partial_y\psi and v=xψv=\partial_x\psi

detJ(ψ,q)=vqy+uqx\begin{aligned} \det J(\psi, q) &= v\frac{\partial q}{\partial_y} + u \frac{\partial q}{\partial_x} \\ \end{aligned}

Upwind Scheme

This term is an advection equation. So we can use the upwind scheme to deal with this.

uxq:=u+Dxq+uDx+qvyq:=v+Dyq+vDy+q\begin{aligned} u\partial_x q := u^+ D_x^- q + u^- D_x^+ q \\ v\partial_y q := v^+ D_y^- q + v^- D_y^+ q \\ \end{aligned}

where:

  • u+,v+=u>0,v>0u^+,v^+ = u>0,v>0
  • u,v=u<0,v<0u^-,v^-=u<0,v<0
  • Dx,DyD_x^-, D_y^- is a backward finite difference scheme
  • Dx+,DyD_x^+, D_y^- is a forward finite difference scheme

Source: Wikipedia | Blog

Plus/Minus

u-velocity:u+=max(u,0)u=min(u,0)v-velocity:v+=max(v,0)v=min(v,0)\begin{aligned} \text{u-velocity}: && && u^+ &= \max(u,0) && && u^- = \min(u,0) \\ \text{v-velocity}: && && v^+ &= \max(v,0) && && v^- = \min(v,0) \end{aligned}
# upwind_2D??
def advection_term_upwind(q, psi, dx, dy, **kwargs) -> Array:
    # u,v schemes
    u, v = F_geos.uv_velocity(psi, dx=dx, dy=dy, accuracy=kwargs.get("accuracy", 1))

    # # Average Center Pixel
    # u = F_grid.center_average_2D(u, padding="same")
    # v = F_grid.center_average_2D(v, padding="same")

    # # Average East-West & North-South
    # u = F_grid.x_average_2D(u, padding="same")
    # v = F_grid.y_average_2D(v, padding="same")

    # print(u_avg.shape, v_avg.shape, q.shape)

    adv_term = F_adv.advection_upwind_2D(q, a=u, b=v, step_size=(dx, dy))

    rhs = jnp.zeros_like(q)

    # rhs = rhs.at[1:-1, 1:-1].set(udq_dx[1:-1, 1:-1] + vdq_dy[1:-1, 1:-1])
    rhs = rhs.at[1:-1, 1:-1].set(adv_term[1:-1, 1:-1])

    return rhs


rhs = advection_term_upwind(
    q, psi, domain.dx_mean, domain.dx_mean, method="central", way=-1
)

print(rhs.min(), rhs.max())
-1.3750228959774459e-07 1.3953795940515203e-07
fig, ax = plt.subplots()

pts = ax.pcolormesh(rhs.T)
plt.colorbar(pts)

plt.show()
<Figure size 640x480 with 2 Axes>

(Initial) State

Again, like all previous tutorials, we will use a nifty "state container" to help us keep track as we pass through the spatial discretizations and ODE solver.

from jaxsw._src.domain.latlon import LatLonMeanDomain
lon = ds.lon.values
lat = ds.lat.values
ssh = jnp.asarray(ds.ssh[0].values.T)

domain = LatLonMeanDomain(lat=lat, lon=lon)

assert domain.size == (lon.shape[0], lat.shape[0])

domain.dx
(Array(4381.27743094, dtype=float64), Array(5559.92086515, dtype=float64))
from typing import Optional, NamedTuple, Callable, Union
from jaxtyping import Array
from jaxsw._src.domain.base import Domain
from jaxsw._src.domain.latlon import LatLonMeanDomain


class StateParams(NamedTuple):
    domain: Domain
    f0: float
    beta: float
    c1: float
    eta: Array


class State(NamedTuple):
    q: Array

    @classmethod
    def init_state(cls, da: xr.DataArray, c1: float = 1.5):
        da = da.transpose("lon", "lat")
        lon = da.lon.values
        lat = da.lat.values

        domain = LatLonMeanDomain(lat=lat, lon=lon)

        # initialize parameters
        eta = jnp.asarray(da.values)
        f0 = domain.f0
        beta = domain.beta

        # ssh --> stream function
        psi = F_geos.ssh_to_streamfn(eta, f0=f0)

        # stream function --> potential vorticity
        q = F_geos.streamfn_to_pvort(
            psi, dx=domain.dx_mean, dy=domain.dx_mean, f0=f0, c1=c1, accuracy=1
        )
        q = enforce_boundaries_helmholtz(q, psi, beta=(f0 / c1) ** 2)

        # initialize state parameters
        state_params = StateParams(c1=c1, domain=domain, f0=f0, beta=beta, eta=eta)

        return cls(q=q), state_params

    def update_state(state, **kwargs):
        return State(
            q=kwargs.get("q", state.q),
        )
c1 = 1.5
f0 = domain.f0
beta = domain.beta

state, params = State.init_state(ds.ssh[0], c1=c1)

assert state.q.shape == ds.ssh[0].shape
assert params.c1 == 1.5

state_new = State.update_state(state, q=state.q + 1)
np.testing.assert_array_equal(state_new.q, state.q + 1)

Equation of Motion

tq=detJ(ψ,q)βxψψ=gf0η\begin{aligned} \partial_t q &= - \det\boldsymbol{J}(\psi,q) - \beta\partial_x\psi \\ \psi &= \frac{g}{f_0}\eta \end{aligned}
SSH, PV:ηn,qn=SF:ψbn=f(ηn)SF:ψn=LinearSolve(qn,ψbn)RHS:rhsn=rhs(ψn,qn)PV:qn+1=qn+Δtrhs\begin{aligned} \text{SSH, PV:} &&\eta^n,q^n &= \ldots \\ \text{SF:} && \psi^n_b &= f(\eta^n) \\ \text{SF:} && \psi^{n} &= \text{LinearSolve}(q^{n}, \psi^n_b) \\ \text{RHS:} && rhs^n &= \mathbf{rhs}(\psi^n, q^n) \\ \text{PV:} && q^{n+1} &= q^{n} + \Delta t \hspace{1mm} rhs \end{aligned}
# from jaxsw._src.operators.functional.fd import jacobian
from jaxsw._src.domain.base import Domain
from jaxsw._src.models.pde import DynamicalSystem
from jaxsw._src.domain.time import TimeDomain


class QG(DynamicalSystem):
    @staticmethod
    def equation_of_motion(t: float, state: State, args):
        """Quasi-Geostrophic Equations

        Equation:
            ∂q/∂t + det J(Ψ,q) = -β ∂Ψ/∂x
            q = ∇²Ψ - (f₀²/c₁²) Ψ
            Ψ = (f₀/g) η
        """
        # parse params
        params = args
        dx = dy = params.domain.dx_mean
        f0, beta, c1, eta = params.f0, params.beta, params.c1, params.eta

        # print("Before:", state.q.min(), state.q.max())

        # parse state
        q = state.q

        # ssh -> stream function
        psi_bv = F_geos.ssh_to_streamfn(ssh=eta, f0=f0)

        # potential vorticity -> stream function
        psi = F_geos.pvort_to_streamfn(
            q, psi_bv, dx=dx, dy=dy, f0=f0, c1=c1, accuracy=1
        )

        # upwind scheme for advection
        q_rhs = -advection_term_upwind(
            q=q, psi=psi, dx=dx, dy=dy, way=-1, method="central"
        )
        #         q_rhs = -jacobian(p=psi, q=q, dx=dx, dy=dy)

        #         q_rhs += 10 * fdx.laplacian(q, step_size=(dx,dy))

        # beta term
        _, v = F_geos.uv_velocity(psi, dx=dx, dy=dy)

        q_rhs += -beta * v

        # update state
        state = State.update_state(state, q=q_rhs)

        # print("After:", state.q.min(), state.q.max())

        return state

    @staticmethod
    def ssh_from_state(state, params, axis: int = -1) -> Array:
        dx = dy = params.domain.dx_mean
        f0, c1, eta = params.f0, params.c1, params.eta
        q = state.q

        if q.ndim > 2:
            q = q[axis]

        psi_bv = F_geos.ssh_to_streamfn(ssh=eta, f0=f0)

        psi = F_geos.pvort_to_streamfn(
            q, psi_bv, dx=dx, dy=dy, f0=f0, c1=c1, accuracy=1
        )

        return F_geos.streamfn_to_ssh(psi, f0=domain.f0)
# SPATIAL DISCRETIZATION
# initialize state
state_init, params = State.init_state(ds.ssh[0], c1=c1)

# right hand side
state_out = QG.equation_of_motion(0, state_init, params)
# plot_state_2D(state_out)
# plot_state_3D(state_out)
# plot_vectorfield_2D(state_out)
fig, ax = plt.subplots()

pts = ax.imshow(state_out.q.T, origin="lower")
plt.colorbar(pts)

plt.show()
<Figure size 640x480 with 2 Axes>

Time Stepping with Diffrax (All together)

dt_ = np.timedelta64(1, "h")
pd.to_timedelta(dt_).total_seconds()
3600.0
# TEMPORAL DISCRETIZATION
# initialize temporal domain


num_minutes = 30
dt = pd.to_timedelta(num_minutes, unit="minutes").total_seconds()
debug = False
num_days = 5
tmin = 0.0
tmax = pd.to_timedelta(num_days, unit="days").total_seconds()

num_hours_save = 12
dt_save = pd.to_timedelta(num_hours_save, unit="hours").total_seconds()

print(f"Step Size (dt): {int(dt/60)} minutes")
print(f"Tmax: {num_days} days")
print(f"Tmax: {tmax:,} seconds | {int(tmax/60/60/24)} day(s)")
Step Size (dt): 30 minutes
Tmax: 5 days
Tmax: 432,000.0 seconds | 5 day(s)
# time domain
t_domain = TimeDomain(tmin=tmin, tmax=tmax, dt=dt)


t_domain
TimeDomain(tmin=0.0, tmax=432000.0, dt=1800.0)
num_hours_save = 6
dt_save = pd.to_timedelta(num_hours_save, unit="hours").total_seconds()

ts = jnp.arange(tmin, tmax, dt_save)
saveat = dfx.SaveAt(ts=ts)
# SPATIAL DISCRETIZATION
# initialize state
solver = dfx.Euler()  #
# max_steps = 10
# DYNAMICAL SYSTEM
dyn_model = QG(t_domain=t_domain, saveat=saveat, solver=solver)
%%time

state_init, params = State.init_state(ds.ssh[0], c1=c1)

state_sol = dyn_model.integrate(state_init, dt, params, max_steps=10_000)
CPU times: user 4.69 s, sys: 153 ms, total: 4.84 s
Wall time: 4.78 s
state_sol.q.shape
(20, 201, 201)
ssh_t = QG.ssh_from_state(state_sol, params, axis=-1)

ssh_t.min(), ssh_t.max()
(Array(-0.33095231, dtype=float64), Array(1.16310116, dtype=float64))
fig, ax = plt.subplots(ncols=3, nrows=2, figsize=(10, 5))

#
ds.ssh[0].plot.pcolormesh(ax=ax[0, 0], cmap="viridis")
ds.ssh[0].plot.contour(ax=ax[0, 0], cmap="viridis")

ax[0, 0].set(title=f"SSH ($t=0$ days)")
ax[0, 1].contour(ds.ssh[num_days].values, cmap="viridis")
pts = ax[0, 1].imshow(ds.ssh[num_days].values, origin="lower", cmap="viridis")
plt.colorbar(pts)
ax[0, 1].set(title=f"True SSH ($t={num_days}$ days)")

ax[0, 2].contour(np.flipud(ssh_t.T), cmap="viridis")
pts = ax[0, 2].imshow(np.flipud(ssh_t.T), cmap="viridis")
ax[0, 2].set(title=f"QG Model ($t={num_days}$ days)")
plt.colorbar(pts)


pts = ax[1, 0].imshow(np.abs(ds.ssh[0].values.T - ds.ssh[num_days].values), cmap="Reds")
plt.colorbar(pts)
ax[1, 0].set(title=f"Diff SSH (T - $t_0$)")

pts = ax[1, 1].imshow(np.abs(ds.ssh[0].values.T - ssh_t.T), cmap="Reds")
plt.colorbar(pts)
ax[1, 1].set(title=f"Diff QG (T - $t_0$)")


pts = ax[1, 2].imshow(np.abs(ds.ssh[num_days].values - ssh_t.T), cmap="Reds")
plt.colorbar(pts)
ax[1, 2].set(title=f"Diff QG vs SSH (@T={num_days})")


plt.tight_layout()
plt.show()
<Figure size 1000x500 with 12 Axes>
References
  1. Pedlosky, J. (1987). Geophysical Fluid Dynamics. Springer New York. 10.1007/978-1-4612-4650-3
  2. Vallis, G. K. (2017). Atmospheric and Oceanic Fluid Dynamics. Cambridge University Press. 10.1017/9781107588417