Skip to content

Variational Inference


Motivations

Variational inference is the most scalable inference method the machine learning community has (as of 2019).

Tutorials

  • https://www.ritchievink.com/blog/2019/06/10/bayesian-inference-how-we-are-able-to-chase-the-posterior/
  • https://www.ritchievink.com/blog/2019/09/16/variational-inference-from-scratch/

ELBO - Derivation

Let's start with the marginal likelihood function.

\mathcal{P}(y| \theta)=\int_\mathcal{X} \mathcal{P}(y|\mathbf x, \theta) \cdot \mathcal{P}(\mathbf x) \cdot d\mathbf{x}

where we have effectively marginalized out the f's. We already know that it's difficult to propagate the \mathbf x's through the nonlinear functions \mathbf K^{-1} and |det \mathbf K| (see previous doc for examples). So using the VI strategy, we introduce a new variational distribution q(\mathbf x) to approximate the posterior distribution \mathcal{P}(\mathbf x| y). The distribution is normally chosen to be Gaussian:

q(\mathbf x) = \prod_{i=1}^{N}\mathcal{N}(\mathbf x|\mathbf \mu_z, \mathbf \Sigma_z)

So at this point, we aree interested in trying to find a way to measure the difference between the approximate distribution q(\mathbf x) and the true posterior distribution \mathcal{P} (\mathbf x). Using some algebra, let's take the log of the marginal likelihood (evidence):

\log \mathcal{P}(y|\theta) = \log \int_\mathcal{X} \mathcal{P}(y|\mathbf x, \theta) \cdot \mathcal{P}(\mathbf x) \cdot d\mathbf x

So now we are going to use the some tricks that you see within almost every derivation of the VI framework. The first one consists of using the Identity trick. This allows us to change the expectation to incorporate the new variational distribution q(\mathbf x). We get the following equation:

\log \mathcal{P}(y|\theta) = \log \int_\mathcal{X} \mathcal{P}(y|\mathbf x, \theta) \cdot \mathcal{P}(\mathbf x) \cdot \frac{q(\mathbf x)}{q(\mathbf x)} \cdot d\mathbf x

Now that we have introduced our new variational distribution, we can regroup and reweight our expectation. Because I know what I want, I get the following:

\log \mathcal{P}(y|\theta) = \log \int_\mathcal{X} \mathcal{P}(y|\mathbf x, \theta) \cdot q(\mathbf x) \cdot \frac{\mathcal{P}(\mathbf x)}{q(\mathbf x)} \cdot d\mathbf x

Now with Jensen's inequality, we have the relationship f(\mathbb{E}[x]) \leq \mathbb{E} [f(x)]. We would like to put the \log function inside of the integral. Jensen's inequality allows us to do this. If we let f(\cdot)= \log(\cdot) then we get the Jensen's equality for a concave function, f(\mathbb{E}[x]) \geq \mathbb{E} [f(x)]. In this case if we match the terms to each component to the inequality, we have

\log \cdot \mathbb{E}_\mathcal{q(\mathbf x)} \left[ \mathcal{P}(y|\mathbf x, \theta) \cdot \frac{\mathcal{P}(\mathbf x)}{q(\mathbf x)} \right] \geq \mathbb{E}_\mathcal{q(\mathbf x)} \left[\log \mathcal{P}(y|\mathbf x, \theta) \cdot \frac{\mathcal{P}(\mathbf x)}{q(\mathbf x)} \right]

So now finally we have both terms in the inequality. Summarizing everything we have the following relationship:

log \mathcal{P}(y|\theta) = \log \int_\mathcal{X} \mathcal{P}(y|\mathbf x, \theta) \cdot q(\mathbf x) \cdot \frac{\mathcal{P}(\mathbf x)}{q(\mathbf x)} \cdot d\mathbf x
\log \mathcal{P}(y|\theta) \geq \int_\mathcal{X} \left[\log \mathcal{P}(y|\mathbf x, \theta) \cdot \frac{\mathcal{P}(\mathbf x)}{q(\mathbf x)} \right] q(\mathbf x) \cdot d\mathbf x

I'm going to switch up the terminology just to make it easier aesthetically. I'm going to let \mathcal{L}(\theta) be \log \mathcal{P}(y|\theta) and \mathcal{F}(q, \theta) \leq \mathcal{L}(\theta). So basically:

\mathcal{L}(\theta) =\log \mathcal{P}(y|\theta) \geq \int_\mathcal{X} \left[\log \mathcal{P}(y|\mathbf x, \theta) \cdot \frac{\mathcal{P}(\mathbf x)}{q(\mathbf x)} \right] q(\mathbf x) \cdot d\mathbf x = \mathcal{F}(q, \theta)

With this simple change I can talk about each of the parts individually. Now using log rules we can break apart the likelihood and the quotient. The quotient will be needed for the KL divergence.

\mathcal{F}(q) = \underbrace{\int_\mathcal{X} q(\mathbf x) \cdot \log \mathcal{P}(y|\mathbf x, \theta) \cdot d\mathbf x}_{\mathbb{E}_{q(\mathbf{x})}} + \underbrace{\int_\mathcal{X} q(\mathbf x) \log \frac{\mathcal{P}(\mathbf x)}{q(\mathbf x)} \cdot d\mathbf x}_{\text{KL}}

The punchline of this (after many calculated manipulations), is that we obtain an optimization equation \mathcal{F}(\theta):

\mathcal{F}(q)=\mathbb{E}_{q(\mathbf x)}\left[ \log \mathcal{P}(y|\mathbf x, \theta) \right] - \text{D}_\text{KL}\left[ q(\mathbf x) || \mathcal{P}(\mathbf x) \right]

where:

  • Approximate posterior distribution: q(x)
  • The best match to the true posterior \mathcal{P}(y|\mathbf x, \theta). This is what we want to calculate.
  • Reconstruction Cost: \mathbb{E}_{q(\mathbf x)}\left[ \log \mathcal{P}(y|\mathbf x, \theta) \right]
  • The expected log-likelihood measure of how well the samples from q(x) are able to explain the data y.
  • Penalty: \text{D}_\text{KL}\left[ q(\mathbf x) || \mathcal{P}(\mathbf x) \right]
  • Ensures that the explanation of the data q(x) doesn't deviate too far from your beliefs \mathcal{P}(x). (Okham's razor constraint)

Source: VI Tutorial - Shakir Mohamed

If we optimize \mathcal{F} with respect to q(\mathbf x), the KL is minimized and we just get the likelihood. As we've seen before, the likelihood term is still problematic as it still has the nonlinear portion to propagate the \mathbf x's through. So that's nothing new and we've done nothing useful. If we introduce some special structure in q(f) by introducing sparsity, then we can achieve something useful with this formulation. But through augmentation of the variable space with \mathbf u and \mathbf Z we can bypass this problem. The second term is simple to calculate because they're both chosen to be Gaussian.

### Comments on q(x)

  • We have now transformed our problem from an integration problem to an optimization problem where we optimize for q(x) directly.
  • Many people tend to simplify q but we could easily write some dependencies on the data for example q(x|\mathcal{D}).
  • We can easily see the convergence as we just have to wait until the loss (free energy) reaches convergence.
  • Typically q(x) is a Gaussian whereby the variational parameters are the mean and the variance. Practically speaking, we could freeze or unfreeze any of these parameters if we have some prior knowledge about our problem.
  • Many people say 'tighten the bound' but they really just mean optimization: modifying the hyperparameters so that we get as close as possible to the true marginal likelihood.

## Pros and Cons

### Why Variational Inference?

  • Applicable to all probabilistic models
  • Transforms a problem from integration to one of optimization
  • Convergence assessment
  • Principled and Scalable approach to model selection
  • Compact representation of posterior distribution
  • Faster to converge
  • Numerically stable
  • Modern Computing Architectures (GPUs)

Why Not Variational Inference?

  • Approximate posterior only
  • Difficulty in optimization due to local minima
  • Under-estimates the variance of posterior
  • Limited theory and guarantees for variational mehtods

Resources

Lower Bound

Summaries

Presentations

Reviews * From EM to SVI * Variational Inference * VI- Review for Statisticians * Tutorial on VI * VI w/ Code * VI - Mean Field * VI Tutorial * GMM * VI in GMM * GMM Pyro | Pyro * GMM PyTorch | PyTorch | PyTorchy

Code

Extensions

From Scratch

  • Programming a Neural Network from Scratch - Ritchie Vink (2017) - blog
  • An Introduction to Probability and Computational Bayesian Statistcs - Eric Ma 0Blog
  • Variational Inference from Scratch - Ritchie Vink (2019) - blog
  • Bayesian inference; How we are able to chase the Posterior - Ritchie Vink (2019) - blog
  • Algorithm Breakdown: Expectation Maximization - blog

Variational Inference

  • Variational Bayes and The Mean-Field Approximation - Keng (2017) - blog