Joint Distribution¶
p(y1:T,u1:T,θ)=p(θ)p(u0)t=1∑Tp(yt∣zt,θ) Posterior¶
p(zt∣) Loss Function¶
J(u,θ)=t=1∑Tlogp(yt∣ut,θ)+logp(z0∣θ)+logp(θ) So each of the terms:
Data Likelihood:Prior State:Prior Parameters:p(yt∣ut,θ)p(zt∣θ)p(θ)=N(yt∣ut,Σuu)=N(zt∣μz,Σz)=N(θ∣μθ,Σθ)
Algorithm¶
Data¶
Observations:D={yt}n=1Nyt∈RDy
Dynamical Model¶
∂tu=f(u,t,θ)
Adjoint Model¶
∂tλ=f∗(λ,t,u,y) def init_f_adjoint(f) -> Callable:
f = lambda u, lam: jax.vjp(f, u)(lam)[0]
return f
Initial Condition¶
We need an initial condition for the state, u0, and the parameters, θ.
Initial State:Initial Parameters:u0θ0∈RDθ∈RDθ
Integrate Forward Model¶
We need the list of time steps which match the observations
T+=[t0,t1,t2,…,tT] # time step
dt: float = 0.01
# list of time steps
time_steps: Array["Nt"] = np.arange(0, num_time_steps, dt)
We need to pass this through a solver to get our states.
U=ODESolve(f,u0,θ,T) state_sol: PyTree = ode_solve(f, u_0, params, time_steps)
So we essentially get a matrix of all of the time points output of our model.
U=[u1,u2,…,uT]∈RT×Du # extract state
u_sol: Array["Nt Du"]: state_sol.u
Integrate Backward Adjoint Model¶
Now, we need to do the opposite, we need to run through our solver in reverse using the adjoint model.
First, we need to initialize
T=[tT,tT−1,tT−2,…,t2,t1,t0] # list of inverse time steps
time_steps_reverse: Array["Nt"] = time_steps[::-1]
∂tλf∗(ut,λ,y,θ)=f∗(λ,t,u,y)=Jf⊤(ut)λ−Jh⊤(ut)Cyy−1(h(ut)−yt) Now, we iterate through each of these time steps
λt+1=f∗(ut,λt,y,θ)
Loss Function¶
State¶
First, we have the loss function for the state
∇u0L(u0,θ,λ0)=Σzz−1(u0−μz)−λ0 So a single gradient step would be
u0k+1=u0k−αB∇u0L(u0k,θ,λ0)
Parameters¶
Secondly, we have the loss function for the parameters
∇θL(u0,θ,λ0)=Σθθ−1t=1∑TJf⊤(θt)λt+1 So a single gradient step would be
θk+1=θk−αB∇θL(u0k,θ,λ0)