OSSE NADIR#

import autoroot
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.random as jrandom
import numpy as np
import numba as nb
import equinox as eqx
import kernex as kex
import finitediffx as fdx
import diffrax as dfx
import xarray as xr
import pandas as pd
import metpy
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm, trange
from jaxtyping import Float, Array, PyTree, ArrayLike
import wandb
from omegaconf import OmegaConf
import hydra
from sklearn import pipeline
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from jejeqx._src.transforms.dataframe.spatial import Spherical2Cartesian
from jejeqx._src.transforms.dataframe.temporal import TimeDelta
from jejeqx._src.transforms.dataframe.scaling import MinMaxDF

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
# Ensure TF does not see GPU and grab all GPU memory.
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
jax.config.update("jax_enable_x64", False)

%matplotlib inline
%load_ext autoreload
%autoreload 2

Recap Formulation#

We are interested in learning non-linear functions \(\boldsymbol{f}\).

\[ \begin{aligned} \boldsymbol{f}(\mathbf{x}) &= \mathbf{w}^\top\boldsymbol{\phi}(\mathbf{x})+\mathbf{b} \end{aligned} \]

where the \(\boldsymbol{\phi}(\cdot)\) is a basis function. Neural Fields typically try to learn this basis funciton via a series of composite functions of the form

\[ \boldsymbol{\phi}(\mathbf{x}) = \boldsymbol{\phi}_L\circ\boldsymbol{\phi}_{L-1} \circ\cdots\circ \boldsymbol{\phi}_2\circ\boldsymbol{\phi}_{1}(\mathbf{x}) \]

Problems#

Here, we will demonstrate a problem that a naive network has.

Sparse Observations#

In the previous examples, we were demonstrating how NerFs perform when we have some clean simulation. However, in many real problems, we do not have access to such clean

For this example, we are going to look at the case when we have very sparse observations: as in the case with satellite altimetry data like SWOT. In this case

!ls /gpfswork/rech/cli/uvo53rl/projects/jejeqx/data/natl60/
# # load config
# config_dm = OmegaConf.load('./configs/natl60_obs.yaml')

# # instantiate
# # dm = hydra.utils.instantiate(config_dm.datamodule)
# dm = hydra.utils.instantiate(config_dm.alongtrack_scaled)
# # run setup
# dm.setup()

# # check cunits
# (
#     dm.ds_test[:]["spatial"].min(),
#     dm.ds_test[:]["spatial"].max(),
#     dm.ds_test[:]["temporal"].min(),
#     dm.ds_test[:]["temporal"].max(),
#     dm.ds_test[:]["data"].min(),
#     dm.ds_test[:]["data"].max(),
# )

# len(dm.ds_train)
# load config
config_dm = OmegaConf.load("./configs/natl60_obs.yaml")

# instantiate
# dm = hydra.utils.instantiate(config_dm.datamodule)
dm = hydra.utils.instantiate(config_dm.alongtrack_scaled)
# run setup
dm.setup()

# dm = hydra.utils.instantiate(config_dm.datamodule)
dm_eval = hydra.utils.instantiate(
    config_dm.natl60_dc20a_eval,
    spatial_transform=dm.spatial_transform,
    temporal_transform=dm.temporal_transform,
)
# run setup
dm_eval.setup()

# check cunits
(
    dm.ds_test[:]["spatial"].min(),
    dm.ds_test[:]["spatial"].max(),
    dm.ds_test[:]["temporal"].min(),
    dm.ds_test[:]["temporal"].max(),
    dm.ds_test[:]["data"].min(),
    dm.ds_test[:]["data"].max(),
)
dm.spatial_transform.named_steps
len(dm.ds_train)
xrda = dm.load_xrds()
xrda
# %matplotlib inline

# fig, ax = plt.subplots()

# sub_ds = xrda_obs.isel(time=slice(0,None))
# pts = ax.scatter(sub_ds.lon, sub_ds.lat, c=sub_ds.ssh, s=0.1)
# ax.set(
#     xlabel="Longitude",
#     ylabel="Latitude",
# )

# plt.colorbar(pts, label="Sea Surface Height [m]")
# plt.tight_layout()
# plt.show()
init = dm.ds_train[:32]
x_init, t_init, y_init = init["spatial"], init["temporal"], init["data"]
x_init.min(), x_init.max(), x_init.shape, t_init.min(), t_init.max(), t_init.shape

Model#

The input data is a coordinate vector, \(\mathbf{x}_\phi\), of the image coordinates.

\[ \mathbf{x}_\phi \in \mathbb{R}^{D_\phi} \]

where \(D_\phi = [\text{x}, \text{y}]\). So we are interested in learning a function, \(\boldsymbol{f}\), such that we can input a coordinate vector and output a scaler/vector value of the pixel value.

\[ \mathbf{u} = \boldsymbol{f}(\mathbf{x}_\phi; \boldsymbol{\theta}) \]
# load config
model_config = OmegaConf.load("./configs/model.yaml")

# instantiate
model_ffn = hydra.utils.instantiate(model_config.ffn)

# test output
out = model_ffn(x=x_init[0], t=t_init[0])
assert out.shape == y_init[0].shape

# test output (batched)
out_batch = jax.vmap(model_ffn, in_axes=(0, 0))(x_init, t_init)
assert out_batch.shape == y_init.shape

SIREN Layer#

\[ \boldsymbol{\phi}^{(\ell)}(\mathbf{x}) = \sin \left( \omega^{(\ell)}\left( \mathbf{w}^{(\ell)}\mathbf{x} + \mathbf{b}^{(\ell)} + \mathbf{s}^{(\ell)} \right)\right) \]

where \(\mathbf{s}\) is the modulation

\[ \mathbf{s}^{(\ell)} = \mathbf{w}_z^{(\ell)}\mathbf{z} + \mathbf{b}_z^{(\ell)} \]
# import joblib

# model_config_file = "/gpfswork/rech/cli/uvo53rl/checkpoints/nerfs/siren/nadir4/scratch/config.pkl"
# checkpoint_file = "/gpfswork/rech/cli/uvo53rl/checkpoints/nerfs/siren/nadir4/scratch/checkpoint_model.ckpt"

# old_config = joblib.load(model_config_file)

# model = hydra.utils.instantiate(old_config["model"])

Optimizer (+ Learning Rate)#

For this, we will use a simple adam optimizer with a learning_rate of 1e-4. From many studies, it appears that a lower learning rate works well with this methods because there is a lot of data. In addition, a bigger batch_size is also desireable. We will set the num_epochs to 2_000 which should be good enough for a single image. Obviously more epochs and a better learning rate scheduler would result in better results but this will be sufficient for this demo.

import optax

num_epochs = 250

# load config
opt_config = OmegaConf.load("./configs/optimizer.yaml")

# instantiate
optimizer = hydra.utils.instantiate(opt_config.adamw)
scheduler_config = OmegaConf.load("./configs/lr_scheduler.yaml")

num_steps_per_epoch = len(dm.ds_train)

scheduler = hydra.utils.instantiate(
    scheduler_config.warmup_cosine, decay_steps=int(num_epochs * num_steps_per_epoch)
)

optimizer = optax.chain(optimizer, optax.scale_by_schedule(scheduler))
optimizer

Trainer Module#

import glob
import os
from pathlib import Path

from jejeqx._src.trainers.base import TrainerModule
from jejeqx._src.trainers.callbacks import wandb_model_artifact
from jejeqx._src.losses import psnr


class RegressorTrainer(TrainerModule):
    def __init__(self, model, optimizer, **kwargs):
        super().__init__(model=model, optimizer=optimizer, pl_logger=None, **kwargs)

    def create_functions(self):
        @eqx.filter_value_and_grad
        def mse_loss(model, batch):
            x, t, y = batch["spatial"], batch["temporal"], batch["data"]
            pred = jax.vmap(model, in_axes=(0, 0))(x, t)
            loss = jnp.mean((y - pred) ** 2)
            return loss

        def train_step(state, batch):
            loss, grads = mse_loss(state.params, batch)
            state = state.update_state(state, grads)
            psnr_loss = psnr(loss)
            metrics = {"loss": loss, "psnr": psnr_loss}
            return state, loss, metrics

        def eval_step(model, batch):
            loss, _ = mse_loss(model, batch)
            psnr_loss = psnr(loss)
            return {"loss": loss, "psnr": psnr_loss}

        def test_step(model, batch):
            x, t = batch["spatial"], batch["temporal"]
            out = jax.vmap(model, in_axes=(0, 0))(x, t)
            loss, _ = mse_loss(model, batch)
            psnr_loss = psnr(loss)
            return out, {"loss": loss, "psnr": psnr_loss}

        def predict_step(model, batch):
            x, t = batch["spatial"], batch["temporal"]
            out = jax.vmap(model, in_axes=(0, 0))(x, t)
            return out

        return train_step, eval_step, test_step, predict_step

    def on_training_end(
        self,
    ):
        if self.pl_logger:
            save_dir = Path(self.log_dir).joinpath(self.save_name)
            self.save_model(save_dir)
            wandb_model_artifact(self)
            self.pl_logger.finalize("success")
seed = 123
debug = False
enable_progress_bar = False
log_dir = "./"

trainer = RegressorTrainer(
    model_ffn,
    optimizer,
    seed=seed,
    debug=debug,
    enable_progress_bar=enable_progress_bar,
    log_dir=log_dir,
)

train_more = False
%%time


out, metrics = trainer.test_model(dm.test_dataloader())
metrics
try:
    trainer.load_model("./checkpoints/checkpoint_model_rff_osse_nadir.ckpt")
    # trainer.load_model("./checkpoints/checkpoint_model_rff_ssh.ckpt")
    pass
except:
    pass
%%time


out, metrics = trainer.test_model(dm.test_dataloader())
metrics
%%time

if train_more:
    metrics = trainer.train_model(dm, num_epochs=num_epochs)
out, metrics = trainer.test_model(dm.test_dataloader())

metrics
out.shape
if train_more:
    trainer.save_model("./checkpoints/checkpoint_model_rff_osse_nadir.ckpt")
all_metrics = pd.DataFrame()
all_metrics = pd.concat(
    [
        all_metrics,
        pd.DataFrame(
            data=[["rff", metrics["loss"], metrics["psnr"]]],
            columns=["model", "MSE", "PSNR"],
        ),
    ]
)
xrda = dm_eval.load_xrds()
%%time

out, metrics = trainer.test_model(dm_eval.test_dataloader())
metrics
xrda["ssh_rff"] = (("time", "lat", "lon"), dm_eval.data_to_df(out).to_xarray().ssh.data)
xrda["ssh_rff"].attrs["standard_name"] = "Sea Surface Height"
ssh_fn_rff = trainer.model
fig, ax = plt.subplots(ncols=3, figsize=(12, 3))

itime = "2012-10-22"

xrda.ssh.sel(time=itime).plot.pcolormesh(ax=ax[0], cmap="viridis")
ax[0].set(title="Original")

# xrda.ssh_mlp.isel(time=0).plot.pcolormesh(ax=ax[1], cmap="viridis")
# ax[1].set(title="Naive MLP")

xrda.ssh_rff.sel(time=itime).plot.pcolormesh(ax=ax[2], cmap="viridis")
ax[2].set(title="Fourier Features")


plt.tight_layout()
plt.show()
import typing as tp
from jejeqx._src.transforms.xarray.geostrophic import calculate_coriolis
from metpy.constants import earth_gravity

f0: Array = jnp.asarray(1e-5)
g: Array = jnp.asarray(earth_gravity.magnitude)
c: Array = jnp.asarray(1.5)
# f0: Array = jnp.asarray(calculate_coriolis(xrda.lat).data.magnitude)
# g: Array = jnp.asarray(earth_gravity.magnitude)


def create_streamfn(f: tp.Callable, f0: float = 1e-5, g: float = 9.81) -> tp.Callable:
    def sfn(x: Array, t: Array) -> Array:
        return (g / f0) * f(x, t)

    return sfn


def create_gradient_fn(f: tp.Callable) -> tp.Callable:
    def fn(x: Array, t: Array) -> Array:
        return jax.jacfwd(f)(x, t).squeeze()

    return fn


def uv_velocity(grad_psi: Array) -> tp.Tuple[Array, Array]:
    dpsi_x, dpsi_y = jnp.split(grad_psi, 2, axis=-1)

    u = -dpsi_y
    v = dpsi_x
    return u, v


def create_laplacian_fn(f: tp.Callable) -> tp.Callable:
    def fn(x: Array, t: Array) -> Array:
        # return jax.jacfwd(jax.jacrev(f))(x)
        H = jax.hessian(f)
        L = jnp.diagonal(H(x, t)[0])
        return jnp.sum(L, keepdims=True)

    return fn


def create_pvort_fn(f: tp.Callable, f0: float = 1e-5, c: float = 1.5) -> tp.Callable:
    rvort_fn = create_laplacian_fn(f)

    def fn(x: Array, t: Array) -> Array:
        rvort = rvort_fn(x, t)
        return rvort - (f0 / c) ** 2 * f(x, t)

    return fn


def create_advection_fn(f: tp.Callable) -> tp.Callable:
    pvort_fn = create_pvort_fn(f)
    grad_pvort_fn = create_gradient_fn(pvort_fn)
    grad_psi_fn = create_gradient_fn(f)

    def fn(x: Array, t: Array) -> Array:
        # gradient of potential vorticity
        grad_pvort = grad_pvort_fn(x, t)
        pvort_x, pvort_y = jnp.split(grad_pvort, 2, axis=-1)
        # u, v - velocity
        grad_psi = grad_psi_fn(x, t)
        u, v = uv_velocity(grad_psi)
        return u * pvort_x + v * pvort_y

    return fn
ssh_fn = trainer.model
psi_fn = create_streamfn(ssh_fn)
grad_psi_fn = create_gradient_fn(psi_fn)
rvort_fn = create_laplacian_fn(psi_fn)
pvort_fn = create_pvort_fn(psi_fn)
rhs_fn = create_advection_fn(psi_fn)
eta = ssh_fn(x_init[10], y_init[10])
psi = psi_fn(x_init[10], y_init[10])
rvort = rvort_fn(x_init[10], y_init[10])
pvort = pvort_fn(x_init[10], y_init[10])
rhs = rhs_fn(x_init[10], y_init[10])
eta.shape, psi.shape, rvort.shape, pvort.shape, rhs.shape
eta, psi, rvort, pvort, rhs
def qg_loss_fn(f, f0, g, c):
    
    psi_fn = create_streamfn(ff=f, f0=f0, g=g)
    
    grad_psi_fn = create_gradient_fn(psi_fn)
    
    rvort_fn = create_laplacian_fn(psi_fn)
    
    grad_psi_fn = create_gradient_fn(psi_fn)
    
    def residual_fn(x, t):
        # calculate psi
        psi = psi_fn(x, t)
        # calculate relative vorticity
        rvort = rvort_fn(x, t)
        # calculate the gradient of psi
        grad_psi = grad_fn(x, t)
        
        # calculate u, v
        u, v = uv_velocity(grad_psi)
        dq_dx, dq_dy = grad_fn(
        # calculate advection
        rhs = u * 
        return None
    return residual_fn
\[ \mathcal{R}(\boldsymbol{\theta}) = \partial_t q - u \partial_x q + v\partial_y q \]

where:

\[\begin{split} \begin{aligned} \psi& = \frac{g}{f_0}\eta \\ q &= \nabla^2\psi - \frac{f_0^2}{c_1^2}\psi \end{aligned} \end{split}\]

Evaluation#

We will predict the whole dataset at the full resolution available for the same time period.

01-June-2013 :--> 15-June-2013

from dataclasses import dataclass, field
from typing import List, Dict


@dataclass
class SSHDMEVAL:
    _target_: str = "jejeqx._src.datamodules.coords.EvalCoordDM"
    paths: str = "/gpfswork/rech/yrf/commun/data_challenges/dc20a_osse/test/dc_ref/NATL60-CJM165_GULFSTREAM*"
    batch_size: int = 10_000
    shuffle: bool = False
    train_size: float = 0.80
    decode_times: bool = True
    evaluation: bool = True
    spatial_coords: List = field(default_factory=lambda: ["lat", "lon"])
    temporal_coords: List = field(default_factory=lambda: ["time"])
    variables: List = field(default_factory=lambda: ["sossheig"])
    coarsen: Dict = field(default_factory=lambda: {"lon": 2, "lat": 2})
    resample: str = "1D"
%%time

# select = {"time": slice("2012-10-22", "2012-11-22")}
select = {"time": slice("2012-10-22", "2012-12-02")}

config_dm = OmegaConf.structured(SSHDMEVAL())

dm_eval = hydra.utils.instantiate(
    config_dm,
    select=select,
    spatial_transform=dm.spatial_transform,
    temporal_transform=dm.temporal_transform,
)

dm_eval.setup()
print(f"Num Points: {len(dm_eval.ds_test):,}")
%%time

xrda = dm_eval.load_xrds()
%%time

out, metrics = trainer.test_model(dm_eval.test_dataloader())
metrics
xrda["ssh_rff"] = dm_eval.data_to_df(out).to_xarray().sossheig
import common_utils as cutils
ds_rff = cutils.calculate_physical_quantities(xrda.ssh_rff)
ds_natl60 = cutils.calculate_physical_quantities(xrda.sossheig)
fig, ax = cutils.plot_analysis_vars(
    [
        ds_natl60.isel(time=-1),
        ds_rff.isel(time=-1),
    ]
)
plt.show()
ds_psd_natl60 = cutils.calculate_isotropic_psd(ds_natl60)
ds_psd_rff = cutils.calculate_isotropic_psd(ds_rff)
fig, ax = cutils.plot_analysis_psd_iso(
    [
        ds_psd_natl60,
        ds_psd_rff,
    ],
    [
        "NATL60",
        "RFE",
    ],
)
plt.show()
ds_psd_scores = cutils.calculate_isotropic_psd_score(ds_rff, ds_natl60)
cutils.plot_analysis_psd_iso_score([ds_psd_scores], ["SIREN"], ["k"])
plt.show()
for ivar in ds_psd_scores:
    resolved_spatial_scale = ds_psd_scores[ivar].attrs["resolved_scale_space"] / 1e3
    print(f"Wavelength [km]: {resolved_spatial_scale:.2f} [{ivar.upper()}]")
    print(f"Wavelength [degree]: {resolved_spatial_scale/111:.2f} [{ivar.upper()}]")
ds_psd_natl60 = cutils.calculate_spacetime_psd(ds_natl60)
ds_psd_rff = cutils.calculate_spacetime_psd(ds_rff)
fig, ax = cutils.plot_analysis_psd_spacetime(
    [
        ds_psd_natl60,
        ds_psd_rff,
    ],
    [
        "NATL60",
        "RFE",
    ],
)
plt.show()
ds_psd_rff = cutils.calculate_spacetime_psd_score(ds_rff, ds_natl60)
for ivar in ds_psd_rff:
    resolved_spatial_scale = ds_psd_rff[ivar].attrs["resolved_scale_space"] / 1e3
    print(f"Resolved Scale [km]: {resolved_spatial_scale:.2f} [{ivar.upper()}]")
    resolved_temporal_scale = ds_psd_rff[ivar].attrs["resolved_scale_time"]
    print(f"Resolved Scale [days]: {resolved_temporal_scale:.2f}  [{ivar.upper()}]")
_ = cutils.plot_analysis_psd_spacetime_score([ds_psd_rff], ["rff"])