Plotting maps with ocean bench#
from pathlib import Path
if Path('.').absolute().name == 'notebooks':
%cd ..
/raid/localscratch/qfebvre/oceanbench
import yaml
import inspect
from IPython.display import Markdown, display
from omegaconf import OmegaConf
import hvplot
import hvplot.xarray
hvplot.extension('matplotlib')
import hydra
def pprint_cfg(cfg):
display(Markdown("""```yaml\n\n""" +yaml.dump(OmegaConf.to_container(cfg), default_flow_style=None, indent=2)+"""\n\n```"""))
def get_cfg(cfg_path, v=True):
with hydra.initialize('../config', version_base='1.3'):
cfg = hydra.compose(cfg_path)
if v: pprint_cfg(cfg)
return cfg
task = 'osse_gf_nadir'
method = '4dvarnet'
lb_cfg = get_cfg(f'task/{task}/leaderboard', v=False)
lb_cfg.method = method
lb = hydra.utils.call(lb_cfg.outputs)
hydra.utils.call(lb_cfg.outputs.eval_ds.dataarrays.ref.inp)()
<xarray.DataArray 'ssh' (time: 365, lat: 781, lon: 1721)> [490596865 values with dtype=float64] Coordinates: * lon (lon) float64 -79.0 -78.95 -78.9 -78.85 -78.8 ... 6.85 6.9 6.95 7.0 * lat (lat) float64 26.0 26.05 26.1 26.15 26.2 ... 64.85 64.9 64.95 65.0 * time (time) datetime64[ns] 2012-10-01T12:00:00 ... 2013-09-30T12:00:00 Attributes: units: m standard_name: sea_surface_height long_name: Sea Surface Height
list(lb.plots.maps)
['ssh_study',
'ssh_ref',
'ke_study',
'ke_ref',
'vort_r_study',
'vort_r_ref',
'strain_study',
'strain_ref']
plot_ds = lb.plots.build_ds.maps()
def trim(plots, ncols=4):
nrows = len(plots)/ncols
trimed_plots = []
for i, p in enumerate(plots):
img, contour = dict(p.items()).values()
img = img.opts(
colorbar=(i%ncols) == (ncols-1),
xticks= None if (i >= ( len(plots) - ncols)) else False,
xlabel= None if (i >= ( len(plots) - ncols)) else '',
ylabel= None if (i%ncols) ==0 else '',
yticks= None if (i%ncols) ==0 else False,
)
trimed_plots.append((img * contour).opts(**p.opts.get().kwargs))
return functools.reduce(operator.add, trimed_plots).opts(sublabel_format='').opts(vspace=0.1, hspace=0.1)
def plot_both(quantity):
return (
lb.plots.maps[f'{quantity}_study'](plot_ds).opts(title=method) +
lb.plots.maps[f'{quantity}_ref'](plot_ds).opts(title='NATL60')
).opts(sublabel_format='')
plot_ds
<xarray.Dataset> Dimensions: (time: 42, lat: 200, lon: 200) Coordinates: * time (time) datetime64[ns] 2012-10-22 2012-10-23 ... 2012-12-02 * lon (lon) float64 -65.0 -64.95 -64.9 -64.85 ... -55.15 -55.1 -55.05 * lat (lat) float64 33.0 33.05 33.1 33.15 33.2 ... 42.8 42.85 42.9 42.95 Data variables: ref (time, lat, lon) float64 nan 0.6639 0.6693 ... -0.2023 -0.2091 study (time, lat, lon) float64 nan 0.6685 0.6754 ... -0.1738 -0.1766
quantity = 'ssh'
plot_both(quantity).opts(title=f"Quantity: {quantity} Task: {task}")
quantity = 'ke'
plot_both(quantity).opts(title=f"Quantity: {quantity} Task: {task}")
quantity = 'vort_r'
plot_both(quantity).opts(title=f"Quantity: {quantity} Task: {task}")
quantity = 'strain'
plot_both(quantity).opts(title=f"Quantity: {quantity} Task: {task}")
list(lb_cfg.results.outputs.plots.maps.methods)
['duacs',
'bfn',
'miost',
'dymost',
'nerf_ffn',
'nerf_siren',
'nerf_mlp',
'4dvarnet']
quantity = 'vort_r'
plots=[]
for method in list(lb_cfg.results.outputs.plots.maps.methods):
print(method)
lb_cfg = get_cfg(f'task/{task}/leaderboard', v=False)
lb_cfg.method = method
lb = hydra.utils.call(lb_cfg.outputs)
plot_ds = lb.plots.build_ds.maps()
plots.append(lb.plots.maps[f'{quantity}_study'](plot_ds).opts(title=method))
duacs
bfn
miost
dymost
nerf_ffn
nerf_siren
nerf_mlp
4dvarnet
import functools
import operator
functools.reduce(operator.add, plots).opts(sublabel_format='')
def trim(plots, ncols=4):
nrows = len(plots)/ncols
trimed_plots = []
for i, p in enumerate(plots):
img, contour = dict(p.items()).values()
img = img.opts(
colorbar=(i%ncols) == (ncols-1),
xticks= None if (i >= ( len(plots) - ncols)) else False,
xlabel= None if (i >= ( len(plots) - ncols)) else '',
ylabel= None if (i%ncols) ==0 else '',
yticks= None if (i%ncols) ==0 else False,
)
trimed_plots.append((img * contour).opts(**p.opts.get().kwargs))
return functools.reduce(operator.add, trimed_plots).opts(sublabel_format='').opts(vspace=0.1, hspace=0.1)
trim(plots, ncols=4).opts(title="Different methods reconstructions on a single task")
quantity = 'vort_r'
method = '4dvarnet'
plots=[]
for task in [
'osse_gf_nadir',
'osse_gf_nadirswot',
'osse_gf_nadir_sst',
'ose_gf'
]:
print(task)
lb_cfg = get_cfg(f'task/{task}/leaderboard', v=False)
lb_cfg.method = method
lb = hydra.utils.call(lb_cfg.outputs)
plot_ds = lb.plots.build_ds.maps()
plots.append(lb.plots.maps[f'{quantity}_study'](plot_ds).opts(title=task))
osse_gf_nadir
osse_gf_nadirswot
osse_gf_nadir_sst
ose_gf
trim(plots, ncols=4).opts(title="4dvarnet method reconstructions on different tasks")