Variational Inference

How to think about modern 4DVar formulations

CNRS
MEOM

In this section, we will look at how we can infer the latent variables by exploiting the observed variables. We are assuming that we cannot directly measure the latent variable and instead we can only measure the observations. Following the notation throughout this book, the latent variables are z\boldsymbol{z} and the observed variables are y\boldsymbol{y}.

Our goal is to perform inference, i.e. we are interested in estimating some hidden state given some observations. We can define the posterior using Bayes rule

p(zy)=p(z,y)p(y)p(\boldsymbol{z}|\boldsymbol{y}) = \frac{p(\boldsymbol{z},\boldsymbol{y})}{p(\boldsymbol{y})}

where the numerator is the joint distribution for the latent variable and the observation and the denominator is the marginal likelihood or evidence for the observations. The joint distribution can be easy to estimate because we can generally factor this quantity using conditional distributions, i.e.

p(z,y)=p(yz)p(z)p(z,y)=p(y|z)p(z)

However the marginal lieklihood needs to be calculated by integrating out the latent variables.

p(y)=p(z,y)dzp(y) = \int p(z,y)dz

This integral is generally intractable because it would mean integrating out all possible latent variables which we don't have access to. So we need to use alternative methods to try and estimate the posterior.

We can introduce a variational distribution from a family of possible distributions Q\mathcal{Q} whereby we pick the best candidate q(z)Qq^*(z)\in\mathcal{Q} that fits the true posterior p(zx)p(z|x). In general, we want a distribution that is easy to calculate, e.g. Gaussian, Bernoulli, etc, so that we can exploit conjugacy for calculating quantities within the loss function. We could also employ a parameterized variational distribution which we would need to find given the observations, i.e. q(z;ϕ)q(z;\boldsymbol{\phi}).

To measure the similarity between our approximate posterior and the true posterior, we will use an asymmetric distance metric called the Kullback-Leibler (KL) divergence. This is given by

DKL[q(z)p(zy)]=Ezq(z)[logq(z)p(zx)]=Ezq(z)[logp(zx)q(z)]D_{KL}\left[ q(z) || p(z|y)\right] = \mathbb{E}_{z\sim q(z)} \left[ \log \frac{q(z)}{p(z|x)} \right] = \mathbb{E}_{z\sim q(z)} \left[ - \log \frac{p(z|x)}{q(z)} \right]

In our case, we would like to find the best candidate distribution qq^* st it minimizes the KL divergence

q(z)=argminqDKL[q(z)p(zy)]q^*(z) = \underset{q}{\text{argmin}} \hspace{2mm} D_{KL}\left[ q(z) || p(z|y)\right]

In the above equation, we don't have access to the true posterior so we will use the Bayes rules for the posterior (1) that we outlined earlier. We can plug the RHS of this equation into our KLD minimization problem to get

DKL[q(z)p(zy)]=Ezq(z)[logp(z,y)+logq(z)]+logp(y)D_{KL}\left[ q(z) || p(z|y)\right] = \mathbb{E}_{z\sim q(z)} \left[ - \log p(\boldsymbol{z},\boldsymbol{y}) + \log q(z) \right] + \log p(\boldsymbol{y})

The first term is the variational distribution, the middle term is the joint distribution and the right term is the intractable marginal likelihood (3) that we referenced earlier.

Let's look at the KLD measure again with the posterior.

DKL[q(z)p(zy)]=Ezq(z)[logp(zx)q(z)]D_{KL}\left[ q(z) || p(z|y)\right] = \mathbb{E}_{z\sim q(z)} \left[ - \log \frac{p(z|x)}{q(z)} \right]

First, we will use log rules to expand the ratio

DKL[q(z)p(zy)]=Ezq(z)[logp(zx)+logq(z)]D_{KL}\left[ q(z) || p(z|y)\right] = \mathbb{E}_{z\sim q(z)} \left[ - \log p(z|x) + \log q(z) \right]

Now, let's plug in the RHS of Bayes posterior outlined in (1).

DKL[q(z)p(zy)]=Ezq(z)[logp(z,y)p(y)+logq(z)]D_{KL}\left[ q(z) || p(z|y)\right] = \mathbb{E}_{z\sim q(z)} \left[ - \log \frac{p(\boldsymbol{z},\boldsymbol{y})}{p(\boldsymbol{y})} + \log q(z) \right]

Again, we use log rules to expand this term

DKL[q(z)p(zy)]=Ezq(z)[logp(z,y)+logp(y)+logq(z)]D_{KL}\left[ q(z) || p(z|y)\right] = \mathbb{E}_{z\sim q(z)} \left[ - \log p(\boldsymbol{z},\boldsymbol{y}) + \log p(\boldsymbol{y}) + \log q(z) \right]

Now, we isolate out the marginal likelihood term (3) from the rest of the equation and we get

DKL[q(z)p(zy)]=Ezq(z)[logp(z,y)+logq(z)]+Ezq(z)[logp(y)]D_{KL}\left[ q(z) || p(z|y)\right] = \mathbb{E}_{z\sim q(z)} \left[ - \log p(\boldsymbol{z},\boldsymbol{y}) + \log q(z) \right] + \mathbb{E}_{z\sim q(z)} \left[ \log p(\boldsymbol{y}) \right]

We can remove the expectation on the rightmost term because there is no dependency on the latent variable.

DKL[q(z)p(zy)]=Ezq(z)[logp(z,y)+logq(z)]+logp(y)D_{KL}\left[ q(z) || p(z|y)\right] = \mathbb{E}_{z\sim q(z)} \left[ - \log p(\boldsymbol{z},\boldsymbol{y}) + \log q(z) \right] + \log p(\boldsymbol{y})

Looking at (6), we can rearrange this equation to isolate the expectation on the LHS of the equation. This gives us

ELBO(q):=Ezq(z)[logp(z,y)+logq(z)]=logp(y)DKL[q(z)p(zy)]\begin{aligned} \text{ELBO}(q) := && \mathbb{E}_{z\sim q(z)}\left[ - \log p(\boldsymbol{z},\boldsymbol{y}) + \log q(z) \right] = \log p(\boldsymbol{y}) - D_{KL}\left[ q(z) || p(z|y)\right] \end{aligned}

which is known as the evidence lower bound (ELBO). This implies that we can maximize the quantity on the RHS of the equation which implies that we are simultaneously i) maximizing the evidence and ii) minimizing the KLD between our variational distribution and the true posterior.

3 Perspectives of the ELBO

There are three main ways to look at the ELBO depending upon the literature and application. The first one is the likelihood perspective, the second one is the flow perspective, and the last one is the variational free energy perspective. In all three cases, we first need to unpack the ELBO by expanding the joint distribution via Bayes rule outlined in (2). This gives us

ELBO(q):=Ezq(z)[logp(yz)Likelihoodlogp(z)Prior+logq(z)Variational Dist]\begin{aligned} \text{ELBO}(q) := && \mathbb{E}_{z\sim q(z)}\left[ \underbrace{\log p(\boldsymbol{y}|\boldsymbol{z})}_{\text{Likelihood}} - \underbrace{\log p(\boldsymbol{z})}_{\text{Prior}} + \underbrace{\log q(z)}_{\text{Variational Dist}} \right] \end{aligned}

Below, we outline each of the perspectives.

Data Fidelity + Prior

If we group the prior term and the variational distribution together, we get

ELBO(q):=Ezq(z)[logp(yz)]DKL[q(z)p(z)]\begin{aligned} \text{ELBO}(q) := && \mathbb{E}_{z\sim q(z)}\left[ \log p(\boldsymbol{y}|\boldsymbol{z}) \right] - D_{KL}\left[ q(\boldsymbol{z})||p(\boldsymbol{z}) \right] \end{aligned}

The first term is the reconstruction loss which measures the expectation of likelihood wrt the variational distribution. The second term is the KL-Divergence between the prior and the variational distribution. This formulation is commonly found with Latent Variable models (LVMs) and Variational Autoencoders (VAEs) [Kingma & Welling, 2013].

Volume Correction

This perspective is more in line with the idea of using transform distributions. If we group the variational distribution and the likelihood term, we get

ELBO(q):=Ezq(z)[logp(z)]+Ezq(z)[logp(yz)q(z)]\begin{aligned} \text{ELBO}(q) := && \mathbb{E}_{z\sim q(z)}\left[ \log p(\boldsymbol{z}) \right] + \mathbb{E}_{z\sim q(z)}\left[ \log \frac{p(\boldsymbol{y}|\boldsymbol{z})}{q(z)} \right] \end{aligned}

The first term is the reparameterized probability via the expectation in the transform distribution. The second term is the volume correction factor or the likelihood contribution. This formulation was (re-)introduced for the SurVae Flows paper [Nielsen et al., 2020] where they showcased generalized flows with bijective, surjective, and stochastic transformations.

Variational Free Energy

Lastly, we have the Variational Free Energy (VFE) formulation which is a very common way to motivate this using Free energy principles which is in part motivated by the Gibbs inequality. If we group the prior and the likelihood term, we get

ELBO(q):=Ezq(z)[logp(y,z)]H[q(z)]\begin{aligned} \text{ELBO}(q) := && \mathbb{E}_{z\sim q(z)}\left[ \log p(\boldsymbol{y},\boldsymbol{z}) \right] - \mathcal{H} \left[ q(z) \right] \end{aligned}

The first term is the energy function which is the variational expectation over the population loss or joint distribution. The second term is the entropy of the variational distribution which acts as a regularization on the overall complexity of the distribution. This formulation is common in the Bayesian Learning Rule (BLR) literature [Khan & Rue (2021)Kıral et al. (2023)] as well as the sparse Gaussian process [Bauer et al., 2016].


Variational Distribution

We defined the variationa distribution as q(zx)q(z|x). However, we have many types of variational distributions we can impose. For example, we have some of the following:

  • Delta, q(z)=zq(z)=z
  • Gaussian, N(μ,Σ)\mathcal{N}(\boldsymbol{\mu},\boldsymbol{\Sigma})
  • Laplacian, $$
  • Mixture Distribution, kKπkP\sum_{k}^{K}\pi_k \mathbb{P}
  • Bijective Transform (Flow), q(zz~)q(z|\tilde{z})
  • Stochastic Transform (Encoder, Amortized), q(zx)q(z|x)
  • Conditional, q(zx,y)q(z|x,y)

Below we will go through each of them and outline some potential strengths and weaknesses of each of the methods.


Delta Distribution

This is probably the distribution with the least amount of parameters. We set the covariance matrix to 00, i.e. Σθ:=0\boldsymbol{\Sigma_\theta}:=\mathbf{0}, and we let all of the mass rest on mean points, μθ:=μ=u\boldsymbol{\mu_\theta}:=\boldsymbol{\mu}=\mathbf{u}.

q(z)=δ(zz^)q(z) = \delta(z - \hat{z})

Note: Although this is the most trivial variational distribution, it is the most widely used in optimization algorithms because it is equivalent to the MAP estimation (or MLE without any prior) as shown in [Wang & Blei, 2012].


Simple, q(z)q(z)

This is the simplest case where we often assume a very simple distribution can describe the distribution.

q(z)=N(zμθ,Σθ)q(z) = \mathcal{N}(z|\boldsymbol{\mu_\theta},\boldsymbol{\Sigma_\theta})

If we take each of the Gaussian parameters as full matrices, we end up with:

μθ:=μRD,Σθ:=ΣRD×D;\boldsymbol{\mu_\theta}:=\boldsymbol{\mu} \in \mathbb{R}^D, \hspace{5mm} \boldsymbol{\Sigma_\theta}:=\boldsymbol{\Sigma} \in \mathbb{R}^{D\times D};

For very high dimensional problems, these are a lot of parameters to learn. Now, we can have various simplifications (or complications) with this. For example, we can simplify the mean, μ\boldsymbol{\mu}, to be zero. The majority of the changes will come from the covariance. Here are a few modifications.

Full Covariance

This is when we parameterize our covariance to be a full covariance matrix. Σθ:=Σ\boldsymbol{\Sigma_\theta} := \boldsymbol{\Sigma}. This is easily the most expensive and the most complex of the Gaussian types.

Lower Cholesky

We can also parameterize our covariance to be a lower triangular matrix, i.e. Σθ:=L\boldsymbol{\Sigma_\theta} := \mathbf{L}, that satisfies the cholesky decomposition, i.e. LL=Σ\mathbf{LL}^\top = \boldsymbol{\Sigma}. This reduces the number of parameters of the full covariance by a factor. It also has desireable properties when parameterizing covariance matrices that are computationally attractive, e.g. positive definite.

Diagonal Covariance

We can parameterize our covariance matrix to be a diagonal, i.e. Σθ:=diag(σ)\boldsymbol{\Sigma_\theta} := \text{diag}(\boldsymbol{\sigma}). This is a very drastic simplification of our model which limits the expressivity. However, there are immense computational benefits For example, a d-dimensional multivariate Gaussian rv with a mean and a diagonal covariance is the same as the product of dd univeriate Gaussians.

q(z)=N(μθ,diag(σθ))=dDN(μd,σd)q(z) = \mathcal{N}\left(\boldsymbol{\mu_\theta}, \text{diag}(\boldsymbol{\sigma_\theta})\right) = \prod_{d}^D \mathcal{N}(\mu_d, \sigma_d )

This is also known as the mean-field approximation and it is a very common starting point in practical VI algorithms.

Low Rank Multivariate Normal

Another parameterization is a low rank matrix with a diagonal matrix, i.e. Σθ:=WW+D\boldsymbol{\Sigma_\theta} := \mathbf{W}\mathbf{W}^\top + \mathbf{D} where WRD×d,DRD×D\mathbf{W} \in \mathbb{R}^{D\times d}, \mathbf{D} \in \mathbb{R}^{D\times D}. We assume that our parameterization can be low dimensional which might be appropriate for some applications. This allows for some computationally efficient schemes that make use of the Woodbury Identity and the matrix determinant lemma.

Orthogonal Decoupled

One interesting approach is to map the variational parameters via a subspace parameterization [Salimbeni et al., 2018]. For example, we can define the mean and variance like so:

μθ=ΨμaΣθ=ΨΣAΨΣ+I\begin{aligned} \boldsymbol{\mu_\theta} &= \boldsymbol{\Psi}_{\boldsymbol{\mu}} \mathbf{a} \\ \boldsymbol{\Sigma_\theta} &= \boldsymbol{\Psi}_{\boldsymbol{\Sigma}} \mathbf{A} \boldsymbol{\Psi}_{\boldsymbol{\Sigma}}^\top + \mathbf{I} \end{aligned}

This is a bit of a spin off of the Low-Rank Multivariate Normal approach. However, this method takes care and provides a low-rank method for both the mean and the covariance. They argue that we would be able to put more computational effort in the mean function (computationally easy) and less computational effort for the covariance (computationally intensive).


Laplace Approximation

qθ(x)=p(x^y)exp(12(xx^)S1(xx^))q_{\boldsymbol{\theta}}(x) = p(\hat{x}|y)\exp\left(-\frac{1}{2}(x - \hat{x})\mathbf{S}^{-1}(x-\hat{x}) \right)

where:

x=argmaxxlogp(xy)x^* = \underset{x}{\text{argmax}} \hspace{1mm} \log p(x|y)
S=θθp(xy)θ=θ^\mathbf{S} = - \boldsymbol{\nabla_\theta}\boldsymbol{\nabla_\theta}p(x|y)|_{\boldsymbol{\theta}=\hat{\boldsymbol{\theta}}}

This method was popularized by [Kass et al. (1991)MacKay (1992)]


Mixture Distribution

The principal behind this is that a simple base distribution, e.g. Gaussian, is not expressive enough. However, a mixture of simple distributions, e.g. Mixture of Gaussians, will be more expressive. So the idea is to choose simple base distribution and replicate it kk times. Then, we then do a normalized weighted summation of each component to produce our mixture distribution.

q(z)=kKπkPkq(z) = \sum_{k}^K\pi_k \mathbb{P}_k

where 0πk10 \leq \pi_k \leq 1 and kKπk=1\sum_{k}^K\pi_k=1. For example, we can use a Gaussian distribution

pθ(z)=N(μ,Σ)p_\theta(z) = \mathcal{N}(\boldsymbol{\mu}, \boldsymbol{\Sigma})

where θ={πk,μk,Σk}kK\theta = \{\pi_k, \boldsymbol{\mu}_k, \boldsymbol{\Sigma}_k \}_k^K are potentially learned parameters.. And the mixture distribution will be

qθ(z)=kKπkN(zμk,Σk)q_{\boldsymbol \theta}(\mathbf{z}) = \sum_{k}^K \pi_k \mathcal{N}(\mathbf{z} |\boldsymbol{\mu}_k, \boldsymbol{\Sigma}_k)

Again, we are free to parameterize the covariances as flexible or restrictive as possible. For example we can have full, cholesky, low-rank or diagonal. In addition we can tie some of the parameters together. For example, we can have the same covariance matrix for every kthk^\text{th} component, e.g. Σk=Σ\boldsymbol{\Sigma}_k=\boldsymbol{\Sigma}. Even for VAEs, this becomes a prior distribution which has noticable improvement over the standard Gaussian prior.

Note: in principal, a mixture distribution is very powerful and has the ability to estimate any distribution, e.g. univariate with enough components. However, like with most problems, the issue is estimating the best parameters just from observations.


Reparameterized


Gaussian


Bijective Transformation (Flow)

It may be that the variational distribution, qq, is not sufficiently expressive enough even with the complex Gaussian parameterization and/or the mixture distribution. So another option is to use a bijective transformation to map the data from a simple base distribution, e.g. Gaussian, to a more complex distribution for our variational parameter, zz.

z=Tϕ(z~)\mathbf{z} = \boldsymbol{T_\phi}(\tilde{\mathbf{z}})

We hope that the resulting variational distribution, q(z)q(z), acts a better approximation to the data. Because our transformation is bijective, we can

variational parameter, zz, to a simple base distribution st we ha

q(z)=pe(z~)zTϕ1(z) q(z) = p_e(\tilde{z})|\boldsymbol{\nabla}_\mathbf{z}\boldsymbol{T_\phi}^{-1}(\mathbf{z})|

where z|\boldsymbol{\nabla}_\mathbf{z} \cdot| is the determinant Jacobian of the transformation, Tϕ\boldsymbol{T_\phi}.


Stochastic Transformation (Encoder, Amortization)

Another type of transformation is a stochastic transformation. This is given by q(zx)q(z|x). In this case, we assume some non-linear. For example, a Gaussian distribution with a parameterized mean and variance via neural networks

q(zx)=N(μϕ(x),σϕ(x))q(\mathbf{z}|\mathbf{x}) = \mathcal{N}\left(\boldsymbol{\mu_\phi}(\mathbf{x}), \boldsymbol{\sigma_\phi}(\mathbf{x})\right)

or more appropriately

q(zx)=N(μ,diag(exp(σlog2))),(μ,σlog2)=NNθ(x)q(\mathbf{z}|\mathbf{x}) = \mathcal{N}\left(\boldsymbol{\mu}, \text{diag}(\exp (\boldsymbol{\sigma}^2_{\log}) )\right), \hspace{4mm} (\boldsymbol{\mu}, \boldsymbol{\sigma}^2_{\log}) = \text{NN}_{\boldsymbol \theta}(\mathbf{x})

It can be very difficult to try and have a variational distribution that is complicated enough to cover the whole posterior. So often, we use a variational distribution that is conditioned on the observations, i.e. q(zx)q(z|x). This is known as an encoder because we encode the observations to obey th


Non-Parametric

  • Kernels & Stein
References
  1. Kingma, D. P., & Welling, M. (2013). Auto-Encoding Variational Bayes. arXiv. 10.48550/ARXIV.1312.6114
  2. Nielsen, D., Jaini, P., Hoogeboom, E., Winther, O., & Welling, M. (2020). SurVAE Flows: Surjections to Bridge the Gap between VAEs and Flows. arXiv. 10.48550/ARXIV.2007.02731
  3. Khan, M. E., & Rue, H. (2021). The Bayesian Learning Rule. arXiv. 10.48550/ARXIV.2107.04562
  4. Kıral, E. M., Möllenhoff, T., & Khan, M. E. (2023). The Lie-Group Bayesian Learning Rule. arXiv. 10.48550/ARXIV.2303.04397
  5. Bauer, M., van der Wilk, M., & Rasmussen, C. E. (2016). Understanding Probabilistic Sparse Gaussian Process Approximations. arXiv. 10.48550/ARXIV.1606.04820