Plotting maps with ocean bench

Plotting maps with ocean bench#

from pathlib import Path
if Path('.').absolute().name == 'notebooks':
    %cd ..
/raid/localscratch/qfebvre/oceanbench
import yaml
import inspect
from IPython.display import Markdown, display
from omegaconf import OmegaConf
import hvplot
import hvplot.xarray
hvplot.extension('matplotlib')
import hydra

def pprint_cfg(cfg):
    display(Markdown("""```yaml\n\n""" +yaml.dump(OmegaConf.to_container(cfg), default_flow_style=None, indent=2)+"""\n\n```"""))

def get_cfg(cfg_path, v=True):
    with hydra.initialize('../config', version_base='1.3'):
        cfg = hydra.compose(cfg_path)
    if v: pprint_cfg(cfg)
    return cfg
task = 'osse_gf_nadir'
method = '4dvarnet'

lb_cfg = get_cfg(f'task/{task}/leaderboard', v=False)
lb_cfg.method = method
lb = hydra.utils.call(lb_cfg.outputs)
hydra.utils.call(lb_cfg.outputs.eval_ds.dataarrays.ref.inp)()
<xarray.DataArray 'ssh' (time: 365, lat: 781, lon: 1721)>
[490596865 values with dtype=float64]
Coordinates:
  * lon      (lon) float64 -79.0 -78.95 -78.9 -78.85 -78.8 ... 6.85 6.9 6.95 7.0
  * lat      (lat) float64 26.0 26.05 26.1 26.15 26.2 ... 64.85 64.9 64.95 65.0
  * time     (time) datetime64[ns] 2012-10-01T12:00:00 ... 2013-09-30T12:00:00
Attributes:
    units:          m
    standard_name:  sea_surface_height
    long_name:      Sea Surface Height
list(lb.plots.maps)
['ssh_study',
 'ssh_ref',
 'ke_study',
 'ke_ref',
 'vort_r_study',
 'vort_r_ref',
 'strain_study',
 'strain_ref']
plot_ds = lb.plots.build_ds.maps()

    
def trim(plots, ncols=4):
    nrows = len(plots)/ncols
    trimed_plots = []
    for i, p in enumerate(plots):
        img, contour = dict(p.items()).values()
        
        img = img.opts(
            colorbar=(i%ncols) == (ncols-1),
            xticks= None if (i >= ( len(plots) - ncols)) else False,
            xlabel= None if  (i >= ( len(plots) - ncols)) else '',
            ylabel= None if  (i%ncols) ==0 else '',
            yticks= None if  (i%ncols) ==0 else False,
        )    
        trimed_plots.append((img * contour).opts(**p.opts.get().kwargs))
    return functools.reduce(operator.add, trimed_plots).opts(sublabel_format='').opts(vspace=0.1, hspace=0.1)

def plot_both(quantity):
    
    return (
        lb.plots.maps[f'{quantity}_study'](plot_ds).opts(title=method) +
         lb.plots.maps[f'{quantity}_ref'](plot_ds).opts(title='NATL60')
    ).opts(sublabel_format='')
    
plot_ds
<xarray.Dataset>
Dimensions:  (time: 42, lat: 200, lon: 200)
Coordinates:
  * time     (time) datetime64[ns] 2012-10-22 2012-10-23 ... 2012-12-02
  * lon      (lon) float64 -65.0 -64.95 -64.9 -64.85 ... -55.15 -55.1 -55.05
  * lat      (lat) float64 33.0 33.05 33.1 33.15 33.2 ... 42.8 42.85 42.9 42.95
Data variables:
    ref      (time, lat, lon) float64 nan 0.6639 0.6693 ... -0.2023 -0.2091
    study    (time, lat, lon) float64 nan 0.6685 0.6754 ... -0.1738 -0.1766
quantity = 'ssh'
plot_both(quantity).opts(title=f"Quantity: {quantity}   Task: {task}")
quantity = 'ke'
plot_both(quantity).opts(title=f"Quantity: {quantity}   Task: {task}")
quantity = 'vort_r'
plot_both(quantity).opts(title=f"Quantity: {quantity}   Task: {task}")
quantity = 'strain'
plot_both(quantity).opts(title=f"Quantity: {quantity}   Task: {task}")
list(lb_cfg.results.outputs.plots.maps.methods)
['duacs',
 'bfn',
 'miost',
 'dymost',
 'nerf_ffn',
 'nerf_siren',
 'nerf_mlp',
 '4dvarnet']
quantity = 'vort_r'
plots=[]
for method in list(lb_cfg.results.outputs.plots.maps.methods):
    print(method)    
    lb_cfg = get_cfg(f'task/{task}/leaderboard', v=False)
    lb_cfg.method = method
    lb = hydra.utils.call(lb_cfg.outputs)
    plot_ds = lb.plots.build_ds.maps()
    plots.append(lb.plots.maps[f'{quantity}_study'](plot_ds).opts(title=method))
duacs
bfn
miost
dymost
nerf_ffn
nerf_siren
nerf_mlp
4dvarnet
import functools
import operator
functools.reduce(operator.add, plots).opts(sublabel_format='')
    
def trim(plots, ncols=4):
    nrows = len(plots)/ncols
    trimed_plots = []
    for i, p in enumerate(plots):
        img, contour = dict(p.items()).values()
        
        img = img.opts(
            colorbar=(i%ncols) == (ncols-1),
            xticks= None if (i >= ( len(plots) - ncols)) else False,
            xlabel= None if  (i >= ( len(plots) - ncols)) else '',
            ylabel= None if  (i%ncols) ==0 else '',
            yticks= None if  (i%ncols) ==0 else False,
        )    
        trimed_plots.append((img * contour).opts(**p.opts.get().kwargs))
    return functools.reduce(operator.add, trimed_plots).opts(sublabel_format='').opts(vspace=0.1, hspace=0.1)
trim(plots, ncols=4).opts(title="Different methods reconstructions on a single task")
quantity = 'vort_r'
method = '4dvarnet'
plots=[]
for task in [
    'osse_gf_nadir',
    'osse_gf_nadirswot',
    'osse_gf_nadir_sst',
    'ose_gf'
]:
    print(task)    
    lb_cfg = get_cfg(f'task/{task}/leaderboard', v=False)
    lb_cfg.method = method
    lb = hydra.utils.call(lb_cfg.outputs)
    plot_ds = lb.plots.build_ds.maps()
    plots.append(lb.plots.maps[f'{quantity}_study'](plot_ds).opts(title=task))
osse_gf_nadir
osse_gf_nadirswot
osse_gf_nadir_sst
ose_gf
trim(plots, ncols=4).opts(title="4dvarnet method reconstructions on different tasks")