4DVarNet#
Observations
We assume that we have the following representation:
We assume they are the true state, \(\mathbf{x}_\text{gt}\), corrupted by some noise, \(\epsilon\).
Prior
We have some way to generate samples.
e.g. we have an inexpensive physical model where we can generate samples from. We could also use a cheap emulator of the physical model to draw samples.
Prior vs. Observations
We need to relate the prior, \(\mathbf{x}_\text{gt}\), and the observations, \(\mathbf{y}_\text{gt}\).
We assume that there is some function, \(\boldsymbol{f}\), that maps the observations, \(\mathbf{y}_\text{obs}\), and prior, \(\mathbf{x}_\text{init}\), to the true state, \(\mathbf{x}_\text{gt}\).
Note:
Data (Unsupervised)
The first case is when we do not have any realization of the ground truth.
In this scenario, we do not have
Data (Supervised)
The second case assumes that we do not have any realizations of the true state, we only have our prior and posterior.
Note: This is very useful in the case of pre-training a model on emulated data.
Minimization Problem
Minimization Strategy
We are interested in finding the fixed-point solution such that
In this task, we are looking to minimize the above function wrt to the inputs, \(\mathbf{x}_\text{init}\), but also with the best parameters, \(\boldsymbol{\theta}\).
Loss Function (Generic)
Given \(N\) samples of data pairs, \(\mathcal{D} = \left\{ \mathbf{x}_{\text{init}^{(i)}}, \mathbf{y}_{\text{obs}^{(i)}}\right\}_{i=1}^N\), we have a energy function, \(\boldsymbol{U}\) which represents the generic inverse problem.
Loss Function (4D Variational)
Problem Setting#
Let’s take some observations, \(\mathbf{y}\), which are sparse and incomplete. We are interested in finding a state, \(\mathbf{x}\), that best matches the observations and fills in the missing observations.
This is a learning problem of the reconstruction error for the observed data. It is given by:
where:
\(||\cdot||_\Omega^2\) - L2 norm evaluated on the subdomain \(\Omega\).
\(\boldsymbol{U}(\cdot)\) - the energy function
\(\boldsymbol{I}(\cdot)\) - the solution to the interpolation problem
Interpolation Problem, \(\boldsymbol{I}\)#
This tries to solve the interpolation problem of finding the best state, \(\mathbf{x}\), given the observations, \(\mathbf{y}\), on a subdomain, \(\boldsymbol{\Omega}\). This involves minimizing some energy function, \(\boldsymbol{U}\).
They use a fixed point method to solve this problem.
Domain#
The first term in the loss function is the observation term defined as:
This is the evaluation of the quadratic norm restricted to the domain, \(\Omega\).
Pseudo-Code
# do some computation
x = ...
# fill the solution with observations
x[mask] = y[mask]
Operators#
ODE/PDE#
Constrained#
#
Optimization#
Fixed Point Algorithm#
We are interested in minimizing this function.
Psuedo-Code
# initialize x
x = x_init
# loop through number of iterations
for k in range(n_iterations):
# update sigma point method.
x = fn(x)
# update via known observations
x[mask] = y[mask]
Project-Based Iterative Update#
DINEOF, Alvera-Azcarate et. al. (2016)
DINCAE, Barth et. al. (2020)
Projection
We will use our function, \(\boldsymbol{\phi}\), to map the data one iteration, \(k\).
Update Observed Domain
We will update the true state, \(\mathbf{x}\), where we have observations, \(\mathbf{y}\). This is given by the \(\Omega\) function (i.e. a mask).
Update Unobserved Domain
Pseudo-code
def update(x: Array, y: Array, mask: Array, params: pytree, phi_fn: Callable):
x = mask * y + phi_fn(x, params) * (1 - mask)
return x
Gradient-Based Iterative Update#
Let \(\boldsymbol{U}(\mathbf{x}, \mathbf{y},\boldsymbol{\Omega},\boldsymbol{\theta}) : \mathbb{R}^D \rightarrow \mathbb{R}\) be the energy function.
Gradient Step
Update Observed Domain
Update Unobserved Domain
Pseudo-Code
def energy_fn(
x, Array[Batch, Dims],
y: Array[Batch, Dims],
mask: Array[Batch, Dims],
params: pytree,
alpha_prior: float=0.01,
alpha_obs: float=0.99
) -> float:
loss_obs = np.mean(mask * (x - y) ** 2)
loss_prior = np.mean((phi(x, params) - x))
total_loss = alpha_obs * loss_obs + alpha_prior * loss_prior
return total_loss
def update(
x: Array[Batch, Dims],
y: Array[Batch, Dims],
mask: Array[Batch, Dims],
params: pytree,
energy_fn: Callable,
alpha: float
) -> Array[Batch, Dims]:
x = x - alpha * jax.grad(energy_fn)(x, y, mask, params)
return x
NN-Interpolator Iterative Update#
Let \(\boldsymbol{NN}(\mathbf{x}, \mathbf{y};\boldsymbol{\theta})\) be an arbitrary NN function.
Projection
Gradient Step
Update Observed Domain
Update Unobserved Domain
Example:
LSTM#
Pseudo-Code
def update(
x: Array[Batch, Dims],
y: Array[Batch, Dims],
mask: Array[Batch, Dims],
params: pytree,
hidden_params: Tuple[Array[Dims]],
alpha: float,
energy_fn: Callable,
rnn_fn: Callable,
activation_fn: Callable
) -> Array[Batch, Dims]:
# gradient - variational cost
x_g = alpha * jax.grad(energy_fn)(x, y, mask, params)
# NN Gradient update
g = rnn_fn(x_g, hidden_params)
x = x - activation_fn(g)
return x
CNN#
activation_fn = lambda x: tanh(x)
def update(
x: Array[Batch, Dims],
y: Array[Batch, Dims],
mask: Array[Batch, Dims],
x_g: Array[Batch, Dims],
params: pytree,
alpha: float,
energy_fn: Callable,
nn_fn: Callable,
activation_fn: Callable = lambda x: tanh(x)
) -> Array[Batch, Dims]:
# gradient - variational cost
x_g_new = alpha * jax.grad(energy_fn)(x, y, mask, params)
# NN Gradient update
x_g = jnp.concatenate([x_g_new, x_g])
g = nn_fn(x_g)
x = x - activation_fn(g)
return x
Stochastic Transformation#
Deterministic#
Loss
Probabilistic#
Loss
where:
\(y \in \mathbb{R}\)
\(\mathbf{x} \in \mathbb{R}^{D}\)
where:
\(\mathbf{y} \in \mathbb{R}^{D_y}\)
\(\mathbf{x} \in \mathbb{R}^{D_x}\)
\(\boldsymbol{\mu}_{\boldsymbol \theta}:\mathbb{R}^{D_x} \rightarrow \mathbb{R}^{D_y}\)
\(\boldsymbol{\Sigma}_{\boldsymbol \theta}:\mathbb{R}^{D_x} \rightarrow \mathbb{R}^{D_y \times D_y}\)
Other Perspectives#
Explicit vs Implicit Model#
Explicit Model#
Penalize the Loss
Conditional Explicit Model#
We assume some function maps both the observations and the initial condition to the solution.
The difference here is that we do not add any extra terms into the loss because it is explicit.
Implicit Model#
We assume $\( \begin{aligned} \mathbf{x}_\text{gt}\boldsymbol{g}(\mathbf{x}_\text{init}, \mathbf{y}_\text{obs}) = 0 \end{aligned} \)$
where:
This is the same as a fixed-point solution.
Literature#
Learning Latent Dynamics for Partially Observed Chaotic Systems -
Joint Interpolation and Representation Learning for Irregularly Sampled Satellite-Derived Geophysics -
Learning Variational Data Assimilation Models and Solvers -
Intercomparison of Data-Driven and Learning-Based Interpolations of Along-Track Nadir and Wide-Swath SWOT Altimetry Observations -
Variational Deep Learning for the Identification and Reconstruction of Chaotic and Stochastic Dynamical Systems from Noisy and Partial Observations - Nguyen et. al. (2020) - Paper | Code