In early sections, we identified some estimation problems that we encountered.
For example, we can do state estimation and parameter estimation.
However, how can we deal with estimating the state and parameters simultaneously?
Plug-in-Play priors is one way to handle this as we can simply pre-train our model on historical observations and then do zero-shot learning.
This is where bi-level optimization comes into play.
There is a lot of theory that looks at how we can learn the state and the parameters.
Outer-Level Objective:Inner-Level Objective:θ∗x∗(θ)=θargmin L(θ,x∗(θ))=xargmin U(x,θ)
Argmin Differentiation¶
∂θ∂L(θ,x∗(θ))=∂x∂L(θ,x∗(θ))∂θ∂U(x,θ) So the question is how do we get the derivative of the inner objective wrt the parameters θ
∂θ∂U(x,θ)
Unrolling¶
We assume the solution at the end of the unrolling is the solution to the minimization problem
xk(θ)≈x∗(θ) We define a loss function which takes this value
Lk(θ):=f(θ,xk(θ)) We can define the process for unrolling. Let's choose an initial value, x0(θ). We define the unrolling step as
xk(θ):xk+1(θ)=xk(θ)−λ∇xU(θ,xk(θ)) Note: We can choose whatever algorithm we want for this entire step. For example, we can choose SGD or even Adam.
Similarly, we can also choose an algorithm for how we find the parameters
θn∗:=θn+1=θn−μ∇θLk(θn) Note, we assume that the best parameter for the minimization problem, θk∗ is a good choice for the parameter estimation problem, θn.
Implicit Differentiation¶
We focus on the argmin differentiation problem
x∗(θ)=xargmin U(x,θ) Assumptions:
- Strongly convex in x
- Smooth
Implicit Function Theorem
This states that x∗(θ) is a unique solution of
∇xU(θ,x∗(θ))=0 Note: unrolling just does ∇x(xk(θ)). My job is to construct the solution!
Implications: This holds for all θ's!
Goal: Estimation ∇θx∗(θ).
Result: A Linear System!
∂x∂θU(θ,x∗(θ))+∂x2U(θ,x∗(θ))∇θx∗(θ)=0 We can simplify this
B(θ)+A(θ)∇θx∗(θ)=0
In Theory: We can find a Jacobian by solving the linear system (in theory).
In Practice: We don't need the Hessian. We just need Hessian vector products! If we observe the term, we notice that we get
A(θ)=[Dθ×Dx]
Observe: Let's look at the original loss function. Using the chain rule we get:
∇L(θ)=∂θf(θ,x∗(θ))+∇θx∗(θ)⊤∂xf(θ,x∗(θ))=0 And now, let's look at the linear system we want to solve
∇θx∗(θ)=−[A(θ)]−1B(θ) which is awful. So plugging this back into the equation, we get:
∇L(θ)=∂θf(θ,x∗(θ))+(−[A(θ)]−1B(θ))⊤∂xf(θ,x∗(θ))=0 However, looking at the sizes, we notice that
(D_x x D_x)(???)(D_x)
This is hard: (D_x x D_x x ???)()
This is easy: (D_x x D_x)(????x D_x)
This is simply a vjp
-> B(θ)⊤.
Note: We can also solve the linear system using gradient descent and Hessian products instead of pure hessians.
Computational Cost¶
The cost is almost the same in terms of computations.
Unrolling | Implicit Differentiation |
---|
k-steps forward (unrolling) | k-steps for optimization ∇θx∗(θ) |
k-steps backwards (backprop unrolling) | k-steps for linear system opt (Ax−b) |
The cost is better for memory because we don't have to do unrolling!
Approximate Soln or Gradient¶
Lk(θ)≈∇L(θ) Do we approximate the solution (unrolling) or approximate the gradient (Implicit Diff)
Warm Starts¶
We can do a warmstart for the linear system optimization.
x∗(θ)=−[A(θ)]−1B(θ) by starting from an already good solution. (Is this called pre-conditioning?)
Strongly Convex Solution¶
Not really... (Michael Work says overparameterized systems converges faster!)