Lorenz 96

J. Emmanuel JohnsonTakaya Uchida
import autoroot  # noqa: F401, I001
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
import diffrax as dfx
import xarray as xr
import numpy as np

from jaxsw import L96State, Lorenz96, rhs_lorenz_96

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
%matplotlib inline
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

Lorenz 96

  • Equation of Motion
  • Observation Operator
  • Integrate

Equation of Motion

dxdt=(xi+1xi2)xi1xi+F\frac{dx}{dt} = (x_{i+1} - x_{i-2})x_{i-1}-x_i+F

where FF is normally 8 to cause some chaotic behaviour.

F = 8
# initialize state
ndim = 50
noise = 0.01
state = L96State.init_state(ndim=ndim, noise=noise)

# rhs
x = state.x
state_dot = rhs_lorenz_96(x=x, F=F)

x_dot = state_dot

assert x_dot.shape == x.shape

Model

t0 = 0.0
t1 = 30.0

# initialize state
state_init, params = L96State.init_state_and_params(ndim=ndim, F=F)

# initialize model
advection = True
l96_model = Lorenz96(advection=advection)

# step through
state_dot = l96_model.equation_of_motion(t=0, state=state_init, args=params)

state_dot.x.shape
(50,)

Time Stepping

dt = 0.01
t0 = 0.0
t1 = 30.0

ts = jnp.arange(t0, t1, dt)

saveat = dfx.SaveAt(t0=t0, t1=t1, ts=ts)
saveat
SaveAt( subs=SubSaveAt( t0=0.0, t1=30.0, ts=f32[3000], steps=False, fn=<function save_y> ), dense=False, solver_state=False, controller_state=False, made_jump=False )
# Euler, Constant StepSize
solver = dfx.Euler()
stepsize_controller = dfx.ConstantStepSize()

# integration
sol = dfx.diffeqsolve(
    terms=dfx.ODETerm(l96_model.equation_of_motion),
    solver=solver,
    t0=ts.min(),
    t1=ts.max(),
    dt0=dt,
    y0=state_init,
    saveat=saveat,
    args=params,
    stepsize_controller=stepsize_controller,
)
sol
Solution( t0=f32[], t1=f32[], ts=f32[3001], ys=L96State(x=f32[3001,50]), interpolation=None, stats={ 'max_steps': i32[], 'num_accepted_steps': i32[], 'num_rejected_steps': i32[], 'num_steps': i32[] }, result=i32[], solver_state=None, controller_state=None, made_jump=None )

Analysis

sol.ys.x.shape
(3001, 50)
da_sol = xr.DataArray(
    data=np.asarray(sol.ys.x),
    dims=["time", "dimension"],
    coords={
        "dimension": (["dimension"], np.arange(0, len(sol.ys.x[0]))),
        "time": (["time"], np.asarray(sol.ts)),
    },
    attrs={
        "ode": "lorenz_96",
        "F": params.F,
    },
)

da_sol
Loading...
fig, ax = plt.subplots(figsize=(5, 3))

da_sol.T.plot.imshow(cmap="viridis")

ax.set_xlabel("Time")
ax.set_ylabel("Dimension")

plt.tight_layout()
plt.show()
<Figure size 500x300 with 2 Axes>
fig, ax = plt.subplots(figsize=(5, 3), subplot_kw={"projection": "3d"})

pts = da_sol.T.plot.surface(ax=ax, add_colorbar=False, cmap="viridis")

cbar_kwargs = dict(shrink=0.3, aspect=5, cmap="viridis")
plt.colorbar(pts, **cbar_kwargs)

ax.set_xlabel("Time")
ax.set_ylabel("Dimension")

plt.tight_layout()
plt.show()
<Figure size 500x300 with 2 Axes>

Batch of Trajectories

batchsize = 100

state_batch, params = L96State.init_state_and_params(
    ndim=ndim, batchsize=batchsize, F=F
)

fn_batched = jax.vmap(rhs_lorenz_96, in_axes=(0, None))

state_dot_batch = fn_batched(state_batch.x, F)
x_dot = state_dot_batch
# state_dot_batch = fn_batched(state_batch)

assert x_dot.shape == state_batch.x.shape
assert state_batch.x.shape == x_dot.shape
# Euler, Constant StepSize
solver = dfx.Euler()
stepsize_controller = dfx.ConstantStepSize()

# integration
integrate = lambda state: dfx.diffeqsolve(
    terms=dfx.ODETerm(l96_model.equation_of_motion),
    solver=solver,
    t0=ts.min(),
    t1=ts.max(),
    dt0=dt,
    y0=state,
    saveat=saveat,
    args=params,
    stepsize_controller=stepsize_controller,
)
sol = jax.vmap(integrate)(state_batch)
da_sol = xr.DataArray(
    data=sol.ys.x,
    dims=["realization", "time", "dimension"],
    coords={
        "realization": (["realization"], np.arange(0, len(sol.ys.x))),
        "dimension": (["dimension"], np.arange(0, len(sol.ys.x[0].T))),
        "time": (["time"], np.asarray(sol.ts[0])),
    },
    attrs={
        "ode": "lorenz_96",
        "F": params.F,
    },
)

da_sol
Loading...
fig, ax = plt.subplots(nrows=3, figsize=(5, 8))

for i in range(3):
    da_sol.isel(realization=i).T.plot.imshow(ax=ax[i], cmap="viridis")

    ax[i].set_xlabel("Time")
    ax[i].set_ylabel("Dimension")

plt.tight_layout()
plt.show()
<Figure size 500x800 with 6 Axes>