Isotropic PSD#

import sys, os

# spyder up to find the root
oceanbench_root = "/gpfswork/rech/cli/uvo53rl/projects/oceanbench"

# append to path
sys.path.append(str(oceanbench_root))
import autoroot
import typing as tp
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.random as jrandom
import numpy as np
import numba as nb
import pandas as pd
import equinox as eqx
import kernex as kex
import finitediffx as fdx
import diffrax as dfx
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm, trange
from jaxtyping import Float, Array, PyTree, ArrayLike
import wandb
from omegaconf import OmegaConf
import hydra
import metpy
from sklearn.pipeline import Pipeline
from jejeqx._src.transforms.dataframe.spatial import Spherical2Cartesian
from jejeqx._src.transforms.dataframe.temporal import TimeDelta
from jejeqx._src.transforms.dataframe.scaling import MinMaxDF


sns.reset_defaults()
sns.set_context(context="poster", font_scale=0.7)
jax.config.update("jax_enable_x64", False)

%matplotlib inline
%load_ext autoreload
%autoreload 2

Processing Chain#

Part I:

  • Open Dataset

  • Validate Coordinates + Variables

  • Decode Time

  • Select Region

  • Sortby Time

Part II: Regrid

Part III:

  • Interpolate Nans

  • Add Units

  • Spatial Rescale

  • Time Rescale

Part IV: Metrics

Data#

# !wget wget -nc https://s3.us-east-1.wasabisys.com/melody/osse_data/ref/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc
# !cat configs/postprocess.yaml
# # load config
# config_dm = OmegaConf.load('./configs/postprocess.yaml')

# # instantiate
# ds = hydra.utils.instantiate(config_dm.NATL60_GF_1Y1D)
# ds

Reference Dataset#

For the reference dataset, we will look at the NEMO simulation of the Gulfstream.

%%time

# load config
config_dm = OmegaConf.load("./configs/postprocess.yaml")

# instantiate
ds_natl60 = hydra.utils.instantiate(config_dm.NATL60_GF_FULL).compute()
ds_natl60

Prediction Datasets - NADIR#

%%time

# load config
results_config = OmegaConf.load("./configs/results_dc20a_nadir.yaml")

# instantiate
ds_duacs = hydra.utils.instantiate(results_config.DUACS_NADIR.data).compute()
ds_miost = hydra.utils.instantiate(results_config.MIOST_NADIR.data).compute()
ds_nerf_siren = hydra.utils.instantiate(results_config.NERF_SIREN_NADIR.data).compute()
ds_nerf_ffn = hydra.utils.instantiate(results_config.NERF_FFN_NADIR.data).compute()
ds_nerf_mlp = hydra.utils.instantiate(results_config.NERF_MLP_NADIR.data).compute()

Regrdding#

from oceanbench._src.geoprocessing.gridding import grid_to_regular_grid
%%time

ds_duacs = grid_to_regular_grid(
    src_grid_ds=ds_duacs.pint.dequantify(),
    tgt_grid_ds=ds_natl60.pint.dequantify(),
    keep_attrs=False,
)
ds_miost = grid_to_regular_grid(
    src_grid_ds=ds_miost.pint.dequantify(),
    tgt_grid_ds=ds_natl60.pint.dequantify(),
    keep_attrs=False,
)
ds_nerf_siren = grid_to_regular_grid(
    src_grid_ds=ds_nerf_siren.pint.dequantify(),
    tgt_grid_ds=ds_natl60.pint.dequantify(),
    keep_attrs=False,
)
ds_nerf_ffn = grid_to_regular_grid(
    src_grid_ds=ds_nerf_ffn.pint.dequantify(),
    tgt_grid_ds=ds_natl60.pint.dequantify(),
    keep_attrs=False,
)
ds_nerf_mlp = grid_to_regular_grid(
    src_grid_ds=ds_nerf_mlp.pint.dequantify(),
    tgt_grid_ds=ds_natl60.pint.dequantify(),
    keep_attrs=False,
)

Preprocess Chain#

%%time

# load config
psd_config = OmegaConf.load("./configs/metrics.yaml")

ds_natl60 = hydra.utils.instantiate(psd_config.psd_preprocess_chain)(
    ds_natl60.pint.dequantify()
)
ds_duacs = hydra.utils.instantiate(psd_config.psd_preprocess_chain)(
    ds_duacs.pint.dequantify()
)
ds_miost = hydra.utils.instantiate(psd_config.psd_preprocess_chain)(
    ds_miost.pint.dequantify()
)
ds_nerf_siren = hydra.utils.instantiate(psd_config.psd_preprocess_chain)(
    ds_nerf_siren.pint.dequantify()
)
ds_nerf_ffn = hydra.utils.instantiate(psd_config.psd_preprocess_chain)(
    ds_nerf_ffn.pint.dequantify()
)
ds_nerf_mlp = hydra.utils.instantiate(psd_config.psd_preprocess_chain)(
    ds_nerf_mlp.pint.dequantify()
)

Power Spectrum (Isotropic)#

%%time

# load config
psd_config = OmegaConf.load("./configs/metrics.yaml")

ds_natl60_psd = hydra.utils.instantiate(psd_config.psd_isotropic_chain)(
    ds_natl60.pint.dequantify()
)
ds_duacs_psd = hydra.utils.instantiate(psd_config.psd_isotropic_chain)(
    ds_duacs.pint.dequantify()
)
ds_miost_psd = hydra.utils.instantiate(psd_config.psd_isotropic_chain)(
    ds_miost.pint.dequantify()
)
ds_nerf_siren_psd = hydra.utils.instantiate(psd_config.psd_isotropic_chain)(
    ds_nerf_siren.pint.dequantify()
)
ds_nerf_ffn_psd = hydra.utils.instantiate(psd_config.psd_isotropic_chain)(
    ds_nerf_ffn.pint.dequantify()
)
ds_nerf_mlp_psd = hydra.utils.instantiate(psd_config.psd_isotropic_chain)(
    ds_nerf_mlp.pint.dequantify()
)
from jejeqx._src.viz.xarray.psd import PlotPSDIsotropic, PlotPSDScoreIsotropic
psd_iso_plot = PlotPSDIsotropic()
psd_iso_plot.init_fig(figsize=(8, 7))
psd_iso_plot.plot_both(
    ds_natl60_psd.ssh,
    freq_scale=1e3,
    units="km",
    label="NATL60",
    color="black",
)
psd_iso_plot.plot_both(
    ds_duacs_psd.ssh,
    freq_scale=1e3,
    units="km",
    label="DUACS",
    color="tab:green",
)
psd_iso_plot.plot_both(
    ds_miost_psd.ssh,
    freq_scale=1e3,
    units="km",
    label="MIOST",
    color="tab:red",
)
psd_iso_plot.plot_both(
    ds_nerf_ffn_psd.ssh,
    freq_scale=1e3,
    units="km",
    label="NERF (FFN)",
    color="tab:blue",
)
psd_iso_plot.plot_both(
    ds_nerf_siren_psd.ssh,
    freq_scale=1e3,
    units="km",
    label="NERF (SIREN)",
    color="tab:olive",
)
psd_iso_plot.plot_both(
    ds_nerf_mlp_psd.ssh,
    freq_scale=1e3,
    units="km",
    label="NERF (MLP)",
    color="tab:cyan",
)

# set custom bounds
psd_iso_plot.ax.set_xlim((10 ** (-3) - 0.00025, 10 ** (-1) + 0.025))
psd_iso_plot.ax.set_ylabel("PSD [SSH]")
plt.tight_layout()
plt.gcf().savefig("./figures/dc20a/psd/isotropic/dc20a_psd_isotropic_nadir.png")
plt.show()

PSD Isotropic Score#

%%time

# load config
psd_config = OmegaConf.load("./configs/metrics.yaml")

ds_psd_duacs_score = hydra.utils.instantiate(
    psd_config.psd_isotropic_score,
    da=ds_duacs.pint.dequantify(),
    da_ref=ds_natl60.pint.dequantify(),
)

ds_psd_miost_score = hydra.utils.instantiate(
    psd_config.psd_isotropic_score,
    da=ds_miost.pint.dequantify(),
    da_ref=ds_natl60.pint.dequantify(),
)

ds_psd_nerf_mlp_score = hydra.utils.instantiate(
    psd_config.psd_isotropic_score,
    da=ds_nerf_mlp.pint.dequantify(),
    da_ref=ds_natl60.pint.dequantify(),
)

ds_psd_nerf_ffn_score = hydra.utils.instantiate(
    psd_config.psd_isotropic_score,
    da=ds_nerf_ffn.pint.dequantify(),
    da_ref=ds_natl60.pint.dequantify(),
)

ds_psd_nerf_siren_score = hydra.utils.instantiate(
    psd_config.psd_isotropic_score,
    da=ds_nerf_siren.pint.dequantify(),
    da_ref=ds_natl60.pint.dequantify(),
)
from oceanbench._src.metrics.utils import find_intercept_1D, find_intercept_2D
psd_iso_plot = PlotPSDScoreIsotropic()
psd_iso_plot.init_fig(figsize=(8, 7))
psd_iso_plot.plot_score(
    ds_psd_duacs_score.ssh,
    freq_scale=1e3,
    units="km",
    name="DUACS",
    color="green",
    threshhold=0.50,
    threshhold_color="tab:green",
)
psd_iso_plot.plot_score(
    ds_psd_miost_score.ssh,
    freq_scale=1e3,
    units="km",
    name="MIOST",
    color="red",
    threshhold=0.50,
    threshhold_color="tab:red",
)
psd_iso_plot.plot_score(
    ds_psd_nerf_mlp_score.ssh,
    freq_scale=1e3,
    units="km",
    name="NERF (MLP)",
    color="cyan",
    threshhold=0.50,
    threshhold_color="tab:cyan",
)
psd_iso_plot.plot_score(
    ds_psd_nerf_ffn_score.ssh,
    freq_scale=1e3,
    units="km",
    name="NERF (FFN)",
    color="blue",
    threshhold=0.50,
    threshhold_color="tab:blue",
)
psd_iso_plot.plot_score(
    ds_psd_nerf_siren_score.ssh,
    freq_scale=1e3,
    units="km",
    name="NERF (SIREN)",
    color="olive",
    threshhold=0.50,
    threshhold_color="tab:olive",
)

# set custom bounds
# psd_iso_plot.ax.set_xlim((10**(-3) - 0.00025, 10**(-1) +0.025))
psd_iso_plot.ax.set_ylabel("PSD Score [SSH]")
plt.legend()
plt.tight_layout()
plt.gcf().savefig(
    "./figures/dc20a/psd_score/isotropic/dc20a_psd_isotropic_score_nadir.png"
)
plt.show()

Prediction Datasets - SWOT#

%%time

# load config
config_dm = OmegaConf.load("./configs/postprocess.yaml")

# instantiate
ds_natl60 = hydra.utils.instantiate(config_dm.NATL60_GF_FULL).compute()
ds_natl60
%%time

# load config
results_config = OmegaConf.load("./configs/results_dc20a_swot.yaml")

# instantiate
ds_duacs = hydra.utils.instantiate(results_config.DUACS_SWOT.data).compute()
ds_miost = hydra.utils.instantiate(results_config.MIOST_SWOT.data).compute()
ds_nerf_siren = hydra.utils.instantiate(results_config.NERF_SIREN_SWOT.data).compute()
ds_nerf_ffn = hydra.utils.instantiate(results_config.NERF_FFN_SWOT.data).compute()
ds_nerf_mlp = hydra.utils.instantiate(results_config.NERF_MLP_SWOT.data).compute()

Regrdding#

from oceanbench._src.geoprocessing.gridding import grid_to_regular_grid
%%time

ds_duacs = grid_to_regular_grid(
    src_grid_ds=ds_duacs.pint.dequantify(),
    tgt_grid_ds=ds_natl60.pint.dequantify(),
    keep_attrs=False,
)
ds_miost = grid_to_regular_grid(
    src_grid_ds=ds_miost.pint.dequantify(),
    tgt_grid_ds=ds_natl60.pint.dequantify(),
    keep_attrs=False,
)
ds_nerf_siren = grid_to_regular_grid(
    src_grid_ds=ds_nerf_siren.pint.dequantify(),
    tgt_grid_ds=ds_natl60.pint.dequantify(),
    keep_attrs=False,
)
ds_nerf_ffn = grid_to_regular_grid(
    src_grid_ds=ds_nerf_ffn.pint.dequantify(),
    tgt_grid_ds=ds_natl60.pint.dequantify(),
    keep_attrs=False,
)
ds_nerf_mlp = grid_to_regular_grid(
    src_grid_ds=ds_nerf_mlp.pint.dequantify(),
    tgt_grid_ds=ds_natl60.pint.dequantify(),
    keep_attrs=False,
)

Preprocess Chain#

%%time

# load config
psd_config = OmegaConf.load("./configs/metrics.yaml")

ds_natl60 = hydra.utils.instantiate(psd_config.psd_preprocess_chain)(
    ds_natl60.pint.dequantify()
)
ds_duacs = hydra.utils.instantiate(psd_config.psd_preprocess_chain)(
    ds_duacs.pint.dequantify()
)
ds_miost = hydra.utils.instantiate(psd_config.psd_preprocess_chain)(
    ds_miost.pint.dequantify()
)
ds_nerf_siren = hydra.utils.instantiate(psd_config.psd_preprocess_chain)(
    ds_nerf_siren.pint.dequantify()
)
ds_nerf_ffn = hydra.utils.instantiate(psd_config.psd_preprocess_chain)(
    ds_nerf_ffn.pint.dequantify()
)
ds_nerf_mlp = hydra.utils.instantiate(psd_config.psd_preprocess_chain)(
    ds_nerf_mlp.pint.dequantify()
)

Power Spectrum (Isotropic)#

%%time

# load config
psd_config = OmegaConf.load("./configs/metrics.yaml")

ds_natl60_psd = hydra.utils.instantiate(psd_config.psd_isotropic_chain)(
    ds_natl60.pint.dequantify()
)
ds_duacs_psd = hydra.utils.instantiate(psd_config.psd_isotropic_chain)(
    ds_duacs.pint.dequantify()
)
ds_miost_psd = hydra.utils.instantiate(psd_config.psd_isotropic_chain)(
    ds_miost.pint.dequantify()
)
ds_nerf_siren_psd = hydra.utils.instantiate(psd_config.psd_isotropic_chain)(
    ds_nerf_siren.pint.dequantify()
)
ds_nerf_ffn_psd = hydra.utils.instantiate(psd_config.psd_isotropic_chain)(
    ds_nerf_ffn.pint.dequantify()
)
ds_nerf_mlp_psd = hydra.utils.instantiate(psd_config.psd_isotropic_chain)(
    ds_nerf_mlp.pint.dequantify()
)
from jejeqx._src.viz.xarray.psd import PlotPSDIsotropic, PlotPSDScoreIsotropic
psd_iso_plot = PlotPSDIsotropic()
psd_iso_plot.init_fig(figsize=(8, 7))
psd_iso_plot.plot_both(
    ds_natl60_psd.ssh,
    freq_scale=1e3,
    units="km",
    label="NATL60",
    color="black",
)
psd_iso_plot.plot_both(
    ds_duacs_psd.ssh,
    freq_scale=1e3,
    units="km",
    label="DUACS",
    color="tab:green",
)
psd_iso_plot.plot_both(
    ds_miost_psd.ssh,
    freq_scale=1e3,
    units="km",
    label="MIOST",
    color="tab:red",
)
psd_iso_plot.plot_both(
    ds_nerf_ffn_psd.ssh,
    freq_scale=1e3,
    units="km",
    label="NERF (FFN)",
    color="tab:blue",
)
psd_iso_plot.plot_both(
    ds_nerf_siren_psd.ssh,
    freq_scale=1e3,
    units="km",
    label="NERF (SIREN)",
    color="tab:olive",
)
psd_iso_plot.plot_both(
    ds_nerf_mlp_psd.ssh,
    freq_scale=1e3,
    units="km",
    label="NERF (MLP)",
    color="tab:cyan",
)

# set custom bounds
psd_iso_plot.ax.set_xlim((10 ** (-3) - 0.00025, 10 ** (-1) + 0.025))
psd_iso_plot.ax.set_ylabel("PSD [SSH]")
plt.tight_layout()
plt.gcf().savefig("./figures/dc20a/psd/isotropic/dc20a_psd_isotropic_swot.png")
plt.show()

PSD Isotropic Score#

%%time

# load config
psd_config = OmegaConf.load("./configs/metrics.yaml")

ds_psd_duacs_score = hydra.utils.instantiate(
    psd_config.psd_isotropic_score,
    da=ds_duacs.pint.dequantify(),
    da_ref=ds_natl60.pint.dequantify(),
)

ds_psd_miost_score = hydra.utils.instantiate(
    psd_config.psd_isotropic_score,
    da=ds_miost.pint.dequantify(),
    da_ref=ds_natl60.pint.dequantify(),
)

ds_psd_nerf_mlp_score = hydra.utils.instantiate(
    psd_config.psd_isotropic_score,
    da=ds_nerf_mlp.pint.dequantify(),
    da_ref=ds_natl60.pint.dequantify(),
)

ds_psd_nerf_ffn_score = hydra.utils.instantiate(
    psd_config.psd_isotropic_score,
    da=ds_nerf_ffn.pint.dequantify(),
    da_ref=ds_natl60.pint.dequantify(),
)

ds_psd_nerf_siren_score = hydra.utils.instantiate(
    psd_config.psd_isotropic_score,
    da=ds_nerf_siren.pint.dequantify(),
    da_ref=ds_natl60.pint.dequantify(),
)
from oceanbench._src.metrics.utils import find_intercept_1D, find_intercept_2D
psd_iso_plot = PlotPSDScoreIsotropic()
psd_iso_plot.init_fig(figsize=(8, 7))
psd_iso_plot.plot_score(
    ds_psd_duacs_score.ssh,
    freq_scale=1e3,
    units="km",
    name="DUACS",
    color="green",
    threshhold=0.50,
    threshhold_color="tab:green",
)
psd_iso_plot.plot_score(
    ds_psd_miost_score.ssh,
    freq_scale=1e3,
    units="km",
    name="MIOST",
    color="red",
    threshhold=0.50,
    threshhold_color="tab:red",
)
psd_iso_plot.plot_score(
    ds_psd_nerf_mlp_score.ssh,
    freq_scale=1e3,
    units="km",
    name="NERF (MLP)",
    color="cyan",
    threshhold=0.50,
    threshhold_color="tab:cyan",
)
psd_iso_plot.plot_score(
    ds_psd_nerf_ffn_score.ssh,
    freq_scale=1e3,
    units="km",
    name="NERF (FFN)",
    color="blue",
    threshhold=0.50,
    threshhold_color="tab:blue",
)
psd_iso_plot.plot_score(
    ds_psd_nerf_siren_score.ssh,
    freq_scale=1e3,
    units="km",
    name="NERF (SIREN)",
    color="olive",
    threshhold=0.50,
    threshhold_color="tab:olive",
)

# set custom bounds
# psd_iso_plot.ax.set_xlim((10**(-3) - 0.00025, 10**(-1) +0.025))
psd_iso_plot.ax.set_ylabel("PSD Score [SSH]")
plt.legend()
plt.tight_layout()
plt.gcf().savefig(
    "./figures/dc20a/psd_score/isotropic/dc20a_psd_isotropic_score_swot.png"
)
plt.show()