SpatioTemporal Fields#

import autoroot
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.random as jrandom
import numpy as np
import numba as nb
import equinox as eqx
import kernex as kex
import finitediffx as fdx
import diffrax as dfx
import xarray as xr
import metpy
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm, trange
from jaxtyping import Float, Array, PyTree, ArrayLike
import wandb
from omegaconf import OmegaConf
import hydra
from sklearn import pipeline
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from jejeqx._src.transforms.dataframe.spatial import Spherical2Cartesian
from jejeqx._src.transforms.dataframe.temporal import TimeDelta
from jejeqx._src.transforms.dataframe.scaling import MinMaxDF

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
jax.config.update("jax_enable_x64", False)

%env XLA_PYTHON_CLIENT_PREALLOCATE=false

%matplotlib inline
%load_ext autoreload
%autoreload 2

Recap Formulation#

We are interested in learning non-linear functions \(\boldsymbol{f}\).

\[ \begin{aligned} \boldsymbol{f}(\mathbf{x}) &= \mathbf{w}^\top\boldsymbol{\phi}(\mathbf{x})+\mathbf{b} \end{aligned} \]

where the \(\boldsymbol{\phi}(\cdot)\) is a basis function. Neural Fields typically try to learn this basis funciton via a series of composite functions of the form

\[ \boldsymbol{\phi}(\mathbf{x}) = \boldsymbol{\phi}_L\circ\boldsymbol{\phi}_{L-1} \circ\cdots\circ \boldsymbol{\phi}_2\circ\boldsymbol{\phi}_{1}(\mathbf{x}) \]

Problems#

Here, we will demonstrate a problem that a naive network has.

Data#

# !wget wget -nc https://s3.us-east-1.wasabisys.com/melody/osse_data/ref/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc
from pathlib import Path
Path(
    "/gpfswork/rech/cli/uvo53rl/projects/jejeqx/data/natl60/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc"
).is_file()
# @dataclass
# class Subset:
#     _target_: str = "slice"
#     _args_: List = field(default_factory=lambda :["2013-01-01", "2013-01-10"])
from dataclasses import dataclass, field
from typing import List, Dict


# @dataclass
# class SSHDM:
#     _target_: str = "jejeqx._src.datamodules.natl60.SSHSTNATL60"
#     batch_size: int = 10_000
#     shuffle: bool = False
#     split_method: str = "random"
#     train_size: float = 0.80
#     spatial_coords: List = field(default_factory=lambda : ["x", "y", "z"])
#     temporal_coords: List = field(default_factory=lambda: ["time"])
#     variables: List = field(default_factory=lambda : ["ssh"])
#     coarsen: Dict = field(default_factory=lambda : {"lon": 4, "lat": 4})
#     directory: str = "/gpfswork/rech/cli/uvo53rl/projects/jejeqx/data/natl60/"


@dataclass
class SSHDM:
    _target_: str = "jejeqx._src.datamodules.coords.AlongTrackDM"
    batch_size: int = 10_000
    shuffle: bool = False
    train_size: float = 0.80
    spatial_coords: List = field(default_factory=lambda: ["lat", "lon"])
    temporal_coords: List = field(default_factory=lambda: ["time"])
    variables: List = field(default_factory=lambda: ["ssh"])
    coarsen: Dict = field(default_factory=lambda: {"lon": 2, "lat": 2})
    decode_times: bool = False
    resample: str = "1D"
    paths: str = "/gpfswork/rech/cli/uvo53rl/projects/jejeqx/data/natl60/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc"
    # paths: str = "/gpfswork/rech/yrf/commun/data_challenges/dc20a_osse/test/dc_ref/NATL60-CJM165_GULFSTREAM*"


# spatial transform
spatial_transforms = Pipeline(
    [
        ("cartesian3d", Spherical2Cartesian(radius=1.0, units="degrees")),
        ("spatialminmax", MinMaxDF(["x", "y", "z"], -1, 1)),
    ]
)

temporal_transforms = Pipeline(
    [
        ("timedelta", TimeDelta("2012-10-01", 1, "s")),
        ("timeminmax", MinMaxDF(["time"], -1, 1)),
    ]
)
select = {"time": slice("2013-01-01", "2013-10-30")}

config_dm = OmegaConf.structured(SSHDM())

dm = hydra.utils.instantiate(
    config_dm,
    select=select,
    spatial_transform=spatial_transforms,
    temporal_transform=temporal_transforms,
)

dm.setup()


init = dm.ds_train[:32]
x_init, t_init, y_init = init["spatial"], init["temporal"], init["data"]
x_init.min(), x_init.max(), x_init.shape, t_init.min(), t_init.max(), t_init.shape
len(dm.ds_train), len(dm.ds_test)
xrda = dm.load_xrds()
# import geoviews as gv
# import geoviews.feature as gf
# from cartopy import crs

# gv.extension('bokeh', 'matplotlib')
xrda
# import geoviews as gv
# import geoviews.feature as gf
# from cartopy import crs

# gv.extension('bokeh', 'matplotlib')
# dataset = gv.Dataset(xrda_obs)
# ensemble1 = dataset.to(gv.Image, ['lon', 'lat'], "ssh")
# gv.output(ensemble1.opts(cmap='viridis', colorbar=True, fig_size=200, backend='matplotlib') * gf.coastline(),
#           backend='matplotlib')
# dataset = gv.Dataset(xrda)
# ensemble1 = dataset.to(gv.Image, ['lon', 'lat'], "ssh")
# ensemble2 = dataset.to(gv.Image, ['lon', 'lat'], "ssh_lmsiren")
# gv.output(ensemble1.opts(cmap='viridis', colorbar=True, fig_size=200, backend='matplotlib') * gf.coastline(),
#           backend='matplotlib')
!ls /gpfswork/rech/cli/uvo53rl/checkpoints/nerfs/siren/natl60/full/
import joblib

model_config_file = (
    "/gpfswork/rech/cli/uvo53rl/checkpoints/nerfs/siren/natl60/full/config.pkl"
)
checkpoint_file = "/gpfswork/rech/cli/uvo53rl/checkpoints/nerfs/siren/natl60/full/checkpoint_model.ckpt"

old_config = joblib.load(model_config_file)

model = hydra.utils.instantiate(old_config["model"])
old_config["model"]
eqx.pprint(model)

Optimizer (+ Learning Rate)#

import optax

num_epochs = 2_000


@dataclass
class Optimizer:
    _target_: str = "optax.adam"
    learning_rate: float = 1e-4


# FINETUNE!
@dataclass
class Scheduler:
    _target_: str = "optax.warmup_cosine_decay_schedule"
    init_value: float = 0.0
    peak_value: float = 1e-2
    warmup_steps: int = 500
    end_value: float = 1e-6


scheduler_config = OmegaConf.structured(Scheduler())
optim_config = OmegaConf.structured(Optimizer())

optimizer = hydra.utils.instantiate(optim_config)

# num_steps_per_epoch = len(dm.ds_train)

# scheduler = hydra.utils.instantiate(
#     scheduler_config,
#     decay_steps=int(num_epochs * num_steps_per_epoch)
# )

# optimizer = optax.chain(optimizer, optax.scale_by_schedule(scheduler))

Trainer Module#

import glob
import os
from pathlib import Path

from jejeqx._src.trainers.base import TrainerModule
from jejeqx._src.trainers.callbacks import wandb_model_artifact
from jejeqx._src.losses import psnr


class RegressorTrainer(TrainerModule):
    def __init__(self, model, optimizer, **kwargs):
        super().__init__(model=model, optimizer=optimizer, pl_logger=None, **kwargs)

    @property
    def model(self):
        return self.state.params

    @property
    def model_batch(self):
        return jax.vmap(self.state.params, in_axes=(0, 0))

    def create_functions(self):
        @eqx.filter_value_and_grad
        def mse_loss(model, batch):
            x, t, y = batch["spatial"], batch["temporal"], batch["data"]
            pred = jax.vmap(model)(jnp.hstack([x, t]))
            # x, t, y = batch["spatial"], batch["temporal"], batch["data"]
            # pred = jax.vmap(model, in_axes=(0,0))(x, t)
            loss = jnp.mean((y - pred) ** 2)
            return loss

        def train_step(state, batch):
            loss, grads = mse_loss(state.params, batch)
            state = state.update_state(state, grads)
            psnr_loss = psnr(loss)
            metrics = {"loss": loss, "psnr": psnr_loss}
            return state, loss, metrics

        def eval_step(model, batch):
            loss, _ = mse_loss(model, batch)
            psnr_loss = psnr(loss)
            return {"loss": loss, "psnr": psnr_loss}

        def test_step(model, batch):
            x, t, y = batch["spatial"], batch["temporal"], batch["data"]
            out = jax.vmap(model)(jnp.hstack([x, t]))
            loss, _ = mse_loss(model, batch)
            psnr_loss = psnr(loss)
            return out, {"loss": loss, "psnr": psnr_loss}

        def predict_step(model, batch):
            x, t = batch["spatial"], batch["temporal"]
            out = jax.vmap(model)(jnp.hstack([x, t]))
            return out

        return train_step, eval_step, test_step, predict_step

    def on_training_end(
        self,
    ):
        if self.pl_logger:
            save_dir = Path(self.log_dir).joinpath(self.save_name)
            self.save_model(save_dir)
            wandb_model_artifact(self)
            self.pl_logger.finalize("success")
seed = 123
debug = False
enable_progress_bar = False
log_dir = "./"

trainer = RegressorTrainer(
    model,
    optimizer,
    seed=seed,
    debug=debug,
    enable_progress_bar=enable_progress_bar,
    log_dir=log_dir,
)

train_more = False
save_new = True
%%time


out, metrics = trainer.test_model(dm.test_dataloader())
metrics
# trainer.load_model("./checkpoints/checkpoint_model_stlmsiren_ssh_more.ckpt")
# trainer.load_model("./checkpoints/checkpoint_natl60_model_rff.ckpt")
trainer.load_model(checkpoint_file)
%%time


out, metrics = trainer.test_model(dm.test_dataloader())
metrics
%%time
if train_more:
    metrics = trainer.train_model(dm, num_epochs=num_epochs)
%%time

if train_more:
    out, metrics = trainer.test_model(dm.test_dataloader())
    print(metrics)
# if save_new:
#     # trainer.save_model("./checkpoints/check point_model_stlmsiren_ssh_more.ckpt")
#     trainer.save_model("./checkpoints/checkpoint_natl60_model_rff.ckpt")
# # # trainer.save_state("checkpoint_state.ckpt")
xrda["ssh_siren"] = dm.data_to_df(out).to_xarray().ssh
fig, ax = plt.subplots(ncols=2, figsize=(8, 3))

xrda.ssh.isel(time=0).plot.pcolormesh(ax=ax[0], cmap="viridis")
ax[0].set(title="Original")

xrda.ssh_siren.isel(time=0).plot.pcolormesh(ax=ax[1], cmap="viridis")
ax[1].set(title="SIREN")

plt.tight_layout()
plt.show()

Analysis#

import common_utils as cutils
ds_rff = cutils.calculate_physical_quantities(xrda.ssh_siren)
ds_natl60 = cutils.calculate_physical_quantities(xrda.ssh)
import holoviews as hv

hv.extension("matplotlib")


variable = "ssh"  # "vort_r" # "ke" #
cmap = "viridis"  # "RdBu_r" # "YlGnBu_r" #

ssh_ds = xr.Dataset(
    {
        "NATL60": ds_natl60[variable],
        "SIREN": ds_rff[variable],
    }
)

ssh_ds["NATL60"] = (("time", "lat", "lon"), ds_natl60[variable].data.magnitude)
ssh_ds["SIREN"] = (("time", "lat", "lon"), ds_rff[variable].data.magnitude)

to_plot_ds = ssh_ds.isel(time=slice(25, 55, 1)).transpose("time", "lat", "lon")

clim = (
    to_plot_ds[["NATL60", "SIREN"]]
    .to_array()
    .pipe(lambda da: (da.quantile(0.005).item(), da.quantile(0.995).item()))
)

images = (
    hv.Layout(
        [
            hv.Dataset(to_plot_ds)
            .to(hv.QuadMesh, ["lon", "lat"], v)
            .relabel(v)
            .options(cmap=cmap, clim=clim)
            for v in to_plot_ds
        ]
    )
    .cols(2)
    .opts(sublabel_format="")
)

hv.output(images, holomap="gif", fps=2, dpi=300)
fig, ax = cutils.plot_analysis_vars(
    [
        ds_natl60.isel(time=5),
        ds_rff.isel(time=5),
    ],
    figsize=(12, 25),
)
plt.show()

Simple Stats#

def rmse_da(da, da_ref, dim):
    return ((da - da_ref) ** 2).mean(dim=dim) ** 0.5


def nrmse_da(da, da_ref, dim):
    rmse = rmse_da(da=da, da_ref=da_ref, dim=dim)
    std = (da_ref**2).mean(dim=dim) ** 0.5
    return 1.0 - (rmse / std).data.magnitude
import pandas as pd

results_df = pd.DataFrame()

for imodel, iname in zip([ds_rff], ["SIREN"]):
    for ivar in imodel:
        error_space = nrmse_da(imodel[ivar], ds_natl60[ivar], ["lat", "lon", "time"])
        error_time = nrmse_da(imodel[ivar], ds_natl60[ivar], ["time"]).std()

        ires_df = pd.DataFrame(
            data=[[iname, ivar, error_space.item(), error_time.item()]],
            columns=["model", "variable", "nrmse (mu)", "nrsme (std)"],
        )

        results_df = pd.concat([ires_df, results_df.loc[:]], axis=0)
results_df
ds_psd_natl60 = cutils.calculate_isotropic_psd(ds_natl60)
ds_psd_rff = cutils.calculate_isotropic_psd(ds_rff)
%matplotlib inline

fig, ax = cutils.plot_analysis_psd_iso(
    [
        ds_psd_natl60,
        ds_psd_rff,
    ],
    [
        "NATL60",
        "RFE",
    ],
)
plt.show()
ds_scores = cutils.calculate_isotropic_psd_score(ds_rff, ds_natl60)
%matplotlib inline

cutils.plot_analysis_psd_iso_score([ds_scores], ["RFE"], ["k"])
plt.show()
ds_psd_natl60 = cutils.calculate_spacetime_psd(ds_natl60)
ds_psd_rff = cutils.calculate_spacetime_psd(ds_rff)
%matplotlib inline

fig, ax = cutils.plot_analysis_psd_spacetime(
    [
        ds_psd_natl60,
        ds_psd_rff,
    ],
    [
        "NATL60",
        "RFE",
    ],
)
plt.show()
ds_psd_rff = cutils.calculate_spacetime_psd_score(ds_rff, ds_natl60)
%matplotlib inline

_ = cutils.plot_analysis_psd_spacetime_score([ds_psd_rff], ["rff"])