import autoroot # noqa: F401, I001
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
import diffrax as dfx
import xarray as xr
import numpy as np
from jaxsw import L96State, Lorenz96, rhs_lorenz_96
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
%matplotlib inline
%load_ext autoreload
%autoreload 2The autoreload extension is already loaded. To reload it, use:
%reload_ext autoreload
Lorenz 96¶
- Equation of Motion
- Observation Operator
- Integrate
F = 8# initialize state
ndim = 50
noise = 0.01
state = L96State.init_state(ndim=ndim, noise=noise)
# rhs
x = state.x
state_dot = rhs_lorenz_96(x=x, F=F)
x_dot = state_dot
assert x_dot.shape == x.shapeModel¶
t0 = 0.0
t1 = 30.0
# initialize state
state_init, params = L96State.init_state_and_params(ndim=ndim, F=F)
# initialize model
advection = True
l96_model = Lorenz96(advection=advection)
# step through
state_dot = l96_model.equation_of_motion(t=0, state=state_init, args=params)
state_dot.x.shape(50,)Time Stepping¶
dt = 0.01
t0 = 0.0
t1 = 30.0
ts = jnp.arange(t0, t1, dt)
saveat = dfx.SaveAt(t0=t0, t1=t1, ts=ts)
saveatSaveAt(
subs=SubSaveAt(
t0=0.0,
t1=30.0,
ts=f32[3000],
steps=False,
fn=<function save_y>
),
dense=False,
solver_state=False,
controller_state=False,
made_jump=False
)# Euler, Constant StepSize
solver = dfx.Euler()
stepsize_controller = dfx.ConstantStepSize()
# integration
sol = dfx.diffeqsolve(
terms=dfx.ODETerm(l96_model.equation_of_motion),
solver=solver,
t0=ts.min(),
t1=ts.max(),
dt0=dt,
y0=state_init,
saveat=saveat,
args=params,
stepsize_controller=stepsize_controller,
)solSolution(
t0=f32[],
t1=f32[],
ts=f32[3001],
ys=L96State(x=f32[3001,50]),
interpolation=None,
stats={
'max_steps':
i32[],
'num_accepted_steps':
i32[],
'num_rejected_steps':
i32[],
'num_steps':
i32[]
},
result=i32[],
solver_state=None,
controller_state=None,
made_jump=None
)Analysis¶
sol.ys.x.shape(3001, 50)da_sol = xr.DataArray(
data=np.asarray(sol.ys.x),
dims=["time", "dimension"],
coords={
"dimension": (["dimension"], np.arange(0, len(sol.ys.x[0]))),
"time": (["time"], np.asarray(sol.ts)),
},
attrs={
"ode": "lorenz_96",
"F": params.F,
},
)
da_solLoading...
fig, ax = plt.subplots(figsize=(5, 3))
da_sol.T.plot.imshow(cmap="viridis")
ax.set_xlabel("Time")
ax.set_ylabel("Dimension")
plt.tight_layout()
plt.show()
fig, ax = plt.subplots(figsize=(5, 3), subplot_kw={"projection": "3d"})
pts = da_sol.T.plot.surface(ax=ax, add_colorbar=False, cmap="viridis")
cbar_kwargs = dict(shrink=0.3, aspect=5, cmap="viridis")
plt.colorbar(pts, **cbar_kwargs)
ax.set_xlabel("Time")
ax.set_ylabel("Dimension")
plt.tight_layout()
plt.show()
Batch of Trajectories¶
batchsize = 100
state_batch, params = L96State.init_state_and_params(
ndim=ndim, batchsize=batchsize, F=F
)
fn_batched = jax.vmap(rhs_lorenz_96, in_axes=(0, None))
state_dot_batch = fn_batched(state_batch.x, F)
x_dot = state_dot_batch
# state_dot_batch = fn_batched(state_batch)
assert x_dot.shape == state_batch.x.shape
assert state_batch.x.shape == x_dot.shape# Euler, Constant StepSize
solver = dfx.Euler()
stepsize_controller = dfx.ConstantStepSize()
# integration
integrate = lambda state: dfx.diffeqsolve(
terms=dfx.ODETerm(l96_model.equation_of_motion),
solver=solver,
t0=ts.min(),
t1=ts.max(),
dt0=dt,
y0=state,
saveat=saveat,
args=params,
stepsize_controller=stepsize_controller,
)
sol = jax.vmap(integrate)(state_batch)da_sol = xr.DataArray(
data=sol.ys.x,
dims=["realization", "time", "dimension"],
coords={
"realization": (["realization"], np.arange(0, len(sol.ys.x))),
"dimension": (["dimension"], np.arange(0, len(sol.ys.x[0].T))),
"time": (["time"], np.asarray(sol.ts[0])),
},
attrs={
"ode": "lorenz_96",
"F": params.F,
},
)
da_solLoading...
fig, ax = plt.subplots(nrows=3, figsize=(5, 8))
for i in range(3):
da_sol.isel(realization=i).T.plot.imshow(ax=ax[i], cmap="viridis")
ax[i].set_xlabel("Time")
ax[i].set_ylabel("Dimension")
plt.tight_layout()
plt.show()