Bi-Level Optimization

How to think about modern 4DVar formulations

CNRS
MEOM

In early sections, we identified some estimation problems that we encountered. For example, we can do state estimation and parameter estimation. However, how can we deal with estimating the state and parameters simultaneously? Plug-in-Play priors is one way to handle this as we can simply pre-train our model on historical observations and then do zero-shot learning.

This is where bi-level optimization comes into play. There is a lot of theory that looks at how we can learn the state and the parameters.


Formulation

Outer-Level Objective:θ=argmin θL(θ,x(θ))Inner-Level Objective:x(θ)=argmin xU(x,θ)\begin{aligned} \text{Outer-Level Objective}: && && \boldsymbol{\theta}^* &= \underset{\boldsymbol{\theta}}{\text{argmin }} \mathcal{L}(\boldsymbol{\theta},\mathbf{x}^*(\boldsymbol{\theta})) \\ \text{Inner-Level Objective}: && && \mathbf{x}^*(\boldsymbol{\theta}) &= \underset{\mathbf{x}}{\text{argmin }} \mathcal{U}(\mathbf{x},\boldsymbol{\theta}) \end{aligned}

Argmin Differentiation

L(θ,x(θ))θ=L(θ,x(θ))xU(x,θ)θ\frac{\partial \mathcal{L}(\boldsymbol{\theta},\mathbf{x}^*(\boldsymbol{\theta}))}{\partial \boldsymbol{\theta}} = \frac{\partial \mathcal{L}(\boldsymbol{\theta},\mathbf{x}^*(\boldsymbol{\theta}))}{\partial \mathbf{x}} \frac{\partial \mathcal{U}(\mathbf{x},\boldsymbol{\theta})}{\partial \boldsymbol{\theta}}

So the question is how do we get the derivative of the inner objective wrt the parameters θ\boldsymbol{\theta}

U(x,θ)θ\frac{\partial \mathcal{U}(\mathbf{x},\boldsymbol{\theta})}{\partial \boldsymbol{\theta}}

Unrolling

We assume the solution at the end of the unrolling is the solution to the minimization problem

xk(θ)x(θ)x_k(\theta) \approx x^*(\theta)

We define a loss function which takes this value

Lk(θ):=f(θ,xk(θ))\mathcal{L}_k(\theta) := f(\theta, x_k(\theta))

We can define the process for unrolling. Let's choose an initial value, x0(θ)x_0(\theta). We define the unrolling step as

xk(θ):xk+1(θ)=xk(θ)λxU(θ,xk(θ))x_k(\theta): x_{k+1}(\theta) = x_k(\theta) - \lambda \nabla_x\mathcal{U}(\theta,x_k(\theta))

Note: We can choose whatever algorithm we want for this entire step. For example, we can choose SGD or even Adam.

Similarly, we can also choose an algorithm for how we find the parameters

θn:=θn+1=θnμθLk(θn)\theta_n^*:=\theta_{n+1} = \theta_n - \mu \nabla_\theta \mathcal{L}_k(\theta_n)

Note, we assume that the best parameter for the minimization problem, θk\theta_k^* is a good choice for the parameter estimation problem, θn\theta_n.


Implicit Differentiation

We focus on the argmin differentiation problem

x(θ)=argmin xU(x,θ)\mathbf{x}^*(\boldsymbol{\theta}) = \underset{\mathbf{x}}{\text{argmin }} \mathcal{U}(\mathbf{x},\boldsymbol{\theta})

Assumptions:

  • Strongly convex in xx
  • Smooth

Implicit Function Theorem

This states that x(θ)x^*(\theta) is a unique solution of

xU(θ,x(θ))=0\nabla_x \mathcal{U}(\theta, x^*(\theta)) = 0

Note: unrolling just does x(xk(θ))\nabla_x (x_k(\theta)). My job is to construct the solution!

Implications: This holds for all θ\theta's!

Goal: Estimation θx(θ)\nabla_\theta x^*(\theta).

Result: A Linear System!

xθU(θ,x(θ))+x2U(θ,x(θ))θx(θ)=0\partial_x \partial_\theta \mathcal{U}(\theta,x^*(\theta)) + \partial^2_x\mathcal{U}(\theta,x^*(\theta)) \nabla_\theta x^*(\theta) = 0

We can simplify this

B(θ)+A(θ)θx(θ)=0B(\theta) + A(\theta) \nabla_\theta x^*(\theta) = 0

In Theory: We can find a Jacobian by solving the linear system (in theory).

In Practice: We don't need the Hessian. We just need Hessian vector products! If we observe the term, we notice that we get

A(θ)=[Dθ×Dx]A(\theta) = [D_\theta \times D_x]

Observe: Let's look at the original loss function. Using the chain rule we get:

L(θ)=θf(θ,x(θ))+θx(θ)xf(θ,x(θ))=0\nabla \mathcal{L}(\theta) = \partial_\theta f(\theta, x^*(\theta)) + \nabla_\theta x^*(\theta)^\top \partial_x f(\theta, x^*(\theta)) = 0

And now, let's look at the linear system we want to solve

θx(θ)=[A(θ)]1B(θ)\nabla_\theta x^*(\theta) = - \left[A(\theta)\right]^{-1}B(\theta)

which is awful. So plugging this back into the equation, we get:

L(θ)=θf(θ,x(θ))+([A(θ)]1B(θ))xf(θ,x(θ))=0\nabla \mathcal{L}(\theta) = \partial_\theta f(\theta, x^*(\theta)) + \left(-\left[A(\theta)\right]^{-1}B(\theta)\right)^\top \partial_x f(\theta, x^*(\theta)) = 0

However, looking at the sizes, we notice that

(D_x x D_x)(???)(D_x)

This is hard: (D_x x D_x x ???)()

This is easy: (D_x x D_x)(????x D_x)

This is simply a vjp -> B(θ)B(\theta)^\top.


Note: We can also solve the linear system using gradient descent and Hessian products instead of pure hessians.


Computational Cost

The cost is almost the same in terms of computations.

UnrollingImplicit Differentiation
kk-steps forward (unrolling)kk-steps for optimization θx(θ)\nabla_\theta x^*(\theta)
kk-steps backwards (backprop unrolling)kk-steps for linear system opt (AxbAx-b)

The cost is better for memory because we don't have to do unrolling!

Approximate Soln or Gradient

Lk(θ)L(θ)\mathcal{L}_k(\theta) \approx \nabla\mathcal{L}(\theta)

Do we approximate the solution (unrolling) or approximate the gradient (Implicit Diff)

Warm Starts

We can do a warmstart for the linear system optimization.

x(θ)=[A(θ)]1B(θ)x^*(\theta)=-[A(\theta)]^{-1}B(\theta)

by starting from an already good solution. (Is this called pre-conditioning?)

Strongly Convex Solution

Not really... (Michael Work says overparameterized systems converges faster!)