Observations
We assume that we have the following representation:
y obs = x gt + ϵ \mathbf{y}_{\text{obs}} = \mathbf{x}_{\text{gt}} + \epsilon y obs = x gt + ϵ We assume they are the true state, x gt \mathbf{x}_\text{gt} x gt , corrupted by some noise, ε.
Prior
We have some way to generate samples.
x init ∼ P init \mathbf{x}_\text{init} \sim P_\text{init} x init ∼ P init e.g. we have an inexpensive physical model where we can generate samples from. We could also use a cheap emulator of the physical model to draw samples.
Prior vs. Observations
We need to relate the prior , x gt \mathbf{x}_\text{gt} x gt , and the observations, y gt \mathbf{y}_\text{gt} y gt .
We assume that there is some function, f \boldsymbol{f} f , that maps the observations, y obs \mathbf{y}_\text{obs} y obs , and prior, x init \mathbf{x}_\text{init} x init , to the true state, x gt \mathbf{x}_\text{gt} x gt .
x gt = f ( x init , y obs ; θ ) \mathbf{x}_\text{gt} = \boldsymbol{f}(\mathbf{x}_\text{init}, \mathbf{y}_\text{obs}; \boldsymbol{\theta}) x gt = f ( x init , y obs ; θ ) Note :
Data (Unsupervised)
The first case is when we do not have any realization of the ground truth.
D unsupervised = { x init ( i ) , y obs ( i ) } \mathcal{D}_{\text{unsupervised}} = \left\{ \mathbf{x}_{\text{init}^(i)}, \mathbf{y}_{\text{obs}^(i)}\right\} D unsupervised = { x init ( i ) , y obs ( i ) } In this scenario, we do not have
Data (Supervised)
The second case assumes that we do not have any realizations of the true state, we only have our prior and posterior.
D supervised = { x init ( i ) , x gt ( i ) , x obs ( i ) } \mathcal{D}_{\text{supervised}} = \left\{ \mathbf{x}_{\text{init}^(i)},\; \mathbf{x}_{\text{gt}^(i)},\; \mathbf{x}_{\text{obs}^(i)}\right\} D supervised = { x init ( i ) , x gt ( i ) , x obs ( i ) } Note : This is very useful in the case of pre-training a model on emulated data.
Minimization Problem
g ( x init , y obs , x gt ; θ ) = f ( x init , y obs ; θ ) − x gt \boldsymbol{g}(\mathbf{x}_\text{init}, \mathbf{y}_\text{obs}, \mathbf{x}_\text{gt}; \boldsymbol{\theta}) = \boldsymbol{f}(\mathbf{x}_\text{init}, \mathbf{y}_\text{obs}; \boldsymbol{\theta}) - \mathbf{x}_\text{gt} g ( x init , y obs , x gt ; θ ) = f ( x init , y obs ; θ ) − x gt Minimization Strategy
We are interested in finding the fixed-point solution such that
x ( k + 1 ) = f ( x ( k ) , y obs ) \mathbf{x}^{(k+1)} = \boldsymbol{f}(\mathbf{x}^{(k)}, \mathbf{y}_{\text{obs}} ) x ( k + 1 ) = f ( x ( k ) , y obs ) In this task, we are looking to minimize the above function wrt to the inputs, x init \mathbf{x}_\text{init} x init , but also with the best parameters, θ \boldsymbol{\theta} θ .
x ∗ ( θ ) = arg min x g ( x init , y obs , x gt ; θ ) \mathbf{x}^*(\boldsymbol{\theta}) = \argmin_{\mathbf{x}} \; \boldsymbol{g}(\mathbf{x}_\text{init}, \mathbf{y}_\text{obs}, \mathbf{x}_\text{gt}; \boldsymbol{\theta}) x ∗ ( θ ) = x arg min g ( x init , y obs , x gt ; θ ) Loss Function (Generic)
Given N N N samples of data pairs, D = { x init ( i ) , y obs ( i ) } i = 1 N \mathcal{D} = \left\{ \mathbf{x}_{\text{init}^{(i)}}, \mathbf{y}_{\text{obs}^{(i)}}\right\}_{i=1}^N D = { x init ( i ) , y obs ( i ) } i = 1 N , we have a energy function, U \boldsymbol{U} U which represents the generic inverse problem.
U ( x init ( i ) , y obs ( i ) ) = L Data ( x init ( i ) , y obs ( i ) ) + λ R ( x init ( i ) ) \boldsymbol{U}(\mathbf{x}_{\text{init}^{(i)}}, \mathbf{y}_{\text{obs}^{(i)}}) = \mathcal{L}_\text{Data}(\mathbf{x}_{\text{init}^{(i)}}, \mathbf{y}_{\text{obs}^{(i)}}) + \lambda \mathcal{R}(\mathbf{x}_{\text{init}^{(i)}}) U ( x init ( i ) , y obs ( i ) ) = L Data ( x init ( i ) , y obs ( i ) ) + λ R ( x init ( i ) ) Loss Function (4D Variational)
Problem Setting ¶ Let’s take some observations, y \mathbf{y} y , which are sparse and incomplete. We are interested in finding a state, x \mathbf{x} x , that best matches the observations and fills in the missing observations.
This is a learning problem of the reconstruction error for the observed data. It is given by:
θ ∗ = arg min θ ∑ i N ∣ ∣ y i − I ( U ( x ; θ ) , y i , Ω i ) ∣ ∣ Ω i 2 \theta^* = \argmin_{\theta}\sum_i^N || \mathbf{y}_i - \boldsymbol{I}(\boldsymbol{U}(\mathbf{x};\theta), \mathbf{y}_i, \Omega_i)||_{\Omega_i}^2 θ ∗ = θ arg min i ∑ N ∣∣ y i − I ( U ( x ; θ ) , y i , Ω i ) ∣ ∣ Ω i 2 where:
∣ ∣ ⋅ ∣ ∣ Ω 2 ||\cdot||_\Omega^2 ∣∣ ⋅ ∣ ∣ Ω 2 - L2 norm evaluated on the subdomain Ω.U ( ⋅ ) \boldsymbol{U}(\cdot) U ( ⋅ ) - the energy functionI ( ⋅ ) \boldsymbol{I}(\cdot) I ( ⋅ ) - the solution to the interpolation problemInterpolation Problem, I \boldsymbol{I} I ¶ This tries to solve the interpolation problem of finding the best state, x \mathbf{x} x , given the observations, y \mathbf{y} y , on a subdomain, Ω \boldsymbol{\Omega} Ω . This involves minimizing some energy function, U \boldsymbol{U} U .
x ∗ = arg min x U ( x , y , θ , Ω ) ≔ I ( ) \mathbf{x}^* = \argmin_{\mathbf{x}} \boldsymbol{U}(\mathbf{x},\mathbf{y},\boldsymbol{\theta}, \boldsymbol{\Omega}) \coloneqq \boldsymbol{I}() x ∗ = x arg min U ( x , y , θ , Ω ) : = I ( ) They use a fixed point method to solve this problem.
Domain ¶ The first term in the loss function is the observation term defined as:
∣ ∣ ⋅ ∣ ∣ Ω 2 ||\cdot ||_{\Omega}^2 ∣∣ ⋅ ∣ ∣ Ω 2 This is the evaluation of the quadratic norm restricted to the domain, Ω.
Pseudo-Code
# do some computation
x = ...
# fill the solution with observations
x[mask] = y[mask]
Operators ¶ ODE/PDE ¶ Constrained ¶ ψ ( x ; θ ) = \boldsymbol{\psi}(\mathbf{x};\boldsymbol{\theta}) = ψ ( x ; θ ) = Optimization ¶ Fixed Point Algorithm ¶ We are interested in minimizing this function.
x ∗ = arg min x U ( x , y , Ω , θ ) s.t. y = \mathbf{x}^* = \argmin_{\mathbf{x}} \boldsymbol{U}(\mathbf{x},\mathbf{y}, \Omega, \boldsymbol{\theta}) \text{ s.t. } \mathbf{y}_{} = x ∗ = x arg min U ( x , y , Ω , θ ) s.t. y = Psuedo-Code
# initialize x
x = x_init
# loop through number of iterations
for k in range(n_iterations):
# update sigma point method.
x = fn(x)
# update via known observations
x[mask] = y[mask]
Project-Based Iterative Update ¶ DINEOF, Alvera-Azcarate et. al. (2016) DINCAE, Barth et. al. (2020) Projection
We will use our function, ϕ \boldsymbol{\phi} ϕ , to map the data one iteration, k k k .
x ~ ( k + 1 ) = ψ ( x ( k ) ; θ ) \tilde{\mathbf{x}}^{(k+1)} = \boldsymbol{\psi}(\mathbf{x}^{(k)}; \boldsymbol{\theta}) x ~ ( k + 1 ) = ψ ( x ( k ) ; θ ) Update Observed Domain
We will update the true state, x \mathbf{x} x , where we have observations, y \mathbf{y} y . This is given by the Ω function (i.e. a mask).
x ( k + 1 ) ( Ω ) = y ( Ω ) \mathbf{x}^{(k+1)}(\boldsymbol{\Omega}) = \mathbf{y}(\boldsymbol{\Omega}) x ( k + 1 ) ( Ω ) = y ( Ω ) Update Unobserved Domain
x ( k + 1 ) ( Ω ˉ ) = x ~ ( k + 1 ) ( Ω ˉ ) \mathbf{x}^{(k+1)}(\bar{\boldsymbol \Omega}) = \tilde{\mathbf{x}}^{(k+1)}(\bar{\boldsymbol \Omega}) x ( k + 1 ) ( Ω ˉ ) = x ~ ( k + 1 ) ( Ω ˉ ) Pseudo-code
def update(x: Array, y: Array, mask: Array, params: pytree, phi_fn: Callable):
x = mask * y + phi_fn(x, params) * (1 - mask)
return x
Gradient-Based Iterative Update ¶ Let U ( x , y , Ω , θ ) : R D → R \boldsymbol{U}(\mathbf{x}, \mathbf{y},\boldsymbol{\Omega},\boldsymbol{\theta}) : \mathbb{R}^D \rightarrow \mathbb{R} U ( x , y , Ω , θ ) : R D → R be the energy function.
Gradient Step
x ~ ( k + 1 ) = x ( k ) − λ ∇ x ( k ) U ( x ( k ) , y , Ω , θ ) \tilde{\mathbf{x}}^{(k+1)} = \mathbf{x}^{(k)} - \lambda \boldsymbol{\nabla}_{\mathbf{x}^{(k)}}\boldsymbol{U}(\mathbf{x}^{(k)}, \mathbf{y}, \boldsymbol{\Omega}, \boldsymbol{\theta}) x ~ ( k + 1 ) = x ( k ) − λ ∇ x ( k ) U ( x ( k ) , y , Ω , θ ) Update Observed Domain
x ( k + 1 ) ( Ω ) = y ( Ω ) \mathbf{x}^{(k+1)}(\boldsymbol{\Omega}) = \mathbf{y}(\boldsymbol{\Omega}) x ( k + 1 ) ( Ω ) = y ( Ω ) Update Unobserved Domain
x ( k + 1 ) ( Ω ˉ ) = x ~ ( k + 1 ) ( Ω ˉ ) \mathbf{x}^{(k+1)}(\bar{\boldsymbol \Omega}) = \tilde{\mathbf{x}}^{(k+1)}(\bar{\boldsymbol \Omega}) x ( k + 1 ) ( Ω ˉ ) = x ~ ( k + 1 ) ( Ω ˉ ) Pseudo-Code
def energy_fn(
x, Array[Batch, Dims],
y: Array[Batch, Dims],
mask: Array[Batch, Dims],
params: pytree,
alpha_prior: float=0.01,
alpha_obs: float=0.99
) -> float:
loss_obs = np.mean(mask * (x - y) ** 2)
loss_prior = np.mean((phi(x, params) - x))
total_loss = alpha_obs * loss_obs + alpha_prior * loss_prior
return total_loss
def update(
x: Array[Batch, Dims],
y: Array[Batch, Dims],
mask: Array[Batch, Dims],
params: pytree,
energy_fn: Callable,
alpha: float
) -> Array[Batch, Dims]:
x = x - alpha * jax.grad(energy_fn)(x, y, mask, params)
return x
NN-Interpolator Iterative Update ¶ Let N N ( x , y ; θ ) \boldsymbol{NN}(\mathbf{x}, \mathbf{y};\boldsymbol{\theta}) NN ( x , y ; θ ) be an arbitrary NN function.
Projection
x ˙ ( k ) = λ ∇ x ( k ) U ( x ( k ) , y , Ω , θ ) \dot{\mathbf{x}}^{(k)} = \lambda \boldsymbol{\nabla}_{\mathbf{x}^{(k)}}\boldsymbol{U}(\mathbf{x}^{(k)}, \mathbf{y}, \boldsymbol{\Omega}, \boldsymbol{\theta}) x ˙ ( k ) = λ ∇ x ( k ) U ( x ( k ) , y , Ω , θ ) Gradient Step
x ~ ( k ) = x ( k ) − N N ( x ( k ) , x ˙ ( k ) ; θ ) \tilde{\mathbf{x}}^{(k)} = \mathbf{x}^{(k)} - \boldsymbol{NN} \left( \mathbf{x}^{(k)}, \dot{\mathbf{x}}^{(k)}; \boldsymbol{\theta}\right) x ~ ( k ) = x ( k ) − NN ( x ( k ) , x ˙ ( k ) ; θ ) Update Observed Domain
x ( k + 1 ) ( Ω ) = y ( Ω ) \mathbf{x}^{(k+1)}(\boldsymbol{\Omega}) = \mathbf{y}(\boldsymbol{\Omega}) x ( k + 1 ) ( Ω ) = y ( Ω ) Update Unobserved Domain
x ( k + 1 ) ( Ω ˉ ) = x ~ ( k + 1 ) ( Ω ˉ ) \mathbf{x}^{(k+1)}(\bar{\boldsymbol \Omega}) = \tilde{\mathbf{x}}^{(k+1)}(\bar{\boldsymbol \Omega}) x ( k + 1 ) ( Ω ˉ ) = x ~ ( k + 1 ) ( Ω ˉ ) Example:
N N ( x ( k ) ; θ , x ~ ( k + 1 ) ) = N N ~ ( x ( k ) − x ~ ( k + 1 ) ; θ ) \boldsymbol{NN} \left( \mathbf{x}^{(k); \boldsymbol{\theta}}, \tilde{\mathbf{x}}^{(k+1)}\right) = \tilde{\boldsymbol{NN}} \left( \mathbf{x}^{(k)} - \tilde{\mathbf{x}}^{(k+1)}; \boldsymbol{\theta}\right) NN ( x ( k ) ; θ , x ~ ( k + 1 ) ) = NN ~ ( x ( k ) − x ~ ( k + 1 ) ; θ ) LSTM ¶ Pseudo-Code
def update(
x: Array[Batch, Dims],
y: Array[Batch, Dims],
mask: Array[Batch, Dims],
params: pytree,
hidden_params: Tuple[Array[Dims]],
alpha: float,
energy_fn: Callable,
rnn_fn: Callable,
activation_fn: Callable
) -> Array[Batch, Dims]:
# gradient - variational cost
x_g = alpha * jax.grad(energy_fn)(x, y, mask, params)
# NN Gradient update
g = rnn_fn(x_g, hidden_params)
x = x - activation_fn(g)
return x
CNN ¶
activation_fn = lambda x: tanh(x)
def update(
x: Array[Batch, Dims],
y: Array[Batch, Dims],
mask: Array[Batch, Dims],
x_g: Array[Batch, Dims],
params: pytree,
alpha: float,
energy_fn: Callable,
nn_fn: Callable,
activation_fn: Callable = lambda x: tanh(x)
) -> Array[Batch, Dims]:
# gradient - variational cost
x_g_new = alpha * jax.grad(energy_fn)(x, y, mask, params)
# NN Gradient update
x_g = jnp.concatenate([x_g_new, x_g])
g = nn_fn(x_g)
x = x - activation_fn(g)
return x
Deterministic ¶ x ( k + 1 ) = f ( x ( k ) , y obs ) \mathbf{x}^{(k+1)} = \boldsymbol{f}(\mathbf{x}^{(k)}, \mathbf{y}_\text{obs}) x ( k + 1 ) = f ( x ( k ) , y obs ) Loss
L = ∣ ∣ x − f ( x , y ) ∣ ∣ 2 2 \mathcal{L} = ||\mathbf{x} - \boldsymbol{f}(\mathbf{x},\mathbf{y})||_2^2 L = ∣∣ x − f ( x , y ) ∣ ∣ 2 2 Probabilistic ¶ p ( x ( k + 1 ) ∣ x ( k ) ) = N ( x ( k ) ∣ μ θ ( x ( k ) ) , σ θ 2 ( x ) ) p(\mathbf{x}^{(k+1)}|\mathbf{x}^{(k)}) = \mathcal{N}(\mathbf{x}^{(k)}| \boldsymbol{\mu}_{\boldsymbol \theta}(\mathbf{x}^{(k)}), \boldsymbol{\sigma}^2_{\boldsymbol \theta}(\mathbf{x})) p ( x ( k + 1 ) ∣ x ( k ) ) = N ( x ( k ) ∣ μ θ ( x ( k ) ) , σ θ 2 ( x )) Loss
− log p ( y ∣ x ) = 1 2 log σ θ 2 ( x ) + ( y − μ θ ( x ) ) 2 σ θ 2 ( x ) + constant -\log p(y|\mathbf{x}) = \frac{1}{2}\log \boldsymbol{\sigma}^2_{\boldsymbol \theta}(\mathbf{x}) + \frac{(y - \boldsymbol{\mu}_{\boldsymbol \theta}(\mathbf{x}))}{2\boldsymbol{\sigma}^2_{\boldsymbol \theta}(\mathbf{x})} + \text{constant} − log p ( y ∣ x ) = 2 1 log σ θ 2 ( x ) + 2 σ θ 2 ( x ) ( y − μ θ ( x )) + constant where:
y ∈ R y \in \mathbb{R} y ∈ R x ∈ R D \mathbf{x} \in \mathbb{R}^{D} x ∈ R D − log p ( y ∣ x ) = 1 2 log ∣ det Σ θ ( x ) ∣ + ∣ ∣ y − μ θ ( x ) ∣ ∣ Σ θ ( x ) 2 + constant -\log p(\mathbf{y}|\mathbf{x}) = \frac{1}{2} \log|\det \boldsymbol{\Sigma}_{\boldsymbol \theta}(\mathbf{x}) | + ||\mathbf{y} - \boldsymbol{\mu}_{\boldsymbol \theta}(\mathbf{x})||^2_{\boldsymbol{\Sigma}_{\boldsymbol \theta}(\mathbf{x})} + \text{constant} − log p ( y ∣ x ) = 2 1 log ∣ det Σ θ ( x ) ∣ + ∣∣ y − μ θ ( x ) ∣ ∣ Σ θ ( x ) 2 + constant where:
y ∈ R D y \mathbf{y} \in \mathbb{R}^{D_y} y ∈ R D y x ∈ R D x \mathbf{x} \in \mathbb{R}^{D_x} x ∈ R D x μ θ : R D x → R D y \boldsymbol{\mu}_{\boldsymbol \theta}:\mathbb{R}^{D_x} \rightarrow \mathbb{R}^{D_y} μ θ : R D x → R D y Σ θ : R D x → R D y × D y \boldsymbol{\Sigma}_{\boldsymbol \theta}:\mathbb{R}^{D_x} \rightarrow \mathbb{R}^{D_y \times D_y} Σ θ : R D x → R D y × D y Other Perspectives ¶ Explicit vs Implicit Model ¶ Explicit Model ¶ x gt = f ( x init ) \mathbf{x}_\text{gt} = \boldsymbol{f}(\mathbf{x}_\text{init}) x gt = f ( x init ) Penalize the Loss
L = L model ( x gt , x ~ init ) + L data ( x gt ( Ω ) , y obs ( Ω ) ) \mathcal{L} = \mathcal{L}_\text{model}(\mathbf{x}_\text{gt}, \tilde{\mathbf{x}}_\text{init}) + \mathcal{L}_\text{data}(\mathbf{x}_\text{gt}(\boldsymbol{\Omega}), \mathbf{y}_\text{obs}(\boldsymbol{\Omega})) L = L model ( x gt , x ~ init ) + L data ( x gt ( Ω ) , y obs ( Ω )) Conditional Explicit Model ¶ We assume some function maps both the observations and the initial condition to the solution.
x gt = f ( x init , y obs ) \mathbf{x}_\text{gt} = \boldsymbol{f}(\mathbf{x}_\text{init}, \mathbf{y}_\text{obs}) x gt = f ( x init , y obs ) The difference here is that we do not add any extra terms into the loss because it is explicit.
Implicit Model ¶ We assume
x gt g ( x init , y obs ) = 0
\begin{aligned}
\mathbf{x}_\text{gt}\boldsymbol{g}(\mathbf{x}_\text{init}, \mathbf{y}_\text{obs}) = 0
\end{aligned} x gt g ( x init , y obs ) = 0 where:
g ( x init , y obs ) = f ( x init , y obs ) − x init \boldsymbol{g}(\mathbf{x}_\text{init}, \mathbf{y}_\text{obs}) = \boldsymbol{f}(\mathbf{x}_\text{init}, \mathbf{y}_\text{obs}) - \mathbf{x}_\text{init} g ( x init , y obs ) = f ( x init , y obs ) − x init This is the same as a fixed-point solution.
x gt = x init x gt = f ( x gt , y obs ) \begin{aligned}
\mathbf{x}_\text{gt} &= \mathbf{x}_\text{init}\\
\mathbf{x}_\text{gt} &= \boldsymbol{f}(\mathbf{x}_\text{gt}, \mathbf{y}_\text{obs})
\end{aligned} x gt x gt = x init = f ( x gt , y obs ) Literature ¶ Learning Latent Dynamics for Partially Observed Chaotic Systems - Joint Interpolation and Representation Learning for Irregularly Sampled Satellite-Derived Geophysics - Learning Variational Data Assimilation Models and Solvers - Intercomparison of Data-Driven and Learning-Based Interpolations of Along-Track Nadir and Wide-Swath SWOT Altimetry Observations - Variational Deep Learning for the Identification and Reconstruction of Chaotic and Stochastic Dynamical Systems from Noisy and Partial Observations - Nguyen et. al. (2020) - Paper | Code