Skip to article frontmatterSkip to article content

4DVarNet

CNRS
MEOM

Observations

We assume that we have the following representation:

yobs=xgt+ϵ\mathbf{y}_{\text{obs}} = \mathbf{x}_{\text{gt}} + \epsilon

We assume they are the true state, xgt\mathbf{x}_\text{gt}, corrupted by some noise, ε.


Prior

We have some way to generate samples.

xinitPinit\mathbf{x}_\text{init} \sim P_\text{init}

e.g. we have an inexpensive physical model where we can generate samples from. We could also use a cheap emulator of the physical model to draw samples.


Prior vs. Observations

We need to relate the prior, xgt\mathbf{x}_\text{gt}, and the observations, ygt\mathbf{y}_\text{gt}.

We assume that there is some function, f\boldsymbol{f}, that maps the observations, yobs\mathbf{y}_\text{obs}, and prior, xinit\mathbf{x}_\text{init}, to the true state, xgt\mathbf{x}_\text{gt}.

xgt=f(xinit,yobs;θ)\mathbf{x}_\text{gt} = \boldsymbol{f}(\mathbf{x}_\text{init}, \mathbf{y}_\text{obs}; \boldsymbol{\theta})

Note:


Data (Unsupervised)

The first case is when we do not have any realization of the ground truth.

Dunsupervised={xinit(i),yobs(i)}\mathcal{D}_{\text{unsupervised}} = \left\{ \mathbf{x}_{\text{init}^(i)}, \mathbf{y}_{\text{obs}^(i)}\right\}

In this scenario, we do not have


Data (Supervised)

The second case assumes that we do not have any realizations of the true state, we only have our prior and posterior.

Dsupervised={xinit(i),  xgt(i),  xobs(i)}\mathcal{D}_{\text{supervised}} = \left\{ \mathbf{x}_{\text{init}^(i)},\; \mathbf{x}_{\text{gt}^(i)},\; \mathbf{x}_{\text{obs}^(i)}\right\}

Note: This is very useful in the case of pre-training a model on emulated data.


Minimization Problem

g(xinit,yobs,xgt;θ)=f(xinit,yobs;θ)xgt\boldsymbol{g}(\mathbf{x}_\text{init}, \mathbf{y}_\text{obs}, \mathbf{x}_\text{gt}; \boldsymbol{\theta}) = \boldsymbol{f}(\mathbf{x}_\text{init}, \mathbf{y}_\text{obs}; \boldsymbol{\theta}) - \mathbf{x}_\text{gt}

Minimization Strategy

We are interested in finding the fixed-point solution such that

x(k+1)=f(x(k),yobs)\mathbf{x}^{(k+1)} = \boldsymbol{f}(\mathbf{x}^{(k)}, \mathbf{y}_{\text{obs}} )

In this task, we are looking to minimize the above function wrt to the inputs, xinit\mathbf{x}_\text{init}, but also with the best parameters, θ\boldsymbol{\theta}.

x(θ)=arg minx  g(xinit,yobs,xgt;θ)\mathbf{x}^*(\boldsymbol{\theta}) = \argmin_{\mathbf{x}} \; \boldsymbol{g}(\mathbf{x}_\text{init}, \mathbf{y}_\text{obs}, \mathbf{x}_\text{gt}; \boldsymbol{\theta})

Loss Function (Generic)

Given NN samples of data pairs, D={xinit(i),yobs(i)}i=1N\mathcal{D} = \left\{ \mathbf{x}_{\text{init}^{(i)}}, \mathbf{y}_{\text{obs}^{(i)}}\right\}_{i=1}^N, we have a energy function, U\boldsymbol{U} which represents the generic inverse problem.

U(xinit(i),yobs(i))=LData(xinit(i),yobs(i))+λR(xinit(i))\boldsymbol{U}(\mathbf{x}_{\text{init}^{(i)}}, \mathbf{y}_{\text{obs}^{(i)}}) = \mathcal{L}_\text{Data}(\mathbf{x}_{\text{init}^{(i)}}, \mathbf{y}_{\text{obs}^{(i)}}) + \lambda \mathcal{R}(\mathbf{x}_{\text{init}^{(i)}})

Loss Function (4D Variational)


Problem Setting

Let’s take some observations, y\mathbf{y}, which are sparse and incomplete. We are interested in finding a state, x\mathbf{x}, that best matches the observations and fills in the missing observations.

This is a learning problem of the reconstruction error for the observed data. It is given by:

θ=arg minθiNyiI(U(x;θ),yi,Ωi)Ωi2\theta^* = \argmin_{\theta}\sum_i^N || \mathbf{y}_i - \boldsymbol{I}(\boldsymbol{U}(\mathbf{x};\theta), \mathbf{y}_i, \Omega_i)||_{\Omega_i}^2

where:

  • Ω2||\cdot||_\Omega^2 - L2 norm evaluated on the subdomain Ω.
  • U()\boldsymbol{U}(\cdot) - the energy function
  • I()\boldsymbol{I}(\cdot) - the solution to the interpolation problem

Interpolation Problem, I\boldsymbol{I}

This tries to solve the interpolation problem of finding the best state, x\mathbf{x}, given the observations, y\mathbf{y}, on a subdomain, Ω\boldsymbol{\Omega}. This involves minimizing some energy function, U\boldsymbol{U}.

x=arg minxU(x,y,θ,Ω)I()\mathbf{x}^* = \argmin_{\mathbf{x}} \boldsymbol{U}(\mathbf{x},\mathbf{y},\boldsymbol{\theta}, \boldsymbol{\Omega}) \coloneqq \boldsymbol{I}()

They use a fixed point method to solve this problem.


Domain

The first term in the loss function is the observation term defined as:

Ω2||\cdot ||_{\Omega}^2

This is the evaluation of the quadratic norm restricted to the domain, Ω.

Pseudo-Code


# do some computation
x = ... 

# fill the solution with observations
x[mask] = y[mask]

Operators

ODE/PDE

Constrained

ψ(x;θ)=\boldsymbol{\psi}(\mathbf{x};\boldsymbol{\theta}) =


Optimization

Fixed Point Algorithm

We are interested in minimizing this function.

x=arg minxU(x,y,Ω,θ) s.t. y=\mathbf{x}^* = \argmin_{\mathbf{x}} \boldsymbol{U}(\mathbf{x},\mathbf{y}, \Omega, \boldsymbol{\theta}) \text{ s.t. } \mathbf{y}_{} =

Psuedo-Code

# initialize x
x = x_init

# loop through number of iterations
for k in range(n_iterations):

    # update sigma point method.
    x = fn(x)

    # update via known observations
    x[mask] = y[mask]

Project-Based Iterative Update

  • DINEOF, Alvera-Azcarate et. al. (2016)
  • DINCAE, Barth et. al. (2020)

Projection

We will use our function, ϕ\boldsymbol{\phi}, to map the data one iteration, kk.

x~(k+1)=ψ(x(k);θ)\tilde{\mathbf{x}}^{(k+1)} = \boldsymbol{\psi}(\mathbf{x}^{(k)}; \boldsymbol{\theta})

Update Observed Domain

We will update the true state, x\mathbf{x}, where we have observations, y\mathbf{y}. This is given by the Ω function (i.e. a mask).

x(k+1)(Ω)=y(Ω)\mathbf{x}^{(k+1)}(\boldsymbol{\Omega}) = \mathbf{y}(\boldsymbol{\Omega})

Update Unobserved Domain

x(k+1)(Ωˉ)=x~(k+1)(Ωˉ)\mathbf{x}^{(k+1)}(\bar{\boldsymbol \Omega}) = \tilde{\mathbf{x}}^{(k+1)}(\bar{\boldsymbol \Omega})

Pseudo-code

def update(x: Array, y: Array, mask: Array, params: pytree, phi_fn: Callable):

    x = mask * y + phi_fn(x, params) * (1 - mask)

    return x

Gradient-Based Iterative Update

Let U(x,y,Ω,θ):RDR\boldsymbol{U}(\mathbf{x}, \mathbf{y},\boldsymbol{\Omega},\boldsymbol{\theta}) : \mathbb{R}^D \rightarrow \mathbb{R} be the energy function.

Gradient Step

x~(k+1)=x(k)λx(k)U(x(k),y,Ω,θ)\tilde{\mathbf{x}}^{(k+1)} = \mathbf{x}^{(k)} - \lambda \boldsymbol{\nabla}_{\mathbf{x}^{(k)}}\boldsymbol{U}(\mathbf{x}^{(k)}, \mathbf{y}, \boldsymbol{\Omega}, \boldsymbol{\theta})

Update Observed Domain

x(k+1)(Ω)=y(Ω)\mathbf{x}^{(k+1)}(\boldsymbol{\Omega}) = \mathbf{y}(\boldsymbol{\Omega})

Update Unobserved Domain

x(k+1)(Ωˉ)=x~(k+1)(Ωˉ)\mathbf{x}^{(k+1)}(\bar{\boldsymbol \Omega}) = \tilde{\mathbf{x}}^{(k+1)}(\bar{\boldsymbol \Omega})

Pseudo-Code

def energy_fn(
    x, Array[Batch, Dims], 
    y: Array[Batch, Dims], 
    mask: Array[Batch, Dims], 
    params: pytree,
    alpha_prior: float=0.01,
    alpha_obs: float=0.99
    ) -> float:

    loss_obs = np.mean(mask * (x - y) ** 2)

    loss_prior = np.mean((phi(x, params) - x))

    total_loss = alpha_obs * loss_obs + alpha_prior * loss_prior

    return total_loss
def update(
    x: Array[Batch, Dims], 
    y: Array[Batch, Dims], 
    mask: Array[Batch, Dims], 
    params: pytree, 
    energy_fn: Callable,
    alpha: float
    ) -> Array[Batch, Dims]:
    

    x = x - alpha * jax.grad(energy_fn)(x, y, mask, params)

    return x

NN-Interpolator Iterative Update

Let NN(x,y;θ)\boldsymbol{NN}(\mathbf{x}, \mathbf{y};\boldsymbol{\theta}) be an arbitrary NN function.

Projection

x˙(k)=λx(k)U(x(k),y,Ω,θ)\dot{\mathbf{x}}^{(k)} = \lambda \boldsymbol{\nabla}_{\mathbf{x}^{(k)}}\boldsymbol{U}(\mathbf{x}^{(k)}, \mathbf{y}, \boldsymbol{\Omega}, \boldsymbol{\theta})

Gradient Step

x~(k)=x(k)NN(x(k),x˙(k);θ)\tilde{\mathbf{x}}^{(k)} = \mathbf{x}^{(k)} - \boldsymbol{NN} \left( \mathbf{x}^{(k)}, \dot{\mathbf{x}}^{(k)}; \boldsymbol{\theta}\right)

Update Observed Domain

x(k+1)(Ω)=y(Ω)\mathbf{x}^{(k+1)}(\boldsymbol{\Omega}) = \mathbf{y}(\boldsymbol{\Omega})

Update Unobserved Domain

x(k+1)(Ωˉ)=x~(k+1)(Ωˉ)\mathbf{x}^{(k+1)}(\bar{\boldsymbol \Omega}) = \tilde{\mathbf{x}}^{(k+1)}(\bar{\boldsymbol \Omega})

Example:

NN(x(k);θ,x~(k+1))=NN~(x(k)x~(k+1);θ)\boldsymbol{NN} \left( \mathbf{x}^{(k); \boldsymbol{\theta}}, \tilde{\mathbf{x}}^{(k+1)}\right) = \tilde{\boldsymbol{NN}} \left( \mathbf{x}^{(k)} - \tilde{\mathbf{x}}^{(k+1)}; \boldsymbol{\theta}\right)

LSTM

Pseudo-Code

def update(
    x: Array[Batch, Dims], 
    y: Array[Batch, Dims], 
    mask: Array[Batch, Dims], 
    params: pytree, 
    hidden_params: Tuple[Array[Dims]],
    alpha: float,
    energy_fn: Callable,
    rnn_fn: Callable,
    activation_fn: Callable
    ) -> Array[Batch, Dims]:

    # gradient - variational cost
    x_g = alpha * jax.grad(energy_fn)(x, y, mask, params)

    # NN Gradient update
    g = rnn_fn(x_g, hidden_params)

    x = x - activation_fn(g)

    return x
CNN

activation_fn = lambda x: tanh(x)


def update(
    x: Array[Batch, Dims], 
    y: Array[Batch, Dims], 
    mask: Array[Batch, Dims], 
    x_g: Array[Batch, Dims],
    params: pytree, 
    alpha: float,
    energy_fn: Callable,
    nn_fn: Callable,
    activation_fn: Callable = lambda x: tanh(x)
    ) -> Array[Batch, Dims]:

    # gradient - variational cost
    x_g_new = alpha * jax.grad(energy_fn)(x, y, mask, params)

    # NN Gradient update
    x_g = jnp.concatenate([x_g_new, x_g])

    g = nn_fn(x_g)

    x = x - activation_fn(g)

    return x

Stochastic Transformation

Deterministic

x(k+1)=f(x(k),yobs)\mathbf{x}^{(k+1)} = \boldsymbol{f}(\mathbf{x}^{(k)}, \mathbf{y}_\text{obs})

Loss

L=xf(x,y)22\mathcal{L} = ||\mathbf{x} - \boldsymbol{f}(\mathbf{x},\mathbf{y})||_2^2

Probabilistic

p(x(k+1)x(k))=N(x(k)μθ(x(k)),σθ2(x))p(\mathbf{x}^{(k+1)}|\mathbf{x}^{(k)}) = \mathcal{N}(\mathbf{x}^{(k)}| \boldsymbol{\mu}_{\boldsymbol \theta}(\mathbf{x}^{(k)}), \boldsymbol{\sigma}^2_{\boldsymbol \theta}(\mathbf{x}))

Loss

logp(yx)=12logσθ2(x)+(yμθ(x))2σθ2(x)+constant-\log p(y|\mathbf{x}) = \frac{1}{2}\log \boldsymbol{\sigma}^2_{\boldsymbol \theta}(\mathbf{x}) + \frac{(y - \boldsymbol{\mu}_{\boldsymbol \theta}(\mathbf{x}))}{2\boldsymbol{\sigma}^2_{\boldsymbol \theta}(\mathbf{x})} + \text{constant}

where:

  • yRy \in \mathbb{R}
  • xRD\mathbf{x} \in \mathbb{R}^{D}

logp(yx)=12logdetΣθ(x)+yμθ(x)Σθ(x)2+constant-\log p(\mathbf{y}|\mathbf{x}) = \frac{1}{2} \log|\det \boldsymbol{\Sigma}_{\boldsymbol \theta}(\mathbf{x}) | + ||\mathbf{y} - \boldsymbol{\mu}_{\boldsymbol \theta}(\mathbf{x})||^2_{\boldsymbol{\Sigma}_{\boldsymbol \theta}(\mathbf{x})} + \text{constant}

where:

  • yRDy\mathbf{y} \in \mathbb{R}^{D_y}
  • xRDx\mathbf{x} \in \mathbb{R}^{D_x}
  • μθ:RDxRDy\boldsymbol{\mu}_{\boldsymbol \theta}:\mathbb{R}^{D_x} \rightarrow \mathbb{R}^{D_y}
  • Σθ:RDxRDy×Dy\boldsymbol{\Sigma}_{\boldsymbol \theta}:\mathbb{R}^{D_x} \rightarrow \mathbb{R}^{D_y \times D_y}

Other Perspectives

Explicit vs Implicit Model

Explicit Model

xgt=f(xinit)\mathbf{x}_\text{gt} = \boldsymbol{f}(\mathbf{x}_\text{init})

Penalize the Loss

L=Lmodel(xgt,x~init)+Ldata(xgt(Ω),yobs(Ω))\mathcal{L} = \mathcal{L}_\text{model}(\mathbf{x}_\text{gt}, \tilde{\mathbf{x}}_\text{init}) + \mathcal{L}_\text{data}(\mathbf{x}_\text{gt}(\boldsymbol{\Omega}), \mathbf{y}_\text{obs}(\boldsymbol{\Omega}))

Conditional Explicit Model

We assume some function maps both the observations and the initial condition to the solution.

xgt=f(xinit,yobs)\mathbf{x}_\text{gt} = \boldsymbol{f}(\mathbf{x}_\text{init}, \mathbf{y}_\text{obs})

The difference here is that we do not add any extra terms into the loss because it is explicit.

Implicit Model

We assume

xgtg(xinit,yobs)=0 \begin{aligned} \mathbf{x}_\text{gt}\boldsymbol{g}(\mathbf{x}_\text{init}, \mathbf{y}_\text{obs}) = 0 \end{aligned}

where:

g(xinit,yobs)=f(xinit,yobs)xinit\boldsymbol{g}(\mathbf{x}_\text{init}, \mathbf{y}_\text{obs}) = \boldsymbol{f}(\mathbf{x}_\text{init}, \mathbf{y}_\text{obs}) - \mathbf{x}_\text{init}

This is the same as a fixed-point solution.

xgt=xinitxgt=f(xgt,yobs)\begin{aligned} \mathbf{x}_\text{gt} &= \mathbf{x}_\text{init}\\ \mathbf{x}_\text{gt} &= \boldsymbol{f}(\mathbf{x}_\text{gt}, \mathbf{y}_\text{obs}) \end{aligned}

Literature

  • Learning Latent Dynamics for Partially Observed Chaotic Systems -
  • Joint Interpolation and Representation Learning for Irregularly Sampled Satellite-Derived Geophysics -
  • Learning Variational Data Assimilation Models and Solvers -
  • Intercomparison of Data-Driven and Learning-Based Interpolations of Along-Track Nadir and Wide-Swath SWOT Altimetry Observations -
  • Variational Deep Learning for the Identification and Reconstruction of Chaotic and Stochastic Dynamical Systems from Noisy and Partial Observations - Nguyen et. al. (2020) - Paper | Code