AlongTrack Evaluation#
In this notebook, we look at how to evaluate alongtrack data. In the OSSE setting in previous notebooks, we saw that we can evaluate our method on the alongtracks and then compare our learned field with the field from the simulation. In the OSE setting, we don’t have the original field. So we need to evaluate the field directly on the leave-one-out alongtrack. We demonstrate some simple statistics in the real and spectral space that we can perform.
import autoroot
import typing as tp
from dataclasses import dataclass
import functools as ft
import numpy as np
import pandas as pd
import xarray as xr
import einops
from metpy.units import units
import pint_xarray
import xarray_dataclasses as xrdataclass
from oceanbench._src.geoprocessing.validation import validate_latlon, validate_time, decode_cf_time, validate_ssh
from oceanbench._src.preprocessing.alongtrack import alongtrack_ssh
from oceanbench._src.geoprocessing.subset import where_slice
from oceanbench._src.preprocessing.alongtrack import remove_swath_dimension
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.ticker as ticker
import seaborn as sns
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
%load_ext autoreload
%autoreload 2
Data#
!ls "/gpfswork/rech/yrf/commun/data_challenges/dc21a_ose/test/results"
leaderboard.csv OSE_ssh_mapping_DUACS.nc
OSE_ssh_mapping_4dvarNet_2022.nc OSE_ssh_mapping_DYMOST.nc
OSE_ssh_mapping_4dvarNet.nc OSE_ssh_mapping_MIOST.nc
OSE_ssh_mapping_BASELINE.nc results.csv
OSE_ssh_mapping_BFN.nc
!ls /gpfswork/rech/yrf/commun/data_challenges/dc21a_ose/test/test
dt_gulfstream_c2_phy_l3_20161201-20180131_285-315_23-53.nc
from oceanbench._src.geoprocessing.validation import validate_latlon, validate_time
from oceanbench._src.preprocessing.alongtrack import alongtrack_ssh
from oceanbench._src.geoprocessing.subset import where_slice
def preprocess_nadir(da):
# reorganized
da = da.sortby("time").compute()
# validate coordinates
da = da.rename({"longitude": "lon", "latitude": "lat"})
da = validate_latlon(da)
da = validate_time(da)
# slice region
da = where_slice(da, "lon", -64.975, -55.007)
da = where_slice(da, "lat", 33.025, 42.9917)
# slice time period
da = da.sel(time=slice("2017-01-01", "2017-12-31"))
# calculate SSH directly
da = alongtrack_ssh(da)
return da
files_nadir_dc21a = [
"/gpfswork/rech/yrf/commun/data_challenges/dc21a_ose/test/test/dt_gulfstream_c2_phy_l3_20161201-20180131_285-315_23-53.nc",
]
ds_nadir = xr.open_mfdataset(
files_nadir_dc21a,
preprocess=preprocess_nadir,
combine="nested",
engine="netcdf4",
concat_dim="time"
).sortby("time").compute()
ds_nadir
<xarray.Dataset>
Dimensions: (time: 49859)
Coordinates:
* time (time) datetime64[ns] 2017-01-01T08:08:42.012641024 ... 2...
lon (time) float64 -62.61 -62.61 -62.62 ... -62.09 -62.1 -62.1
lat (time) float64 33.03 33.09 33.15 33.2 ... 42.83 42.89 42.95
Data variables:
cycle (time) float64 88.0 88.0 88.0 88.0 ... 101.0 101.0 101.0
track (time) float64 701.0 701.0 701.0 701.0 ... 353.0 353.0 353.0
dac (time) float32 -0.1647 -0.1648 -0.165 ... 0.03 0.0314 0.0327
lwe (time) float32 0.003 0.003 0.003 ... -0.029 -0.029 -0.029
mdt (time) float32 0.593 0.592 0.591 ... -0.165 -0.164 -0.163
ocean_tide (time) float64 -0.3407 -0.3413 -0.342 ... -0.1686 -0.1693
sla_filtered (time) float32 -0.136 -0.16 -0.18 -0.194 ... 0.105 0.102 0.1
sla_unfiltered (time) float32 -0.151 -0.119 -0.158 ... 0.081 0.097 0.114
ssh (time) float32 0.439 0.47 0.43 0.404 ... -0.055 -0.038 -0.02
Attributes: (12/44)
Conventions: CF-1.6
Metadata_Conventions: Unidata Dataset Discovery v1.0
cdm_data_type: Swath
comment: Sea surface height measured by altimeter...
contact: servicedesk.cmems@mercator-ocean.eu
creator_email: servicedesk.cmems@mercator-ocean.eu
... ...
summary: SSALTO/DUACS Delayed-Time Level-3 sea su...
time_coverage_duration: P23H43M4.754863S
time_coverage_end: 2016-01-01T23:11:03Z
time_coverage_resolution: P1S
time_coverage_start: 2015-12-31T23:27:58Z
title: DT Cryosat-2 Global Ocean Along track SS...%matplotlib inline
fig, ax = plt.subplots()
sub_ds = ds_nadir.sel(time=slice("2017-01-01", "2017-01-15"))
pts = ax.scatter(sub_ds.lon, sub_ds.lat, c=sub_ds.ssh, s=0.1)
ax.set(
xlabel="Longitude",
ylabel="Latitude",
xlim=[ds_nadir.lon.min(), ds_nadir.lon.max()],
ylim=[ds_nadir.lat.min(), ds_nadir.lat.max()],
)
plt.colorbar(pts, label="Sea Surface Height [m]")
plt.tight_layout()
plt.show()
Field#
Now, we will load a field from a method that interpolated the data.
# !ls /gpfswork/rech/yrf/commun/data_challenges/dc20a_osse/staging/natl60/
!ls /gpfswork/rech/yrf/commun/data_challenges/dc21a_ose/test/results
leaderboard.csv OSE_ssh_mapping_DUACS.nc
OSE_ssh_mapping_4dvarNet_2022.nc OSE_ssh_mapping_DYMOST.nc
OSE_ssh_mapping_4dvarNet.nc OSE_ssh_mapping_MIOST.nc
OSE_ssh_mapping_BASELINE.nc results.csv
OSE_ssh_mapping_BFN.nc
def preprocess_field(da):
# validate coordinates
da = validate_latlon(da)
da = validate_time(da)
da = validate_ssh(da)
# select region + time period
da = da.sel(
time=slice("2017-01-01", "2017-12-31"),
lon=slice(-65., -55.),
lat=slice(33., 43.),
)
return da
Gridding#
file_DUACS = "/gpfswork/rech/yrf/commun/data_challenges/dc21a_ose/test/results/OSE_ssh_mapping_DUACS.nc"
%%time
ds_field = xr.open_mfdataset(
file_DUACS,
decode_times=True,
preprocess=preprocess_field,
combine="nested",
engine="netcdf4",
concat_dim="time"
)
ds_field = ds_field.sortby("time").compute()
ds_field
CPU times: user 23.7 ms, sys: 8.99 ms, total: 32.7 ms
Wall time: 67.8 ms
<xarray.Dataset>
Dimensions: (lat: 40, lon: 40, time: 365)
Coordinates:
* lat (lat) float64 33.12 33.38 33.62 33.88 ... 42.12 42.38 42.62 42.88
* lon (lon) float64 -64.88 -64.62 -64.38 -64.12 ... -55.62 -55.38 -55.12
* time (time) datetime64[ns] 2017-01-01 2017-01-02 ... 2017-12-31
Data variables:
ssh (time, lat, lon) float64 0.817 0.8275 0.8267 ... 0.08649 0.0524
Attributes:
FileType: GRID_DOTS
OriginalName: dt_upd_global_merged_msla_h_20170101_20170101_20190823.nc
CreatedBy: ballarm@node036.sis.cnes.fr
CreatedOn: 23-AUG-2019 11:21:19:000000
title: SSALTO/DUACS - DT MSLA - Merged Product - Up-to-date Globa...
history: 2019/08/23 11:21:19 ballarm@node036.sis.cnes.fr Import dep...fig, ax = plt.subplots()
ds_field.ssh.isel(time=10).plot.pcolormesh(ax=ax, cmap="viridis")
plt.show()
Regridding#
Now, we will need to regrid our field to the alongtrack data.
from oceanbench._src.geoprocessing.gridding import grid_to_coord_based
ds_field
<xarray.Dataset>
Dimensions: (lat: 40, lon: 40, time: 365)
Coordinates:
* lat (lat) float64 33.12 33.38 33.62 33.88 ... 42.12 42.38 42.62 42.88
* lon (lon) float64 -64.88 -64.62 -64.38 -64.12 ... -55.62 -55.38 -55.12
* time (time) datetime64[ns] 2017-01-01 2017-01-02 ... 2017-12-31
Data variables:
ssh (time, lat, lon) float64 0.817 0.8275 0.8267 ... 0.08649 0.0524
Attributes:
FileType: GRID_DOTS
OriginalName: dt_upd_global_merged_msla_h_20170101_20170101_20190823.nc
CreatedBy: ballarm@node036.sis.cnes.fr
CreatedOn: 23-AUG-2019 11:21:19:000000
title: SSALTO/DUACS - DT MSLA - Merged Product - Up-to-date Globa...
history: 2019/08/23 11:21:19 ballarm@node036.sis.cnes.fr Import dep...ds_nadir["ssh_interp"] = grid_to_coord_based(
ds_field.transpose("lon", "lat", "time"),
ds_nadir,
data_vars=["ssh"],
)["ssh"]
# np.isfinite(ds_nadir_gridded.ssh.isel(time=6)).plot.imshow()
ds_results = ds_nadir.copy()
# drop all nans
ds_results = ds_results[["ssh", "ssh_interp"]].dropna(dim="time")
ds_results
<xarray.Dataset>
Dimensions: (time: 47191)
Coordinates:
* time (time) datetime64[ns] 2017-01-01T08:08:43.899440896 ... 2017-...
lon (time) float64 -62.62 -62.63 -62.64 ... -59.19 -59.2 -59.21
lat (time) float64 33.15 33.2 33.26 33.32 ... 33.3 33.24 33.18 33.13
Data variables:
ssh (time) float32 0.43 0.404 0.364 0.358 ... 0.632 0.612 0.724
ssh_interp (time) float64 0.4091 0.4189 0.43 ... 0.7218 0.7214 0.7208
Attributes: (12/44)
Conventions: CF-1.6
Metadata_Conventions: Unidata Dataset Discovery v1.0
cdm_data_type: Swath
comment: Sea surface height measured by altimeter...
contact: servicedesk.cmems@mercator-ocean.eu
creator_email: servicedesk.cmems@mercator-ocean.eu
... ...
summary: SSALTO/DUACS Delayed-Time Level-3 sea su...
time_coverage_duration: P23H43M4.754863S
time_coverage_end: 2016-01-01T23:11:03Z
time_coverage_resolution: P1S
time_coverage_start: 2015-12-31T23:27:58Z
title: DT Cryosat-2 Global Ocean Along track SS...We should reduce the area slightly to ensure that the boundaries are not included. It makes the comparison more faire.
ds_results = where_slice(ds_results, "lon", -64.975 + 0.25, -55.007 - 0.25)
ds_results = where_slice(ds_results, "lat", 33.025 + 0.25, 42.9917 - 0.25)
ds_results
<xarray.Dataset>
Dimensions: (time: 44546)
Coordinates:
* time (time) datetime64[ns] 2017-01-01T08:08:46.729640704 ... 2017-...
lon (time) float64 -62.64 -62.65 -62.66 ... -59.17 -59.18 -59.19
lat (time) float64 33.32 33.38 33.43 33.49 ... 33.41 33.36 33.3
Data variables:
ssh (time) float32 0.358 0.382 0.457 0.463 ... 0.66 0.655 0.658
ssh_interp (time) float64 0.4415 0.4534 0.4847 ... 0.7216 0.7223 0.7221
Attributes: (12/44)
Conventions: CF-1.6
Metadata_Conventions: Unidata Dataset Discovery v1.0
cdm_data_type: Swath
comment: Sea surface height measured by altimeter...
contact: servicedesk.cmems@mercator-ocean.eu
creator_email: servicedesk.cmems@mercator-ocean.eu
... ...
summary: SSALTO/DUACS Delayed-Time Level-3 sea su...
time_coverage_duration: P23H43M4.754863S
time_coverage_end: 2016-01-01T23:11:03Z
time_coverage_resolution: P1S
time_coverage_start: 2015-12-31T23:27:58Z
title: DT Cryosat-2 Global Ocean Along track SS...%matplotlib inline
fig, ax = plt.subplots(ncols=3, figsize=(18,5))
# True SSH
sub_ds_ssh = ds_results.ssh.sel(time=slice("2017-01-01","2017-01-15"))
pts = ax[0].scatter(
sub_ds_ssh.lon, sub_ds_ssh.lat, c=sub_ds_ssh, s=0.1,
cmap="viridis",
)
ax[0].set(
xlabel="Longitude",
ylabel="Latitude",
xlim=[-65., -55.],
ylim=[33., 43.]
)
plt.colorbar(pts, label="Sea Surface Height [m]")
# Predicted SSH
sub_ds_ssh_interp = ds_results.ssh_interp.sel(time=slice("2017-01-01","2017-01-15"))
pts = ax[1].scatter(
sub_ds_ssh_interp.lon, sub_ds_ssh_interp.lat, c=sub_ds_ssh_interp, s=0.1,
cmap="viridis",
)
ax[1].set(
xlabel="Longitude",
ylabel="Latitude",
xlim=[-65., -55.],
)
plt.colorbar(pts, label="Sea Surface Height [m]")
# Predicted SSH
diff = sub_ds_ssh_interp - sub_ds_ssh
pts = ax[2].scatter(diff.lon, diff.lat, c=diff, s=0.1, cmap="Reds")
ax[2].set(
xlabel="Longitude",
ylabel="Latitude",
xlim=[-65., -55.],
)
plt.colorbar(pts, label="Difference [m]")
plt.tight_layout()
plt.show()
AlongTrack Segments#
TODO: Need to explain this better.
velocity = 6.77 # [km/s]
delta_t = 0.9434 # [s]
delta_x = velocity * delta_t # [km]
length_scale = 1000 # [km]
import oceanbench._src.preprocessing.alongtrack as atrack_process
import oceanbench._src.metrics.power_spectrum as psd_calc
ds_segments = atrack_process.select_track_segments(
ds_results,
"ssh", "ssh_interp",
velocity=6.77,
delta_t=0.9434,
length_scale=1000,
segment_overlapping=0.25
)
ds_segments
<xarray.Dataset>
Dimensions: (segment: 214, track_val: 156)
Coordinates:
* segment (segment) int64 0 1 2 3 4 5 6 7 ... 207 208 209 210 211 212 213
lat (segment) float64 -63.19 -61.24 -64.05 ... -56.85 -57.69 -60.55
lon (segment) float64 37.8 38.31 37.66 38.31 ... 38.16 38.29 37.71
* track_val (track_val) int64 0 1 2 3 4 5 6 ... 149 150 151 152 153 154 155
Data variables:
ssh_interp (segment, track_val) float64 0.4415 0.4534 ... 0.01457 0.01301
ssh (segment, track_val) float32 0.358 0.382 0.457 ... -0.007 0.018
Attributes:
delta_x: 6.386818
velocity: 6.77
length_scale: 1000
nperseg: 156nperseg = ds_segments.nperseg
nperseg
156
Real Statistics#
Statistics#
# psd_calc.psd_welch_score(ds_segments, "ssh", "ssh_interp", delta_x=delta_x, nperseg=nperseg).score
Spectral Statistics#
Power Spectrum#
# SSH AlongTrack
ds_psd = psd_calc.psd_welch(ds_segments, variable="ssh", delta_x=delta_x, nperseg=nperseg).ssh
# SSH Interpolated Map
ds_psd_interp = psd_calc.psd_welch(ds_segments, variable="ssh_interp", delta_x=delta_x, nperseg=nperseg).ssh_interp
# PSD Error
ds_psd_error = psd_calc.psd_welch_error(ds_segments, variable_ref="ssh", variable="ssh_interp", delta_x=delta_x, nperseg=nperseg).error
# PSD Score
ds_psd_score = psd_calc.psd_welch_score(ds_segments, variable_ref="ssh", variable="ssh_interp", delta_x=delta_x, nperseg=nperseg).score
ds_psd
<xarray.DataArray 'ssh' (wavenumber: 79)>
array([1.02373037e+01, 7.14277725e+01, 4.22874489e+01, 2.21006603e+01,
8.25816345e+00, 2.71950841e+00, 1.40674925e+00, 6.88535631e-01,
3.68134439e-01, 2.17998356e-01, 1.30187139e-01, 8.60557631e-02,
6.66696206e-02, 5.22705801e-02, 4.14400324e-02, 3.09867822e-02,
2.78125107e-02, 2.32093651e-02, 2.11372841e-02, 1.81758646e-02,
1.80478077e-02, 1.65228639e-02, 1.55075006e-02, 1.27541861e-02,
1.12466970e-02, 1.24145634e-02, 1.20474407e-02, 1.09734135e-02,
1.00068841e-02, 1.06369033e-02, 1.12501457e-02, 1.16684642e-02,
9.71117616e-03, 9.58298892e-03, 9.61237121e-03, 9.55455657e-03,
9.76762362e-03, 1.02562634e-02, 9.34584532e-03, 8.92794598e-03,
9.51405521e-03, 9.44814924e-03, 9.20096785e-03, 9.09277238e-03,
8.38163868e-03, 8.41708854e-03, 8.41432996e-03, 7.14715524e-03,
7.98082165e-03, 8.83140601e-03, 7.99081381e-03, 8.02045316e-03,
7.78246671e-03, 8.94837547e-03, 8.09417851e-03, 8.00068583e-03,
7.34653696e-03, 8.02976545e-03, 7.93727767e-03, 7.48019805e-03,
7.49919470e-03, 7.87938386e-03, 7.77373137e-03, 7.88705889e-03,
8.31951480e-03, 7.64988083e-03, 7.48078013e-03, 7.67526543e-03,
8.13236553e-03, 7.90031720e-03, 6.98836194e-03, 7.61722075e-03,
7.75263691e-03, 7.75971264e-03, 7.80881057e-03, 7.37115275e-03,
7.29903765e-03, 7.00414460e-03, 3.10720433e-03], dtype=float32)
Coordinates:
* wavenumber (wavenumber) float64 0.0 0.001004 0.002007 ... 0.07728 0.07829Figure#
class PlotPSDIsotropic:
def init_fig(self, ax=None, figsize=None):
if ax is None:
figsize = (5,4) if figsize is None else figsize
self.fig, self.ax = plt.subplots(figsize=figsize)
else:
self.ax = ax
self.fig = plt.gcf()
def plot_wavenumber(self, da,freq_scale=1.0, units=None, **kwargs):
if units is not None:
xlabel = f"Wavenumber [cycles/{units}]"
else:
xlabel = f"Wavenumber"
dim = list(da.dims)[0]
self.ax.plot(da[dim] * freq_scale, da, **kwargs)
self.ax.set(
yscale="log", xscale="log",
xlabel=xlabel,
ylabel=f"PSD [{da.name}]",
xlim=[10**(-3) - 0.00025, 10**(-1) +0.025]
)
self.ax.legend()
self.ax.grid(which="both", alpha=0.5)
def plot_wavelength(self, da, freq_scale=1.0, units=None, **kwargs):
if units is not None:
xlabel = f"Wavelength [{units}]"
else:
xlabel = f"Wavelength"
dim = list(da.dims)[0]
self.ax.plot(1/(da[dim] * freq_scale), da, **kwargs)
self.ax.set(
yscale="log", xscale="log",
xlabel=xlabel,
ylabel=f"PSD [{da.name}]"
)
self.ax.xaxis.set_major_formatter("{x:.0f}")
self.ax.invert_xaxis()
self.ax.legend()
self.ax.grid(which="both", alpha=0.5)
def plot_both(self, da, freq_scale=1.0, units=None, **kwargs):
if units is not None:
xlabel = f"Wavelength [{units}]"
else:
xlabel = f"Wavelength"
self.plot_wavenumber(da=da, units=units, freq_scale=freq_scale, **kwargs)
self.secax = self.ax.secondary_xaxis(
"top", functions=(lambda x: 1 / (x + 1e-20), lambda x: 1 / (x + 1e-20))
)
self.secax.xaxis.set_major_formatter("{x:.0f}")
self.secax.set(xlabel=xlabel)
psd_iso_plot = PlotPSDIsotropic()
psd_iso_plot.init_fig()
psd_iso_plot.plot_both(
ds_psd,
freq_scale=1,
units="km",
label="NADIR",
color="black",
)
psd_iso_plot.plot_both(
ds_psd_interp,
freq_scale=1,
units="km",
label="DUACS",
color="tab:green",
)
# set custom bounds
psd_iso_plot.ax.set_xlim((10**(-3) - 0.00025, 10**(-1) +0.025))
psd_iso_plot.ax.set_ylabel("PSD [SSH]")
plt.show()
Power Spectrum Score#
where \(\mathcal{F}\) is the power spectrum transformation.
from oceanbench._src.metrics.utils import find_intercept_1D, find_intercept_2D
Resolved Scale#
We can assign an arbitrary threshhold which we decide as the “resolved scale” for the quantity of interest. In this case, we decided 0.5 which is the minimum signal-to-noise ratio that we deem trustworthy.
find_intercept_1D(
x=ds_psd_score.values,
y=1./(ds_psd_score.wavenumber),
level=0.5,
kind="slinear",
fill_value="extrapolate"
)
array(152.67922533)
Figure#
class PlotPSDScoreIsotropic(PlotPSDIsotropic):
def _add_score(
self,
da,
freq_scale=1.0,
units=None,
threshhold: float=0.5,
threshhold_color="k",
name=""
):
dim = da.dims[0]
self.ax.set(ylabel="PSD Score", yscale="linear")
self.ax.set_ylim((0,1.0))
self.ax.set_xlim((
10**(-3) - 0.00025,
10**(-1) +0.025,
))
resolved_scale = freq_scale / find_intercept_1D(
x=da.values, y=1./(da[dim].values+1e-15), level=threshhold
)
self.ax.vlines(
x=resolved_scale,
ymin=0, ymax=threshhold,
color=threshhold_color,
linewidth=2, linestyle="--",
)
self.ax.hlines(
y=threshhold,
xmin=np.ma.min(np.ma.masked_invalid(da[dim].values * freq_scale)),
xmax=resolved_scale, color=threshhold_color,
linewidth=2, linestyle="--"
)
label = f"{name}: {1/resolved_scale:.0f} {units} "
self.ax.scatter(
resolved_scale, threshhold,
color=threshhold_color, marker=".",
linewidth=5, label=label,
zorder=3
)
def plot_score(
self,
da,
freq_scale=1.0,
units=None,
threshhold: float=0.5,
threshhold_color="k",
name="",
**kwargs
):
self.plot_both(da=da, freq_scale=freq_scale, units=units, **kwargs)
self._add_score(
da=da,
freq_scale=freq_scale,
units=units,
threshhold=threshhold,
threshhold_color=threshhold_color,
name=name
)
psd_iso_plot = PlotPSDScoreIsotropic()
psd_iso_plot.init_fig()
psd_iso_plot.plot_score(
ds_psd_score,
freq_scale=1,
units="km",
name="DUACS",
color="black",
threshhold=0.5,
threshhold_color="tab:blue"
)
plt.legend()
plt.show()
No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.