This repository is inspired by Blackjax's. In particular, the interface. I have also contributed to Blackjax and used some of the things I've learned here to incorporate in there. This was an exercise for me to implement an inference algorithm in JAX. It also allowed me to ramble a bit about variational inference and how Bayesian neural networks fit into the variational inference framework.
- Be sure to check the amazing Blackjax library out!
- Be sure to checkout the seminal paper from (Blundell et. al, 2015) upon which this repository is also inspired by.
A probabilistic model is an approximation of nature. In particular, it approximates the process by which our observed data was created. While this approximation may be causally and scientifically inaccurate, it can still provide utility based on the goals of the practitioner. For instance, finding associations through our model can be useful for prediction even when the underlying generative assumptions don't mirror reality. From the perspective of probabilistic modeling, our data
Hidden variables are quantities which we believe played a role in generating our data, but unlike our data which is observed, their particular values are unknown to us. The goal of inference is to use our observed data to uncover the likely values of the hidden variables in our model in the form of posterior distributions over those hidden variables. Hidden variables are partitioned into two categories: global hidden variables and local hidden variables, which we denote
The reason for this lengthy preface is that a lot of resources on variational inference (the focus of this repository) speak in terms of both local and global hidden variables, and how we treat them during inference is different.
We have a probabilistic model of our data
We'd like to compute the posterior
The normalizing constant involves a large multi-dimensional integration which is generally intractable. Variational inference turns this "integration problem" into an "optimization problem" which is much easier (computing derivatives is fairly easy, integration if very hard).
From now on I will suppress the
The basic premise of variational inference is to first propose a variational family
We will see that we cannot directly do this because it ends up involving the computation of the "evidence"
So instead we optimize the ELBO, which is equivalent to the KL divergence term up to a constant. It is simply the KL divergence term without the intractable evidence, and then negated since we maximize the ELBO while we would minimize the KL divergence.
Further manipulation of the ELBO allows us to gather intuitive insights into how it will lead
The expected data log likelihood term encourages
This repository implements a particular form of variational inference, often referred to as mean-field variational inference. But be careful! The formulation of the mean-field family and how one optimizes the variational parameters depends on whether the variational distribution is over the local hidden variables or global hidden variables. For the local hidden variable formulation, see (Margossian et. al, 2023). For the global variable case, however, mean-field variational inference is often referred to as selecting the following variational family (of distributions) over the global hidden variables ((Coker et. al, 2021) & (Foong et. al, 2020)):
In other words, the family of multivariate Gaussians with diagonal covariance (also called a fully factorized Gaussian). Some have questioned the expressivity of the mean-field family, and whether it can capture the complex dependencies in a high-dimensional target posterior distribution. For instance, (Foong et. al, 2020) look at the failure modes of mean-field variational inference in shallow neural networks. On the other hand, (Farquhar et. al) argue that with large neural networks, mean-field variational inference is sufficient.
We would like to take the gradient of the ELBO with respect to the variational parameters
Monte Carlo integration allows us to get an unbiased approximation of an expectation of a function by sampling from the distribution the expectation is with respect to:
Let us expand
where in
- Note the identity:
$\nabla_\gamma \, q_ \gamma(\theta) = q_ \gamma(\theta) \nabla_\gamma \, \text{log} \, q_ \gamma(\theta)$ - Plug the new expression for
$\nabla_\gamma \, q_ \gamma(\theta)$ into$(3)$ , factor and rearrange to get$\mathbb{E}_ {q_ \gamma(\theta)} [\dots]$ ...
We can use Monte Carlo to approximate
Let's recall what we want to do: like in
We just saw how we can find a Monte Carlo estimator for this but the variance of such an estimator can be unusably high. The reparameterization trick ends up making the estimator have much less variance; however, it requires more assumptions (i.e. specific forms of variational distribution). The idea of reparameterization is that we can come up with an equivalent representation of a random quantity but this new representation allows us to do cool and good things. So suppose we can express
Now why is this useful? Well let us go through taking the derivative in
We can use Monte Carlo to approximate
We use the reparameterization trick in this repository because we can; namely, we choose
From the reparameterization trick we have
Sometimes, given the form chosen for