4DVarNet#
Observations
We assume that we have the following representation:
We assume they are the true state,
Prior
We have some way to generate samples.
e.g. we have an inexpensive physical model where we can generate samples from. We could also use a cheap emulator of the physical model to draw samples.
Prior vs. Observations
We need to relate the prior,
We assume that there is some function,
Note:
Data (Unsupervised)
The first case is when we do not have any realization of the ground truth.
In this scenario, we do not have
Data (Supervised)
The second case assumes that we do not have any realizations of the true state, we only have our prior and posterior.
Note: This is very useful in the case of pre-training a model on emulated data.
Minimization Problem
Minimization Strategy
We are interested in finding the fixed-point solution such that
In this task, we are looking to minimize the above function wrt to the inputs,
Loss Function (Generic)
Given
Loss Function (4D Variational)
Problem Setting#
Let’s take some observations,
This is a learning problem of the reconstruction error for the observed data. It is given by:
where:
- L2 norm evaluated on the subdomain . - the energy function - the solution to the interpolation problem
Interpolation Problem, #
This tries to solve the interpolation problem of finding the best state,
They use a fixed point method to solve this problem.
Domain#
The first term in the loss function is the observation term defined as:
This is the evaluation of the quadratic norm restricted to the domain,
Pseudo-Code
# do some computation
x = ...
# fill the solution with observations
x[mask] = y[mask]
Operators#
ODE/PDE#
Constrained#
#
Optimization#
Fixed Point Algorithm#
We are interested in minimizing this function.
Psuedo-Code
# initialize x
x = x_init
# loop through number of iterations
for k in range(n_iterations):
# update sigma point method.
x = fn(x)
# update via known observations
x[mask] = y[mask]
Project-Based Iterative Update#
DINEOF, Alvera-Azcarate et. al. (2016)
DINCAE, Barth et. al. (2020)
Projection
We will use our function,
Update Observed Domain
We will update the true state,
Update Unobserved Domain
Pseudo-code
def update(x: Array, y: Array, mask: Array, params: pytree, phi_fn: Callable):
x = mask * y + phi_fn(x, params) * (1 - mask)
return x
Gradient-Based Iterative Update#
Let
Gradient Step
Update Observed Domain
Update Unobserved Domain
Pseudo-Code
def energy_fn(
x, Array[Batch, Dims],
y: Array[Batch, Dims],
mask: Array[Batch, Dims],
params: pytree,
alpha_prior: float=0.01,
alpha_obs: float=0.99
) -> float:
loss_obs = np.mean(mask * (x - y) ** 2)
loss_prior = np.mean((phi(x, params) - x))
total_loss = alpha_obs * loss_obs + alpha_prior * loss_prior
return total_loss
def update(
x: Array[Batch, Dims],
y: Array[Batch, Dims],
mask: Array[Batch, Dims],
params: pytree,
energy_fn: Callable,
alpha: float
) -> Array[Batch, Dims]:
x = x - alpha * jax.grad(energy_fn)(x, y, mask, params)
return x
NN-Interpolator Iterative Update#
Let
Projection
Gradient Step
Update Observed Domain
Update Unobserved Domain
Example:
LSTM#
Pseudo-Code
def update(
x: Array[Batch, Dims],
y: Array[Batch, Dims],
mask: Array[Batch, Dims],
params: pytree,
hidden_params: Tuple[Array[Dims]],
alpha: float,
energy_fn: Callable,
rnn_fn: Callable,
activation_fn: Callable
) -> Array[Batch, Dims]:
# gradient - variational cost
x_g = alpha * jax.grad(energy_fn)(x, y, mask, params)
# NN Gradient update
g = rnn_fn(x_g, hidden_params)
x = x - activation_fn(g)
return x
CNN#
activation_fn = lambda x: tanh(x)
def update(
x: Array[Batch, Dims],
y: Array[Batch, Dims],
mask: Array[Batch, Dims],
x_g: Array[Batch, Dims],
params: pytree,
alpha: float,
energy_fn: Callable,
nn_fn: Callable,
activation_fn: Callable = lambda x: tanh(x)
) -> Array[Batch, Dims]:
# gradient - variational cost
x_g_new = alpha * jax.grad(energy_fn)(x, y, mask, params)
# NN Gradient update
x_g = jnp.concatenate([x_g_new, x_g])
g = nn_fn(x_g)
x = x - activation_fn(g)
return x
Stochastic Transformation#
Deterministic#
Loss
Probabilistic#
Loss
where:
where:
Other Perspectives#
Explicit vs Implicit Model#
Explicit Model#
Penalize the Loss
Conditional Explicit Model#
We assume some function maps both the observations and the initial condition to the solution.
The difference here is that we do not add any extra terms into the loss because it is explicit.
Implicit Model#
We assume
$
where:
This is the same as a fixed-point solution.
Literature#
Learning Latent Dynamics for Partially Observed Chaotic Systems -
Joint Interpolation and Representation Learning for Irregularly Sampled Satellite-Derived Geophysics -
Learning Variational Data Assimilation Models and Solvers -
Intercomparison of Data-Driven and Learning-Based Interpolations of Along-Track Nadir and Wide-Swath SWOT Altimetry Observations -
Variational Deep Learning for the Identification and Reconstruction of Chaotic and Stochastic Dynamical Systems from Noisy and Partial Observations - Nguyen et. al. (2020) - Paper | Code