QG Inversion Schemes
import autoroot
import jax
import jax.numpy as jnp
from jax.config import config
import numpy as np
import numba as nb
import pandas as pd
import equinox as eqx
import finitediffx as fdx
import diffrax as dfx
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm, trange
from jaxtyping import Array, Float
import wandb
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
config.update("jax_enable_x64", True)
%matplotlib inline
%load_ext autoreload
%autoreload 2
Read input SSH¶
ds = xr.open_dataset(
"/Users/eman/code_projects/data/scratch/NATL60_GULFSTREAM_degraded.nc"
)
ds = xr.open_dataset(
"/Users/eman/code_projects/data/scratch/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc",
decode_times=False,
).assign_coords(time=lambda ds: pd.to_datetime(ds.time))
# ds = ds.coarsen(lon=3,lat=3).mean()
ds
Loading...
from jaxsw._src.domain.base import Domain
import jaxsw._src.models.qg.qg_r as qg
lon = ds.lon.values
lat = ds.lat.values
ssh = ds.ssh[0].values.T
dx, dy, f = qg.lat_lon_deltas(lon, lat)
nx, ny = len(lon), len(lat)
dx = dy = jnp.mean(jnp.asarray([dx, dy]))
f0 = np.asarray(np.mean(f))
dt = 600 # 10 mins / 600 s
c1 = 2.7 # 25_000 #1.5
g = 9.91
tol = 1e-15
n_iterations = 144
print(f0, c1, g)
nx = lon.size
ny = lat.size
8.96745305945707e-05 2.7 9.91
plt.figure()
plt.pcolormesh(ssh.T)
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x294deca90>
def enforce_boundaries_helmholtz(u, u_bc, kappa):
u = u.at[0, :].set(-kappa * u_bc[0, :])
u = u.at[-1, :].set(-kappa * u_bc[-1, :])
u = u.at[:, 0].set(-kappa * u_bc[:, 0])
u = u.at[:, -1].set(-kappa * u_bc[:, -1])
return u
from jaxsw._src.utils.dst_solver import inverse_elliptical_dst_solver
g = 9.91
c1 = 1.5
kappa = (f0 / c1) ** 2
ssh = jnp.copy(ssh)
ssh_bv = jnp.copy(ssh)
# ssh -> psi
psi = qg.ssh_to_streamfn(ssh, f0)
# psi -> pv
q_ref = qg.streamfn_to_pvort(psi, dx, dy, f0=f0, c1=c1, accuracy=1)
q_ref = enforce_boundaries_helmholtz(q_ref, psi, kappa=kappa)
# do interior case
psi_bv = psi.at[1:-1, 1:-1].set(0.0)
q_bv = qg.streamfn_to_pvort(psi_bv, dx, dy, f0=f0, c1=c1, accuracy=1)
q_bv = enforce_boundaries_helmholtz(q_bv, psi_bv, kappa=kappa)
# remove interior
q_in = q_ref[1:-1, 1:-1] - q_bv[1:-1, 1:-1]
# do the inversion
inv = inverse_elliptical_dst_solver(q_in, nx, ny, dx, dy, kappa)
psi_rec = psi.at[1:-1, 1:-1].set(inv)
ssh_rec = qg.streamfn_to_ssh(psi_rec, f0=f0)
print(np.max(np.abs(ssh_rec - ssh)))
1.4432899320127035e-15
fig, ax = plt.subplots(ncols=3, figsize=(12, 3))
pts = ax[0].imshow(ssh.T, cmap="viridis")
ax[0].set(title=r"$\eta$")
plt.colorbar(pts)
pts = ax[1].imshow(psi.T, cmap="viridis")
ax[1].set(title=r"$\psi$")
plt.colorbar(pts)
pts = ax[2].imshow(q_ref.T, cmap="viridis")
ax[2].set(title=r"$q_I$")
plt.colorbar(pts)
plt.tight_layout()
plt.show()
fig, ax = plt.subplots(ncols=3, figsize=(12, 3))
pts = ax[0].imshow(ssh.T, cmap="viridis")
plt.colorbar(pts)
pts = ax[1].imshow(ssh_rec.T, cmap="viridis")
plt.colorbar(pts)
pts = ax[2].imshow(jnp.abs(ssh_rec.T - ssh.T), cmap="Reds")
plt.colorbar(pts)
plt.tight_layout()
plt.show()
from typing import Optional
Conjugate Gradient¶
lon = ds.lon.values
lat = ds.lat.values
ssh = jnp.asarray(ds.ssh[0].values.T)
dx, dy, f = qg.lat_lon_deltas(lon, lat)
nx, ny = len(lon), len(lat)
# dx = dy = jnp.mean(jnp.asarray([dx, dy]))
f0 = np.asarray(np.mean(f))
dt = 600 # 10 mins / 600 s
c1 = 2.7 # 25_000 #1.5
g = 9.91
tol = 1e-15
n_iterations = 144
print(f0, c1, g)
nx = lon.size
ny = lat.size
8.96745305945707e-05 2.7 9.91
g = 9.91
c1 = 1.5
kappa = (f0 / c1) ** 2
ssh = jnp.copy(ssh)
# ssh -> psi
psi = qg.ssh_to_streamfn(ssh, f0)
# psi -> pv
q = qg.streamfn_to_pvort(psi, dx, dy, f0=f0, c1=c1, accuracy=1)
q = enforce_boundaries_helmholtz(q, psi, kappa=kappa)
fig, ax = plt.subplots(ncols=3, figsize=(12, 3))
pts = ax[0].imshow(ssh.T, cmap="viridis")
plt.colorbar(pts)
pts = ax[1].imshow(psi.T, cmap="viridis")
plt.colorbar(pts)
pts = ax[2].imshow(q.T, cmap="viridis")
plt.colorbar(pts)
plt.tight_layout()
plt.show()
from jaxopt import linear_solve
from jaxsw._src.utils.linear_solver import conjugate_gradient, steepest_descent
def pv_to_streamfn(
q: Array,
psi_bc,
dx,
dy,
f0: float = 1e-5,
c1: float = 2.7,
tol: float = 1e-5,
maxiters: int = 100,
accuracy: int = 1,
) -> Array:
kappa = (f0 / c1) ** 2
# define matrix multiplcation term
def matvec_Lp(psi):
# psi = enforce_boundaries_psi(psi, psi_bc)
q = qg.streamfn_to_pvort(psi, dx, dy, f0=f0, c1=c1, accuracy=accuracy)
q = enforce_boundaries_helmholtz(q, psi_bc, kappa)
return q
pnew = linear_solve.solve_cg(matvec=matvec_Lp, b=q, tol=tol, maxiter=maxiters)
return pnew
def enforce_boundaries_zero(u):
u = u.at[0, :].set(0.0)
u = u.at[-1, :].set(0.0)
u = u.at[:, 0].set(0.0)
u = u.at[:, -1].set(0.0)
return u
def enforce_boundaries_psi(u, u_bc):
u = u.at[0, :].set(u_bc[0, :])
u = u.at[-1, :].set(u_bc[-1, :])
u = u.at[:, 0].set(u_bc[:, 0])
u = u.at[:, -1].set(u_bc[:, -1])
return u
# calculate stream function
tol = 1e-15
maxiters = 10_000
psi_rec = pv_to_streamfn(
q=q,
psi_bc=psi,
dx=dx,
dy=dy,
f0=f0,
c1=c1,
tol=tol,
maxiters=maxiters,
accuracy=1,
)
# psi_rec = pv_to_streamfn(
# q=q_in, psi_bc=psi[1:-1,1:-1], dx=dx[1:-1,1:-1], dy=dy[1:-1,1:-1],
# f0=f0,c1=c1,
# tol=tol, maxiters=maxiters, accuracy=1)
# psi_rec = psi.at[1:-1,1:-1].set(psi_rec)
ssh_rec = qg.streamfn_to_ssh(psi_rec, f0=f0)
print(np.max(np.abs(ssh_rec - ssh)))
8095.946362564549
fig, ax = plt.subplots(ncols=3, figsize=(12, 3))
pts = ax[0].imshow(ssh.T, cmap="viridis")
plt.colorbar(pts)
pts = ax[1].imshow(ssh_rec.T, cmap="viridis")
plt.colorbar(pts)
pts = ax[2].imshow(jnp.abs(ssh_rec.T - ssh.T), cmap="Reds")
plt.colorbar(pts)
plt.tight_layout()
plt.show()