variational autoencoders

One of the increasingly popular subfields of machine learning I started working with is that of generative modeling. A canonical example of the earliest application for modern machine learning is to train a model which classifies whether an image is a cat or a dog. While these classifiers began to work incredibly well, what they couldn’t do was generate new images of cats or dogs. For this task, the field turned to deep generative modeling, which is encapsulated by methods known as Generative Adversarial Networks and Variational Auto-encoders.

Deep generative modeling tries to model some data generating process, which we’ll call . This process, in a simple case, involves sampling some base distribution, , and then conditionally generating the observed sample, . We refer to as the latent variable and as the observation. In general, you can have as many latent variables as you wanted. Some models, especially those manually constructed, also know the relationship between latent variables and the observed variables, for example, in medical diagnostics models.

In our case, if successfully captured the generative process for images of cats and we could sample from it, then we would have a generative model for cat pics. In this note, we’ll focus on the Variational Auto-encoder (VAE) framework, which differs from GANs in how is learned. Right now, VAE models also claim to trade off generative quality with latent interpretability. This is not all that true in my experience, as only some VAE models offer interpretable latents while generating much worse cat images than a GAN could (more on this in another post). Anyways, the reason a VAE might allow the latents to be interpretable is because it also learns an amortized inference artifact, , which is a variational approximation to the true posterior, (hence the name). GANs don’t care about this thing at all, and only focus on matching the distribution over generated ’s as closely as possible with the empirical distribution over seen cat images.

What this means is that maps to something we can understand before ever seeing a corresponding . For example, might be a variable that controls the weight of cat, so we might expect any cat sampled from with a large value for to be chonkier than those generated from a small value. GANs, on the other hand, ignore this completely, and only use the randomness of to inject randomness in the output.

This sounds great, but how is it actually accomplished? As mentioned, our goal is to learn a parametrized model, , and inference network, . As a reasonable starting point, we would like to do so while maximizing the evidence, , i.e. we want to maximize the chances of our model generating the actual cat images we know exist. Unfortunately, we wanted to be able to conditionally generate cats only using , and we don’t know the values for the variables, , given only the observed cat images, , in our dataset. In order to obtain and compute , we must marginalize over all latent variables , giving us . I’ll just casually mention, by the way, that models defined this way are referred to as latent variable models because of this very intuitive reason of actually not knowing the hidden, unobserved variables corresponding to the evidence.

As it turns out, by using what we have to target the maximization of the log-evidence (computers work in log-space when dealing with probabilities to prevent underflow), we arrive at a nice computable loss function, known as the evidence lower-bound (ELBO). For the record, I am not a fan of the name. Anyways, here is the mathematical ELBO, computed using the things we have:

\begin{align} \log[p_\theta(x)] &= \log [\int p_\theta(x,z)dz]\nonumber \\ &= \log [\int q_\phi(z|x) \frac{p_\theta(x,z)}{q_\phi(z|x)} dz]\nonumber \\ &=\log (\mathbb{E}_{q_\phi(z|x)}[\frac{p_\theta(x,z)}{q_\phi(z|x)}]) \nonumber \\ &\geq \mathbb{E}_{q_\phi(z|x)}[\log(\frac{p_\theta(x,z)}{q_\phi(z|x)})] \end{align}

This derivation looks closely related to importance sampling and uses Jensen’s inequality for the final flourish. It turns out that this loss actually has some nice properties for all the goals we have. By rearranging terms, we see that maximizing the ELBO maximizes the log-evidence while jointly minimizing the KL-divergence between the variational posterior and the true posterior:
\begin{align} \mathrm{ELBO} &= \mathbb{E}_{q_\phi(z|x)}[\log(\frac{p_\theta(x,z)}{q_\phi(z|x)})] \nonumber \\ &= \mathbb{E}_{q_\phi(z|x)}[\log(p_\theta(x,z)) - \log(q_\phi(z|x))] \nonumber \\ &= \mathbb{E}_{q_\phi(z|x)}[\log(p(z|x)p(x)) - \log(q_\phi(z|x))] \nonumber \\ &= \mathbb{E}_{q_\phi(z|x)}[\log(p(z|x)) + \log(p(x)) - \log(q_\phi(z|x))] \nonumber \\ &= \mathbb{E}_{q_\phi(z|x)}[\log(p(z|x)) - \log(q_\phi(z|x))] + \log(p(x))\nonumber \\ &= -\mathrm{KL}({q_\phi(z|x)}~||~{p(z|x)}) + \log(p(x)) \end{align}

Rearranging again, we can also get to a nice interpretation of the loss from the auto-encoding perspective. Maximizing the ELBO also maximizes a ‘reconstruction loss’, , while minimizing the KL-divergence between the prior and approximate posterior as a form of regularization:
\begin{align} \mathrm{ELBO} &= \mathbb{E}_{q_\phi(z|x)}[\log(\frac{p_\theta(x,z)}{q_\phi(z|x)})] \nonumber \\ &= \mathbb{E}_{q_\phi(z|x)}[\log(p_\theta(x,z)) - \log(q_\phi(z|x))] \nonumber \\ &= \mathbb{E}_{q_\phi(z|x)}[\log(p_\theta(x|z)p(z)) - \log(q_\phi(z|x))] \nonumber \\ &= \mathbb{E}_{q_\phi(z|x)}[\log(p_\theta(x|z)) + \log(p(z)) - \log(q_\phi(z|x))] \nonumber \\ % &= \mathbb{E}_{q_\phi(z|x)}[\log(p(z|x)) - \log(q_\phi(z|x))] + \log(p(x))\nonumber \\ &= \mathbb{E}_{q_\phi(z|x)}[\log(p_\theta(x|z)) ]-\mathrm{KL}({q_\phi(z|x)}~||~{p(z))} \end{align}

So, by computing this term for mini-batches of actual cat images and running gradient optimization, we will surely arrive at some nice parameters that give us both the generative model and an approximate inference artifact. There’s a lot of details remaining which make this such an interesting research direction, but at a high level, I do believe VAEs are really promising for the intuition behind their approach.