Generating Trajectories and Missing Data

How to generate Trajectories and create missing data

CNRS
MEOM
import os, sys

jaxsw_path = "/Users/eman/code_projects/jaxsw"
sys.path.append(jaxsw_path)
import autoroot  # noqa: F401, I001
import diffrax as dfx
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import xarray as xr

from jaxsw._src.models.lorenz63 import L63Params, L63State, Lorenz63, rhs_lorenz_63

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%matplotlib inline
%load_ext autoreload
%autoreload 2

Lorenz 63

  • Equation of Motion
  • Observation Operator
  • Integrate

Equation of Motion

dxdt=σ(yx)dydt=x(ρz)ydzdt=xyβz\begin{aligned} \frac{dx}{dt} &= \sigma (y - x) \\ \frac{dy}{dt} &= x (\rho - z) - y \\ \frac{dz}{dt} &= xy - \beta z \end{aligned}

where (σ,ρ,β)(\sigma,\rho,\beta) are hyperparameters.

Simulation

batchsize = 1_000
sigma, rho, beta = 10, 28, 2.667
key = jrandom.PRNGKey(42)
params = L63Params(sigma=10, rho=28, beta=2.667)
state_init = L63State.init_state(
    batchsize=batchsize, noise=0.01, key=key
)

l63_model = Lorenz63()

state_init.x.shape
(1000, 1)

Time Stepping

Warmup

We need to do a warmup simulation to get the dynamical system to a steady state.

dt = 0.005
t0 = 0.0
t1 = 10.0
dt_save = 0.01
# observe_every = 10
ts = jnp.arange(t0, t1, dt_save)
num_tsteps = len(ts)

saveat = dfx.SaveAt(
    ts=ts,
)

# Euler, Constant StepSize
solver = dfx.Tsit5()
stepsize_controller = dfx.PIDController(rtol=1e-5, atol=1e-5)

integration_fn = lambda state: dfx.diffeqsolve(
    terms=dfx.ODETerm(l63_model.equation_of_motion),
    solver=solver,
    t0=ts.min(),
    t1=ts.max(),
    dt0=dt,
    y0=state,
    saveat=saveat,
    args=params,
    stepsize_controller=stepsize_controller,
)

# integration
sol = jax.vmap(integration_fn)(state_init)

sol.ys.x.shape
(1000, 1000, 1)

Simulation

from tqdm.autonotebook import trange
for i in trange(10):
    # integration
    state_init = eqx.tree_at(lambda x: x.x, state_init, sol.ys.x[:, -1])
    state_init = eqx.tree_at(lambda x: x.y, state_init, sol.ys.y[:, -1])
    state_init = eqx.tree_at(lambda x: x.z, state_init, sol.ys.z[:, -1])

    sol = jax.vmap(integration_fn)(state_init)
Loading...
sol.ts[0, :200].shape, sol.ys.x[:200].shape
((200,), (200, 1000, 1))
ds_sol = xr.Dataset(
    {
        "x": (("realization", "time"), sol.ys.x[:, :200].squeeze()),
        "y": (("realization", "time"), sol.ys.y[:, :200].squeeze()),
        "z": (("realization", "time"), sol.ys.z[:, :200].squeeze()),
    },
    coords={
        "time": (["time"], sol.ts[0, :200].squeeze()),
    },
    attrs={
        "ode": "lorenz_63",
        "sigma": params.sigma,
        "beta": params.beta,
        "rho": params.rho,
    },
)

ds_sol
Loading...
fig, ax = plt.subplots(nrows=3, figsize=(4, 8))

for i in range(3):
    ds_sol.x.sel(
        realization=i,
    ).plot(ax=ax[i])
    ds_sol.y.sel(
        realization=i,
    ).plot(ax=ax[i])
    ds_sol.z.sel(
        realization=i,
    ).plot(ax=ax[i])

    ax[i].set_xlabel("Time")
    ax[i].set_ylabel("Values")
    ax[i].set_title(f"Trajectory: {i}")


plt.legend()
plt.tight_layout()
plt.show()
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
<Figure size 400x800 with 3 Axes>
fig, ax = plt.subplots(ncols=3, subplot_kw={"projection": "3d"})
for i in range(3):
    ax[i].plot(
        ds_sol.x.sel(
            realization=i,
        ),
        ds_sol.y.sel(
            realization=i,
        ),
        ds_sol.z.sel(
            realization=i,
        ),
        lw=1.0,
        color="blue",
        label="Trajectory",
    )
    ax[i].set_xlabel("X")
    ax[i].set_ylabel("Y")
    ax[i].set_zlabel("Z")
# ax.scatter3D(out.x[::10], out.y[::10], out.z[::10], lw=0.5, color="red", label="y")


plt.legend()
plt.tight_layout()
plt.show()
<Figure size 640x480 with 3 Axes>

Generating Observations

ds_sol = ds_sol.to_array(dim="component", name="simulation").to_dataset()
ds_sol
Loading...

Missing Time

sample_step = 20
ds_sol["observations"] = xr.full_like(ds_sol.simulation, np.nan)
ds_sol["observations"].loc[..., ::sample_step] = ds_sol["simulation"].loc[
    ..., ::sample_step
]
fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(12, 7))

for i in range(0, 3):
    ds_sol.simulation.sel(realization=i + 100, component="x").plot(
        ax=ax[i, 0], color="black"
    )
    ds_sol.observations.sel(realization=i + 100, component="x").plot.scatter(
        ax=ax[i, 0], color="red"
    )

    ds_sol.simulation.sel(realization=i + 100, component="y").plot(
        ax=ax[i, 1], color="black"
    )
    ds_sol.observations.sel(realization=i + 100, component="y").plot.scatter(
        ax=ax[i, 1], color="red"
    )

    ds_sol.simulation.sel(realization=i + 100, component="z").plot(
        ax=ax[i, 2], color="black"
    )
    ds_sol.observations.sel(realization=i + 100, component="z").plot.scatter(
        ax=ax[i, 2], color="red"
    )


plt.legend()
plt.tight_layout()
plt.show()
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
<Figure size 1200x700 with 9 Axes>
fig, ax = plt.subplots(ncols=3, subplot_kw={"projection": "3d"})
for i in range(3):
    ax[i].plot(
        ds_sol.simulation.sel(realization=i + 100, component="x"),
        ds_sol.simulation.sel(realization=i + 100, component="y"),
        ds_sol.simulation.sel(realization=i + 100, component="z"),
        lw=1.0,
        color="blue",
        label="Trajectory",
    )
    ax[i].scatter(
        ds_sol.observations.sel(realization=i + 100, component="x"),
        ds_sol.observations.sel(realization=i + 100, component="y"),
        ds_sol.observations.sel(realization=i + 100, component="z"),
        lw=1.0,
        color="red",
        label="Observations",
    )
    ax[i].set_xlabel("X")
    ax[i].set_ylabel("Y")
    ax[i].set_zlabel("Z")


plt.legend()
plt.tight_layout()
plt.show()
<Figure size 640x480 with 3 Axes>

Adding Noise

# def add_noise(da, sigma=2**.5):
#     return da  + np.random.randn(*da.shape) * sigma

sigma = 2.0

ds_sol["observations_noise"] = ds_sol["observations"] + sigma * np.random.randn(
    *ds_sol["observations"].shape
)
fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(12, 7))

for i in range(3):
    ds_sol.simulation.sel(realization=i, component="x").plot(ax=ax[i, 0], color="black")
    ds_sol.observations_noise.sel(realization=i, component="x").plot.scatter(
        ax=ax[i, 0], color="red"
    )

    ds_sol.simulation.sel(realization=i, component="y").plot(ax=ax[i, 1], color="black")
    ds_sol.observations_noise.sel(realization=i, component="y").plot.scatter(
        ax=ax[i, 1], color="red"
    )

    ds_sol.simulation.sel(realization=i, component="z").plot(ax=ax[i, 2], color="black")
    ds_sol.observations_noise.sel(realization=i, component="z").plot.scatter(
        ax=ax[i, 2], color="red"
    )


plt.legend()
plt.tight_layout()
plt.show()
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
<Figure size 1200x700 with 9 Axes>
fig, ax = plt.subplots(ncols=3, subplot_kw={"projection": "3d"})
for i in range(3):
    ax[i].plot(
        ds_sol.simulation.sel(realization=i, component="x"),
        ds_sol.simulation.sel(realization=i, component="y"),
        ds_sol.simulation.sel(realization=i, component="z"),
        lw=1.0,
        color="blue",
        label="Trajectory",
    )
    ax[i].scatter(
        ds_sol.observations_noise.sel(realization=i, component="x"),
        ds_sol.observations_noise.sel(realization=i, component="y"),
        ds_sol.observations_noise.sel(realization=i, component="z"),
        lw=1.0,
        color="red",
        label="Observations",
    )
    ax[i].set_xlabel("X")
    ax[i].set_ylabel("Y")
    ax[i].set_zlabel("Z")


plt.legend()
plt.tight_layout()
plt.show()
<Figure size 640x480 with 3 Axes>

Initial Interpolation

import scipy
from tqdm.autonotebook import trange

variable_name = "observations_interp"

ds_sol[variable_name] = xr.full_like(ds_sol.observations_noise, np.nan)


for i in trange(ds_sol.simulation.shape[1]):
    dataarray = ds_sol["observations_noise"].sel(realization=i)

    # interpolate nans
    dataarray = dataarray.interpolate_na(
        dim="time", method="slinear", fill_value="extrapolate"
    )

    # set
    ds_sol[variable_name].loc[:, i, :] = dataarray
Loading...
fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(12, 7))

for i in range(3):
    ds_sol.simulation.sel(realization=i, component="x").plot(ax=ax[i, 0], color="black")
    ds_sol.observations_interp.sel(realization=i, component="x").plot(
        ax=ax[i, 0], color="red", linestyle="--"
    )

    ds_sol.simulation.sel(realization=i, component="y").plot(ax=ax[i, 1], color="black")
    ds_sol.observations_interp.sel(realization=i, component="y").plot(
        ax=ax[i, 1], color="red", linestyle="--"
    )

    ds_sol.simulation.sel(realization=i, component="z").plot(ax=ax[i, 2], color="black")
    ds_sol.observations_interp.sel(realization=i, component="z").plot(
        ax=ax[i, 2], color="red", linestyle="--"
    )


plt.legend()
plt.tight_layout()
plt.show()
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
<Figure size 1200x700 with 9 Axes>
fig, ax = plt.subplots(ncols=3, subplot_kw={"projection": "3d"})
for i in range(3):
    ax[i].plot(
        ds_sol.simulation.sel(realization=i, component="x"),
        ds_sol.simulation.sel(realization=i, component="y"),
        ds_sol.simulation.sel(realization=i, component="z"),
        lw=1.0,
        color="blue",
        label="Trajectory",
    )
    ax[i].plot(
        ds_sol.observations_interp.sel(realization=i, component="x"),
        ds_sol.observations_interp.sel(realization=i, component="y"),
        ds_sol.observations_interp.sel(realization=i, component="z"),
        lw=1.0,
        linestyle="--",
        color="red",
        label="Observations",
    )
    ax[i].set_xlabel("X")
    ax[i].set_ylabel("Y")
    ax[i].set_zlabel("Z")


plt.legend()
plt.tight_layout()
plt.show()
<Figure size 640x480 with 3 Axes>

Save Simulation

!mkdir ./data
ds_sol.to_netcdf("./data/sim_l63.nc")
mkdir: ./data: File exists