In this case, we try to directly parameterize the field.

Data Representation

D={(xn,tn),yn}n=1N\mathcal{D} =\{ (\mathbf{x}_n,t_n), \boldsymbol{y}_n \}_{n=1}^N
# get dataset (Coordinate based)
ds: xr.Dataset = ...

# convert to ML-ready tensor
x: Array["N"] = ds.x.values
t: Array["N"] = ds.t.values
y: Array["N D"] = ds.variable.values


We assume that there is some underlying model which can be inferred from the spatiotemporal coordinates.

yn=f(xn,tn;θ)+εn,εnN(0,σ2) \begin{aligned} \boldsymbol{y}_n &= \boldsymbol{f}(\mathbf{x}_n, t_n; \boldsymbol{\theta}) + \varepsilon_n, && && \varepsilon_n \sim \mathcal{N}(0,\sigma^2) \end{aligned}

In other words, we can write this as

ynN(ynfθ(xn,tn),σ2)y_n \sim \mathcal{N} \left(y_n|\boldsymbol{f_\theta}(\mathbf{x}_n, t_n), \sigma^2 \right)
# initialize parameterized function
params: Params = ...
fn: Model = Model(params)

# apply function
y_pred: Array["N D"] = fn(x, t)

# initial gaussian conditional likelihood
sigma: float = ...
model: ProbModel = Gaussian(mean_fn=fn, variance=sigma**2)

# apply model N(y|f(x,t),sigma)
y_pred: Dist = model(x,t)
y_pred_mean: Array["N Dy"] = y_pred.mean()
y_pred_var: Array["N Dy"] = y_pred.variance()
y_pred_samples: Array["N Dy"] = y_pred.sample(N=10)


We can learn some underlying parameterization by finding the best parameters given some loss function, L\mathcal{L}.

θ=argminθL(θ)\boldsymbol{\theta}^* = \underset{\boldsymbol{\theta}}{\text{argmin}}\hspace{2mm}\mathcal{L}(\boldsymbol{\theta})

We can minimize the data likelihood

L(θ;D)=1DnDlogp(ynf(xn,tn),σ2)\mathcal{L}(\boldsymbol{\theta};\mathcal{D}) = \frac{1}{\mathcal{D}}\sum_{n\in\mathcal{D}}\log p(\boldsymbol{y}_n|\boldsymbol{f}(\mathbf{x}_n,t_n),\sigma^2)
# calculate log probability
loss: Array[""] = model.log_prob(y_true)

# create loss function
loss_fn: Callable = lambda rv_y, y_true: rv_y.log_prob(y_true)

To train, we can use any

θk+1=θkαθL(θ)\boldsymbol{\theta}^{k+1} = \boldsymbol{\theta}^{k} - \alpha\boldsymbol{\nabla_\theta}\mathcal{L}(\boldsymbol{\theta})
# initialize criteria and training regime
loss: Loss = MSE()
optimizer: Optimizer = SGD(learning_rate=0.1)

# learn parameters
params: PyTree = fit_model(
    data=[[x,t], y],
    optimizer=optimizer, loss=loss


yn=f(xn,tn;θ)xΩzRDs\begin{aligned} \boldsymbol{y}_n &= \boldsymbol{f}(\mathbf{x}_n, t_n; \boldsymbol{\theta}) && && \mathbf{x}\in\Omega_z\subseteq\mathbb{R}^{D_s} \end{aligned}
# get coordinates for new domain
x_new: Array["M"] = ...
t_new: Array["M"] = ...

# apply model
y_pred: Array["M D"] = model(x_new, t_new, params)

There are many upgrades we can do:

  • Improved Loss function
  • Conditional Model, fθ(x,t,μ)\boldsymbol{f_\theta}(x,t,\mu)
  • Heterogeneous Noise Model, ε(x,t)\boldsymbol{\varepsilon}(\mathbf{x},t)

Latent Interpolator


Prior Dist.:zpθ(z)=N(0,1)Conditional Likelihood Dist.:ypθ(yx,t,z)=N(yfθ(x,t,z),σ2)\begin{aligned} \text{Prior Dist.}: && && \boldsymbol{z} &\sim p_\theta(\boldsymbol{z}) = \mathcal{N}(\mathbf{0}, \mathbf{1}) \\ \text{Conditional Likelihood Dist.}: && && \boldsymbol{y} &\sim p_\theta(\boldsymbol{y}|\mathbf{x},t,\boldsymbol{z})= \mathcal{N}\left(\boldsymbol{y}|\boldsymbol{f_\theta}(\mathbf{x},t, \boldsymbol{z}), \sigma^2\right) \end{aligned}
# init prior model
mean: Array["Dz"] = zeros_like(...)
sigma: Array["Dz"] = ones_like(...)
prior_model: Dist = DiagGaussian(mean=mean, sigma=sigma)

z_samples: Array["N Dz"] = prior_model.sample(N=...)

# init mean function - nerf
mean_fn: Model = init_nerf(...)

x: Array["Ds"] = ...
t: Array[""] = ...
y_pred: Array["Dy"] = mean_fn(x,t,z)

# init likelihood model w/ parameterized mean fn
sigma: float = ...
likelihood_model: Dist = CondGaussian(mean_fn=fn, scale=sigma)

y_pred: Dist = likelihood_model(x, t, z)
y_pred_mean: Array["Dy"] = y_pred.mean()
y_pred_var: Array["Dy"] = y_pred.variance()
y_pred_samples: Array["N Dy"] = y_pred.sample(N=...)
log_prob: Array["Dy"] = y_pred.log_prob(y)


We are interested in finding the best latent variable, zz, that fits the data, D\mathcal{D}. The posterior is given by:

p(zy,x,t)=1Zp(yx,t,z)p(z)p(z|y,x,t) = \frac{1}{Z}p(y|x,t,z)p(z)

We can write the criteria as the KL-Divergence between the variational distribution, qq, and the posterior, p(zy)p(z|y).

q(z)=argminqQDKL[q(z)p(zy)]q^*(\boldsymbol{z}) = \underset{q\in\mathcal{Q}}{\text{argmin}}\hspace{2mm} D_{KL}\left[ q(\boldsymbol{z})||p(\boldsymbol{z}|\boldsymbol{y}) \right]

The general criteria is given by the ELBO which is an upper bound on the KLD between the variational distribution, qq, and the prior, pp.

L(θ,ϕ)=Ezqϕ[logpθ(yz)logpθ(z)+logqϕ(z)]\mathcal{L}(\theta,\phi) = \mathbb{E}_{z\sim q_\phi} \left[ \log p_\theta(y|z) - \log p_\theta(z) + \log q_\phi(z)\right]

However, we’re going to use the Flow-based objective function which is given by:

L(θ,ϕ)=Ezqθ(z)[logpθ(z)]+Ezqθ(z)[logpθ(yz)qθ(z)]\mathcal{L}(\theta,\phi) = \mathbb{E}_{\boldsymbol{z}\sim q_\theta(\boldsymbol{z})} \left[ \log p_\theta(\boldsymbol{z})\right] + \mathbb{E}_{\boldsymbol{z}\sim q_\theta(\boldsymbol{z})} \left[ \log \frac{p_\theta(\boldsymbol{y}|\boldsymbol{z})}{q_\theta(\boldsymbol{z})}\right]


Forward Transformation

(y,x,t)p(D)zq\begin{aligned} (y,x,t)&\sim p(\mathcal{D}) \\ \boldsymbol{z} \sim q \end{aligned}

First, we need to sample from the variational distribution

zqθ(z)\boldsymbol{z}\sim q_{\boldsymbol{\theta}}(\boldsymbol{z})
# init variational dist
var_dist: Dist = DiagonalGaussian(mean=..., sigma=...)

# sample from variational dist, z ~ q(z)
z_sample: Array["N Dz"] = var_dist.sample(N=...)

Now, we can calculate the log probability of the variational dist and the likelihood terms.

Ezqθ(z)[logpθ(yx,t,z)logqθ(z)]\mathbb{E}_{\boldsymbol{z}\sim q_\theta(\boldsymbol{z})} \left[ \log p_\theta(\boldsymbol{y}|\mathbf{x},t,\boldsymbol{z}) - \log q_\theta(\boldsymbol{z})\right]

# sample from data distribution,
y_sample: Array["Dy"], x: Array["Ds"], t: Array[""] = Data(N=...)

# likelihood model, log p(y|x,t,z)
log_py: Array["Dy"] = likelihood_model(x, t, z_sample).log_prob(y_sample)
# var dist, log q(z)
log_pz: Array["Dy"] = var_dst.log_prob(z_sample)
ldj: Array["Dy"] = log_py - log_qz

z_sample: Array["Dz"] = ...
# sample
y_sample: Array["Dy"] = likelihood_model(x, t, z_sample).sample(N=...)


All Together


Amortized Model

Data Representation

# initialize domain
domain: Domain = ...

# initialize values
y_values: Array["Dx Dy"] = ...

# initialize field
y: Field["Dx Dy"] = Field(y_values, domain)


In this case, we assume that there is some underlying generative model that can be inferred

Decoder:y=TD(z;θ)+ε\begin{aligned} \text{Decoder}: && && \boldsymbol{y} &= \boldsymbol{T_D}(\boldsymbol{z}; \boldsymbol{\theta}) + \varepsilon \end{aligned}
Decoder:yp(yz;θe)=N(yTD(z;θe),σ2)\begin{aligned} \text{Decoder}: && && \boldsymbol{y} &\sim p(\boldsymbol{y}|\boldsymbol{z};\boldsymbol{\theta}_e) = \mathcal{N}(\boldsymbol{y}|\boldsymbol{T_D}(\boldsymbol{z};\boldsymbol{\theta}_e),\sigma^2) \\ \end{aligned}
Encoder:z=TE(y;θ)\begin{aligned} \text{Encoder}: && && \boldsymbol{z} &= \boldsymbol{T_E}(\boldsymbol{y}; \boldsymbol{\theta}) \end{aligned}
# initialize decoder model
decoder_fn: Model = ...

# initialize
sigma: float = ...
prob_model: CondModel = CondGaussian(mean_fn=decoder_fn, variance=sigma**2)

# apply encoder N(z|f(y),sigma)
z: Array["Dz"] = prob_model.mean(context=y)
z: Array["Dz"] = prob_model.variance(context=y)
z: Array["N Dz"] = prob_model.sample(context=y, N=10)

# calculate loss
loss: Array[""] = prob_model.log_prob(context=y, x=z)

We have a constraint

y=TDTE(y)\boldsymbol{y} = \boldsymbol{T_D}\circ\boldsymbol{T_E}(\boldsymbol{y})

# initialize both
model: Model = EncoderDecoder(encoder=encoder_model, decoder=decoder_model)
params: Params = [encoder_params, decoder_params]

# apply model
y_pred: Array["Dx Dy"] = model(y, params)


L(θ)=12σ21Nn=1N(ynTDTE(y))2\mathcal{L}(\boldsymbol{\theta}) = \frac{1}{2\sigma^2}\frac{1}{N} \sum_{n=1}^N \left(\boldsymbol{y}_n - \boldsymbol{T_D}\circ\boldsymbol{T_E}(\boldsymbol{y}) \right)^2


# get observations
y_obs: Array["Dx Dy"] = ...

# apply model
y_pred: Array["Dx Dy"] = model(y_obs, params)

There are many improvements to this model that we can do:

  • Use a simplified linear model - PCA
  • Improved Loss Function
  • Stochastic Encoder