import autoroot # noqa: F401, I001
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
import diffrax as dfx
import xarray as xr
import numpy as np
from jaxsw import L96State, Lorenz96, rhs_lorenz_96
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
%matplotlib inline
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
%reload_ext autoreload
Lorenz 96¶
- Equation of Motion
- Observation Operator
- Integrate
F = 8
# initialize state
ndim = 50
noise = 0.01
state = L96State.init_state(ndim=ndim, noise=noise)
# rhs
x = state.x
state_dot = rhs_lorenz_96(x=x, F=F)
x_dot = state_dot
assert x_dot.shape == x.shape
Model¶
t0 = 0.0
t1 = 30.0
# initialize state
state_init, params = L96State.init_state_and_params(ndim=ndim, F=F)
# initialize model
advection = True
l96_model = Lorenz96(advection=advection)
# step through
state_dot = l96_model.equation_of_motion(t=0, state=state_init, args=params)
state_dot.x.shape
(50,)
Time Stepping¶
dt = 0.01
t0 = 0.0
t1 = 30.0
ts = jnp.arange(t0, t1, dt)
saveat = dfx.SaveAt(t0=t0, t1=t1, ts=ts)
saveat
SaveAt(
subs=SubSaveAt(
t0=0.0,
t1=30.0,
ts=f32[3000],
steps=False,
fn=<function save_y>
),
dense=False,
solver_state=False,
controller_state=False,
made_jump=False
)
# Euler, Constant StepSize
solver = dfx.Euler()
stepsize_controller = dfx.ConstantStepSize()
# integration
sol = dfx.diffeqsolve(
terms=dfx.ODETerm(l96_model.equation_of_motion),
solver=solver,
t0=ts.min(),
t1=ts.max(),
dt0=dt,
y0=state_init,
saveat=saveat,
args=params,
stepsize_controller=stepsize_controller,
)
sol
Solution(
t0=f32[],
t1=f32[],
ts=f32[3001],
ys=L96State(x=f32[3001,50]),
interpolation=None,
stats={
'max_steps':
i32[],
'num_accepted_steps':
i32[],
'num_rejected_steps':
i32[],
'num_steps':
i32[]
},
result=i32[],
solver_state=None,
controller_state=None,
made_jump=None
)
Analysis¶
sol.ys.x.shape
(3001, 50)
da_sol = xr.DataArray(
data=np.asarray(sol.ys.x),
dims=["time", "dimension"],
coords={
"dimension": (["dimension"], np.arange(0, len(sol.ys.x[0]))),
"time": (["time"], np.asarray(sol.ts)),
},
attrs={
"ode": "lorenz_96",
"F": params.F,
},
)
da_sol
Loading...
fig, ax = plt.subplots(figsize=(5, 3))
da_sol.T.plot.imshow(cmap="viridis")
ax.set_xlabel("Time")
ax.set_ylabel("Dimension")
plt.tight_layout()
plt.show()
fig, ax = plt.subplots(figsize=(5, 3), subplot_kw={"projection": "3d"})
pts = da_sol.T.plot.surface(ax=ax, add_colorbar=False, cmap="viridis")
cbar_kwargs = dict(shrink=0.3, aspect=5, cmap="viridis")
plt.colorbar(pts, **cbar_kwargs)
ax.set_xlabel("Time")
ax.set_ylabel("Dimension")
plt.tight_layout()
plt.show()
Batch of Trajectories¶
batchsize = 100
state_batch, params = L96State.init_state_and_params(
ndim=ndim, batchsize=batchsize, F=F
)
fn_batched = jax.vmap(rhs_lorenz_96, in_axes=(0, None))
state_dot_batch = fn_batched(state_batch.x, F)
x_dot = state_dot_batch
# state_dot_batch = fn_batched(state_batch)
assert x_dot.shape == state_batch.x.shape
assert state_batch.x.shape == x_dot.shape
# Euler, Constant StepSize
solver = dfx.Euler()
stepsize_controller = dfx.ConstantStepSize()
# integration
integrate = lambda state: dfx.diffeqsolve(
terms=dfx.ODETerm(l96_model.equation_of_motion),
solver=solver,
t0=ts.min(),
t1=ts.max(),
dt0=dt,
y0=state,
saveat=saveat,
args=params,
stepsize_controller=stepsize_controller,
)
sol = jax.vmap(integrate)(state_batch)
da_sol = xr.DataArray(
data=sol.ys.x,
dims=["realization", "time", "dimension"],
coords={
"realization": (["realization"], np.arange(0, len(sol.ys.x))),
"dimension": (["dimension"], np.arange(0, len(sol.ys.x[0].T))),
"time": (["time"], np.asarray(sol.ts[0])),
},
attrs={
"ode": "lorenz_96",
"F": params.F,
},
)
da_sol
Loading...
fig, ax = plt.subplots(nrows=3, figsize=(5, 8))
for i in range(3):
da_sol.isel(realization=i).T.plot.imshow(ax=ax[i], cmap="viridis")
ax[i].set_xlabel("Time")
ax[i].set_ylabel("Dimension")
plt.tight_layout()
plt.show()