XArrayDataset#
import autoroot
import typing as tp
from dataclasses import dataclass
import numpy as np
import einops
import xarray_dataclasses as xrdataclass
from oceanbench._src.datasets.base import XRDAPatcher
%load_ext autoreload
%autoreload 2
This tutorial walks through some of the nice features of the custom XRDAPatcher
class.
This is a custom class that slices and dices through an xr.DataArray
where a user can specify explicitly the patch dimensions and the strides.
We preallocated the slices and then we can arbitrarily call the slices at will.
This is very similar to the torch.utils.data object except we are only working with xr.DataArray
’s directly.
There have been other previous attempts at this, e.g. xBatcher
.
However, we found the API very cumbersome and non-intuitive.
This is our attempt to design an API that we are comfortable with and that we find easy to use.
Below, we have outlined a few use-cases that users may be interested in. These use cases are:
Chunking a 1-Dimensional Time Series
Patch-ify a 2D Grid
Cube-ify a 3D Volume
Cube-ify a 2D+T Spatio-Temporal Field
Reconstructing Multiple Variables
Choosing Specific Dimensions for Reconstructions
We will walk through each of these and highlight how this can be achieved with the custom XRDAPatcher
class.
Case I: Chunking a 1D TS#
TIME = tp.Literal["time"]
@dataclass
class TimeAxis:
data: xrdataclass.Data[TIME, tp.Literal["datetime64[ns]"]]
name: xrdataclass.Name[str] = "time"
long_name: xrdataclass.Attr[str] = "Date"
@dataclass
class Variable1D:
data: xrdataclass.Data[tuple[TIME], np.float32]
time: xrdataclass.Coordof[TimeAxis] = 0
name: xrdataclass.Attr[str] = "var"
t = np.arange(1, 360+1, 1)
rng = np.random.RandomState(seed=123)
ts = np.sin(t)
ts = Variable1D(data=ts, time=t, name="var")
da = xrdataclass.asdataarray(ts)
da
# da.plot()
In this first example, we are going to do a non-overlapping style. We will take a 30 day window with a 30 day stride. This will give us exactly 12 patches (like 12 months).
patches = {"time": 30}
strides = {"time": 30}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True
xrda_batches = XRDAPatcher(
da=da,
patches=patches,
strides=strides,
check_full_scan=check_full_scan
)
print(xrda_batches)
print(f"Dataset(size): {len(xrda_batches)}")
In this example, we will incorporate overlapping windows. We will do a 30 day window but we will have a 15 day stride. So, we have a 15 day overlap when creating the patches. We can do the mental calculation already because it’s quite simple:
If this is nicely divisible, we wont have any problems.
However, often times it’s not so we might have to use the floor
operator to ensure we get integers
Our method will give a warning (optional) which lets the user know there is an issue.
patches = {"time": 30}
strides = {"time": 15}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True
xrda_batches = XRDAPatcher(
da=da,
patches=patches,
strides=strides,
check_full_scan=check_full_scan
)
print(xrda_batches)
print(f"Dataset(size): {len(xrda_batches)}")
Case II: Patchify a 2D Grid#
TIME = tp.Literal["time"]
X = tp.Literal["x"]
Y = tp.Literal["y"]
@dataclass
class TimeAxis:
data: xrdataclass.Data[TIME, tp.Literal["datetime64[ns]"]]
name: xrdataclass.Name[str] = "time"
long_name: xrdataclass.Attr[str] = "Date"
@dataclass
class XAxis:
data: xrdataclass.Data[X, np.float32]
name: xrdataclass.Name[str] = "x"
@dataclass
class YAxis:
data: xrdataclass.Data[Y, np.float32]
name: xrdataclass.Name[str] = "y"
@dataclass
class Variable2D:
data: xrdataclass.Data[tuple[X, Y], np.float32]
x: xrdataclass.Coordof[XAxis] = 0
y: xrdataclass.Coordof[YAxis] = 0
name: xrdataclass.Attr[str] = "var"
x = np.linspace(-1, 1, 128)
y = np.linspace(-2, 2, 128)
rng = np.random.RandomState(seed=123)
data = rng.randn(x.shape[0], y.shape[0])
grid = Variable2D(data=data, x=x, y=y, name="var")
da = xrdataclass.asdataarray(grid)
da
# da.T.plot.imshow()
We will have a [20,20]
patch with no overlap, [20,20]
patches = {"x": 8, "y": 8}
strides = {"x": 8, "y": 8}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True
xrda_batches = XRDAPatcher(
da=da,
patches=patches,
strides=strides,
check_full_scan=check_full_scan
)
print(xrda_batches)
print(f"Dataset(size): {len(xrda_batches)}")
We will have a [20,20]
patch with some overlap, like the boundaries of 2, [2,2]
patches = {"x": 8, "y": 8}
strides = {"x": 2, "y": 2}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True
xrda_batches = XRDAPatcher(
da=da,
patches=patches,
strides=strides,
check_full_scan=check_full_scan
)
print(xrda_batches)
print(f"Dataset(size): {len(xrda_batches)}")
Case III: Cube-ify a 3D Volume#
TIME = tp.Literal["time"]
X = tp.Literal["x"]
Y = tp.Literal["y"]
Z = tp.Literal["z"]
@dataclass
class TimeAxis:
data: xrdataclass.Data[TIME, tp.Literal["datetime64[ns]"]]
name: xrdataclass.Name[str] = "time"
long_name: xrdataclass.Attr[str] = "Date"
@dataclass
class XAxis:
data: xrdataclass.Data[X, np.float32]
name: xrdataclass.Name[str] = "x"
@dataclass
class YAxis:
data: xrdataclass.Data[Y, np.float32]
name: xrdataclass.Name[str] = "y"
@dataclass
class ZAxis:
data: xrdataclass.Data[Z, np.float32]
name: xrdataclass.Name[str] = "z"
@dataclass
class Variable3D:
data: xrdataclass.Data[tuple[X, Y, Z], np.float32]
x: xrdataclass.Coordof[XAxis] = 0
y: xrdataclass.Coordof[YAxis] = 0
z: xrdataclass.Coordof[ZAxis] = 0
name: xrdataclass.Attr[str] = "var"
x = np.linspace(-1, 1, 128)
y = np.linspace(-2, 2, 128)
z = np.linspace(-5, 5, 128)
rng = np.random.RandomState(seed=123)
data = rng.randn(x.shape[0], y.shape[0], z.shape[0])
grid = Variable3D(data=data, x=x, y=y, z=z, name="var")
da = xrdataclass.asdataarray(grid)
da
We will have a [20,20]
patch with no overlap, [20,20]
patches = {"x": 8, "y": 8, "z": 8}
strides = {"x": 8, "y": 8, "z": 8}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True
xrda_batches = XRDAPatcher(
da=da,
patches=patches,
strides=strides,
check_full_scan=check_full_scan
)
print(xrda_batches)
print(f"Dataset(size): {len(xrda_batches)}")
We will have a [20,20]
patch with some overlap, like the boundaries of 2, [2,2]
patches = {"x": 8, "y": 8, "z": 8}
strides = {"x": 2, "y": 2, "z": 2}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True
xrda_batches = XRDAPatcher(
da=da,
patches=patches,
strides=strides,
check_full_scan=check_full_scan
)
print(xrda_batches)
print(f"Dataset(size): {len(xrda_batches)}")
Case IV: Cube-ify a 2D+T Spatio-Temporal Field#
TIME = tp.Literal["time"]
X = tp.Literal["x"]
Y = tp.Literal["y"]
Z = tp.Literal["z"]
@dataclass
class TimeAxis:
data: xrdataclass.Data[TIME, tp.Literal["datetime64[ns]"]]
name: xrdataclass.Name[str] = "time"
long_name: xrdataclass.Attr[str] = "Date"
@dataclass
class XAxis:
data: xrdataclass.Data[X, np.float32]
name: xrdataclass.Name[str] = "x"
@dataclass
class YAxis:
data: xrdataclass.Data[Y, np.float32]
name: xrdataclass.Name[str] = "y"
@dataclass
class ZAxis:
data: xrdataclass.Data[Z, np.float32]
name: xrdataclass.Name[str] = "z"
@dataclass
class Variable2DT:
data: xrdataclass.Data[tuple[TIME, X, Y], np.float32]
x: xrdataclass.Coordof[XAxis] = 0
y: xrdataclass.Coordof[YAxis] = 0
time: xrdataclass.Coordof[TimeAxis] = 0
name: xrdataclass.Attr[str] = "var"
x = np.linspace(-1, 1, 200)
y = np.linspace(-2, 2, 200)
t = np.arange(1, 360+1, 1)
rng = np.random.RandomState(seed=123)
data = rng.randn(t.shape[0], x.shape[0], y.shape[0])
grid = Variable2DT(data=data, x=x, y=y, time=t, name="var")
da = xrdataclass.asdataarray(grid)
da
Now, this is a rather big field.
Let’s say we want to use some ML method with a CNN to learn how to predict …
However, ingesting this large patch would be very difficult.
So instead, we will use the standard size for many CNNs, which is [64,64]
.
In addition, we will use a temporal window of 15 days.
So the patch will be [15,64,64]
.
As will the above examples, we will also account for the overlap in the spatial borders with [4,4]
strides.
And lastly, we will have a 5 day overlap for the time steps.
So the final strides will be [5,4,4]
patches = {"x": 64, "y": 64, "time": 15}
strides = {"x": 4, "y": 4, "time": 5}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True
xrda_batches = XRDAPatcher(
da=da,
patches=patches,
strides=strides,
check_full_scan=check_full_scan
)
print(xrda_batches)
print(f"Dataset(size): {len(xrda_batches)}")
All of the sudden, we have a LOT of data if we do things in a patch-wise manner, more than 85K samples! However, we know from statistics that perhaps this isn’t the greatest idea because there are a lot of overlap. So we can be clever and use a training dataset with less overlap. However, we can create a different dataset for predictions where we reduce the strides considerably so that we take a weighted average over the predictions!
patches = {"x": 64, "y": 64, "time": 15}
strides = {"x": 1, "y": 1, "time": 1}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True
xrda_batches = XRDAPatcher(
da=da,
patches=patches,
strides=strides,
check_full_scan=check_full_scan
)
print(xrda_batches)
print(f"Dataset(size): {len(xrda_batches)}")
So this will cover use because we can take a weighted average of all of the predictions!
Case V: Reconstructing with multiple variables#
In this example, we look at how we can do reconstructions with multiple variables. This may occur when we have used different methods to make predictions and we want to reconstruct all of them.
Another example is when we have some sort of latent variable representation and we would like to reconstruct each of the latent variable representations.
t = np.arange(1, 360+1, 1)
rng = np.random.RandomState(seed=123)
ts = np.sin(t)
ts = Variable1D(data=ts, time=t, name="var")
da = xrdataclass.asdataarray(ts)
patches = {"time": 30}
strides = {"time": 30}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True
xrda_batches = XRDAPatcher(
da=da,
patches=patches,
strides=strides,
check_full_scan=check_full_scan
)
print(xrda_batches)
print(f"Dataset(size): {len(xrda_batches)}")
all_batches = list(map(lambda x: x.data, xrda_batches))
all_batches_latent = list(map(lambda x: einops.repeat(x, "... -> ... N", N=5), all_batches))
dims_labels = ["time", "z"]
weight = np.ones((patches["time"]))
rec_da = xrda_batches.reconstruct(all_batches_latent, dims_labels=dims_labels, weight=weight)
rec_da
Case VI: Choosing a Specific Dimension for Reconstruction#
x = np.linspace(-1, 1, 50)
y = np.linspace(-2, 2, 50)
t = np.arange(1, 30+1, 1)
rng = np.random.RandomState(seed=123)
data = rng.randn(t.shape[0], x.shape[0], y.shape[0])
grid = Variable2DT(data=data, x=x, y=y, time=t, name="var")
da = xrdataclass.asdataarray(grid)
da
Now, this is a rather big field.
Let’s say we want to use some ML method with a CNN to learn how to predict …
However, ingesting this large patch would be very difficult.
So instead, we will use the standard size for many CNNs, which is [64,64]
.
In addition, we will use a temporal window of 15 days.
So the patch will be [15,64,64]
.
As will the above examples, we will also account for the overlap in the spatial borders with [4,4]
strides.
And lastly, we will have a 5 day overlap for the time steps.
So the final strides will be [5,4,4]
patches = {"x": 10, "y": 10, "time": 5}
strides = {"x": 8, "y": 8, "time": 1}
domain_limits = None#{"lat": slice(-10, 10)}
check_full_scan = True
xrda_batches = XRDAPatcher(
da=da,
patches=patches,
strides=strides,
check_full_scan=check_full_scan
)
print(xrda_batches)
print(f"Dataset(size): {len(xrda_batches)}")
Here, we can reconstruct just the time series. So in this case, we will take the mean of all of the spatial values x,y and just have the time series.
all_batches = list(map(lambda x: x.mean(dim=["x", "y"]).data, xrda_batches))
But we still want to reconstruct! So we can pass these through the reconstruction but pay careful attention to the dimension we wish to reconstruct.
dims_labels = ["time"]
weight = np.ones((patches["time"]))
rec_da = xrda_batches.reconstruct(all_batches, dims_labels=dims_labels, weight=weight)
rec_da
Here, we can reconstruct just the x,y patches. So in this case, we will take the mean of all of the temporal coordinates and just have the spatial patches.
all_batches = list(map(lambda x: x.mean(dim=["time"]).data, xrda_batches))
Again, we still want to reconstruct! So, like above, we can simply pass the correct dimensions to the reconstruction.
dims_labels = ["x", "y"]
weight = np.ones((patches["x"], patches["y"]))
rec_da = xrda_batches.reconstruct(all_batches, dims_labels=dims_labels, weight=weight)
rec_da