OSE NADIR#
import autoroot
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.random as jrandom
import numpy as np
import numba as nb
import equinox as eqx
import kernex as kex
import finitediffx as fdx
import diffrax as dfx
import xarray as xr
import metpy
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm, trange
from jaxtyping import Float, Array, PyTree, ArrayLike
import wandb
from omegaconf import OmegaConf
import hydra
from sklearn import pipeline
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from jejeqx._src.transforms.dataframe.spatial import Spherical2Cartesian
from jejeqx._src.transforms.dataframe.temporal import TimeDelta
from jejeqx._src.transforms.dataframe.scaling import MinMaxDF
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
# Ensure TF does not see GPU and grab all GPU memory.
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
jax.config.update("jax_enable_x64", False)
%matplotlib inline
%load_ext autoreload
%autoreload 2
Recap Formulation#
We are interested in learning non-linear functions \(\boldsymbol{f}\).
where the \(\boldsymbol{\phi}(\cdot)\) is a basis function. Neural Fields typically try to learn this basis funciton via a series of composite functions of the form
Problems#
Here, we will demonstrate a problem that a naive network has.
Sparse Observations#
In the previous examples, we were demonstrating how NerFs perform when we have some clean simulation. However, in many real problems, we do not have access to such clean
For this example, we are going to look at the case when we have very sparse observations: as in the case with satellite altimetry data like SWOT. In this case
# !wget -nc "https://s3.us-east-1.wasabisys.com/melody/osse_data/data/gridded_data_swot_wocorr/dataset_nadir_0d_swot.nc"
!ls /gpfswork/rech/yrf/commun/data_challenges/dc21a_ose/test/train/
from pathlib import Path
from dataclasses import dataclass, field
from typing import List, Dict
@dataclass
class SSHDM:
_target_: str = "jejeqx._src.datamodules.coords.AlongTrackDM"
paths: str = (
"/gpfswork/rech/yrf/commun/data_challenges/dc21a_ose/test/train/dt_gulfstream_*"
)
batch_size: int = 5_000
shuffle: bool = True
train_size: float = 0.80
# subset_size: float = 0.50
spatial_coords: List = field(default_factory=lambda: ["lat", "lon"])
temporal_coords: List = field(default_factory=lambda: ["time"])
variables: List = field(default_factory=lambda: ["sla_unfiltered"])
time_units: str = "seconds since 2016-12-01"
# spatial transform
spatial_transforms = Pipeline(
[
("cartesian3d", Spherical2Cartesian(radius=1.0, units="degrees")),
("spatialminmax", MinMaxDF(["x", "y", "z"], -1, 1)),
]
)
temporal_transforms = Pipeline(
[
("timedelta", TimeDelta("2016-12-01", 1, "s")),
("timeminmax", MinMaxDF(["time"], -1, 1)),
]
)
select = {"time": slice("2016-12-01", "2018-01-31")}
config_dm = OmegaConf.structured(SSHDM())
dm = hydra.utils.instantiate(
config_dm,
select=select,
spatial_transform=spatial_transforms,
temporal_transform=temporal_transforms,
)
dm.setup()
init = dm.ds_train[:32]
x_init, t_init, y_init = init["spatial"], init["temporal"], init["data"]
x_init.min(), x_init.max(), x_init.shape, t_init.min(), t_init.max(), t_init.shape
len(dm.ds_train)
# xrda_obs = dm.load_xrds()
# xrda_obs
# fig, ax = plt.subplots(ncols=1, figsize=(5,4))
# xrda_obs.ssh_obs.isel(time=1).plot.pcolormesh(ax=ax, cmap="viridis")
# ax.set(title="Original")
# plt.tight_layout()
# plt.show()
# import geoviews as gv
# import geoviews.feature as gf
# from cartopy import crs
# gv.extension('bokeh', 'matplotlib')
# xrda_obs
# dataset = gv.Dataset(xrda_obs)
# ensemble1 = dataset.to(gv.Image, ['lon', 'lat'], "ssh_obs")
# gv.output(ensemble1.opts(cmap='viridis', colorbar=True, fig_size=200, backend='matplotlib') * gf.coastline(),
# backend='matplotlib')
Model#
The input data is a coordinate vector, \(\mathbf{x}_\phi\), of the image coordinates.
where \(D_\phi = [\text{x}, \text{y}]\). So we are interested in learning a function, \(\boldsymbol{f}\), such that we can input a coordinate vector and output a scaler/vector value of the pixel value.
SIREN Layer#
where \(\mathbf{s}\) is the modulation
from jejeqx._src.nets.nerfs.ffn import RFFLayer
model_name = "rff"
model = eqx.nn.Sequential(
[
RFFLayer(in_dim=4, num_features=256, out_dim=256, key=jrandom.PRNGKey(42)),
RFFLayer(in_dim=256, num_features=256, out_dim=256, key=jrandom.PRNGKey(123)),
RFFLayer(in_dim=256, num_features=256, out_dim=256, key=jrandom.PRNGKey(23)),
RFFLayer(in_dim=256, num_features=256, out_dim=256, key=jrandom.PRNGKey(81)),
RFFLayer(in_dim=256, num_features=256, out_dim=1, key=jrandom.PRNGKey(32)),
]
)
# check output of models
out = jax.vmap(model)(jnp.hstack([x_init, t_init]))
assert out.shape == y_init.shape
Optimizer (+ Learning Rate)#
For this, we will use a simple adam optimizer with a learning_rate
of 1e-4. From many studies, it appears that a lower learning rate works well with this methods because there is a lot of data. In addition, a bigger batch_size
is also desireable. We will set the num_epochs
to 1_000
which should be good enough for a single image. Obviously more epochs and a better learning rate scheduler would result in better results but this will be sufficient for this demo.
import optax
num_epochs = 1_000
@dataclass
class Optimizer:
_target_: str = "optax.adam"
learning_rate: float = 1e-4
# @dataclass
# class Scheduler:
# _target_: str = "optax.warmup_exponential_decay_schedule"
# init_value: float = 0.0
# peak_value: float = 1e-2
# warmup_steps: int = 100
# end_value: float = 1e-5
# decay_rate: float = 0.1
# FINETUNE!
@dataclass
class Scheduler:
_target_: str = "optax.warmup_cosine_decay_schedule"
init_value: float = 0.0
peak_value: float = 1e-2
warmup_steps: int = 500
end_value: float = 1e-6
scheduler_config = OmegaConf.structured(Scheduler())
optim_config = OmegaConf.structured(Optimizer())
optimizer = hydra.utils.instantiate(optim_config)
# num_steps_per_epoch = len(dm.ds_train)
# scheduler = hydra.utils.instantiate(
# scheduler_config,
# decay_steps=int(num_epochs * num_steps_per_epoch)
# )
# optimizer = optax.chain(optimizer, optax.scale_by_schedule(scheduler))
Trainer Module#
import glob
import os
from pathlib import Path
from jejeqx._src.trainers.base import TrainerModule
from jejeqx._src.trainers.callbacks import wandb_model_artifact
from jejeqx._src.losses import psnr
class RegressorTrainer(TrainerModule):
def __init__(self, model, optimizer, **kwargs):
super().__init__(model=model, optimizer=optimizer, pl_logger=None, **kwargs)
@property
def model(self):
return self.state.params
@property
def model_batch(self):
return jax.vmap(self.state.params, in_axes=(0, 0))
def create_functions(self):
@eqx.filter_value_and_grad
def mse_loss(model, batch):
x, t, y = batch["spatial"], batch["temporal"], batch["data"]
# pred = jax.vmap(model, in_axes=(0,0))(x, t)
pred = jax.vmap(model)(jnp.hstack([x, t]))
loss = jnp.mean((y - pred) ** 2)
return loss
def train_step(state, batch):
loss, grads = mse_loss(state.params, batch)
state = state.update_state(state, grads)
psnr_loss = psnr(loss)
metrics = {"loss": loss, "psnr": psnr_loss}
return state, loss, metrics
def eval_step(model, batch):
loss, _ = mse_loss(model, batch)
psnr_loss = psnr(loss)
return {"loss": loss, "psnr": psnr_loss}
def test_step(model, batch):
x, t, y = batch["spatial"], batch["temporal"], batch["data"]
pred = jax.vmap(model)(jnp.hstack([x, t]))
loss, _ = mse_loss(model, batch)
psnr_loss = psnr(loss)
return pred, {"loss": loss, "psnr": psnr_loss}
def predict_step(model, batch):
x, t = batch["spatial"], batch["temporal"]
pred = jax.vmap(model)(jnp.hstack([x, t]))
return pred
return train_step, eval_step, test_step, predict_step
def on_training_end(
self,
):
if self.pl_logger:
save_dir = Path(self.log_dir).joinpath(self.save_name)
self.save_model(save_dir)
wandb_model_artifact(self)
self.pl_logger.finalize("success")
seed = 123
debug = False
enable_progress_bar = False
log_dir = "./"
trainer = RegressorTrainer(
model,
optimizer,
seed=seed,
debug=debug,
enable_progress_bar=enable_progress_bar,
log_dir=log_dir,
)
train_more = True
save_more = True
%%time
out, metrics = trainer.test_model(dm.test_dataloader())
metrics
if model_name == "rff":
# trainer.load_model("./checkpoints/checkpoint_model_rff_ssh_nadir.ckpt")
# trainer.load_model("./checkpoints/checkpoint_model_rff_ssh_swot.ckpt")
# trainer.load_model("./checkpoints/checkpoint_natl60_model_rff.ckpt")
trainer.load_model("./checkpoints/checkpoint_model_rff_ssh_ose_year.ckpt")
pass
elif model_name == "siren":
trainer.load_model("./checkpoints/checkpoint_model_siren_ssh_swot.ckpt")
%%time
out, metrics = trainer.test_model(dm.test_dataloader())
metrics
%%time
metrics = trainer.train_model(dm, num_epochs=num_epochs)
%%time
out, metrics = trainer.test_model(dm.test_dataloader())
metrics
if save_more:
if model_name == "rff":
# trainer.save_model("./checkpoints/checkpoint_model_rff_ssh_swot.ckpt")
trainer.save_model("./checkpoints/checkpoint_model_rff_ssh_ose_year.ckpt")
elif model_name == "siren":
trainer.save_model("./checkpoints/checkpoint_model_siren_ssh_ose.ckpt")
Evaluation#
!ls /gpfswork/rech/yrf/commun/data_challenges/dc21a_ose/test/test/
from dataclasses import dataclass, field
from typing import List, Dict
@dataclass
class SSHDMEVAL:
_target_: str = "jejeqx._src.datamodules.coords.AlongTrackDM"
paths: str = (
"/gpfswork/rech/yrf/commun/data_challenges/dc21a_ose/test/test/dt_gulfstream_*"
)
batch_size: int = 10_000
evaluation: bool = True
# subset_size: float = None
spatial_coords: List = field(default_factory=lambda: ["lat", "lon"])
temporal_coords: List = field(default_factory=lambda: ["time"])
variables: List = field(default_factory=lambda: ["sla_unfiltered"])
time_units: str = "seconds since 2016-12-01"
select = {"time": slice("2017-01-01", "2017-12-31")}
config_dm = OmegaConf.structured(SSHDMEVAL())
dm_eval = hydra.utils.instantiate(
config_dm,
select=select,
spatial_transform=dm.spatial_transform,
temporal_transform=dm.temporal_transform,
)
dm_eval.setup()
%%time
out, metrics = trainer.test_model(dm_eval.test_dataloader())
metrics
Evaluation#
We will predict the whole dataset at the full resolution available for the same time period.
01-June-2013 :--> 15-June-2013
from dataclasses import dataclass, field
from typing import List, Dict
from jejeqx._src.types.xrdata import Bounds, Period
@dataclass
class SSHDMEVAL:
_target_: str = "jejeqx._src.datamodules.coords.EvalGridDM"
lon_limits: Bounds = Bounds(-65, -55, 0.1)
lat_limits: Bounds = Bounds(33, 43, 0.1)
time_limits: Period = Period("2017-01-01", "2018-01-01", 1, "D")
spatial_coords: List = field(default_factory=lambda: ["lat", "lon"])
temporal_coords: List = field(default_factory=lambda: ["time"])
%%time
config_dm = OmegaConf.structured(SSHDMEVAL())
dm_eval = hydra.utils.instantiate(
config_dm,
spatial_transform=dm.spatial_transform,
temporal_transform=dm.temporal_transform,
)
dm_eval.setup()
len(dm_eval.ds_predict)
%%time
out = trainer.predict_model(dm_eval.predict_dataloader())
xrda = dm_eval.data_to_df(out).to_xarray()
xrda
fig, ax = plt.subplots()
itime = 5
xrda.ssh.isel(time=0).plot.pcolormesh(ax=ax, cmap="viridis", robust=False)
ax.set(title="Random Feature Expansions")
plt.tight_layout()
plt.show()
import common_utils as cutils
ds_rff = cutils.calculate_physical_quantities(xrda.ssh)
ds_rff
fig, ax = cutils.plot_analysis_vars([ds_rff.isel(time=5)])
plt.show()
ds_psd_rff = cutils.calculate_isotropic_psd(ds_rff)
fig, ax = cutils.plot_analysis_psd_iso([ds_psd_rff], ["RFE"])
plt.show()
ds_psd_rff = cutils.calculate_spacetime_psd(ds_rff)
fig, ax = cutils.plot_analysis_psd_spacetime([ds_psd_rff], ["RFE"])
plt.show()