Joint-contrastive inference and model-based deep learning

In this post I will discuss joint-contrastive variational inference, a new form of stochastic variational inference that is gaining traction in the machine learning community. In this and in following posts, I will use the joint-contrastive inference framework in order to show that several commonly used deep learning methods are actually Bayesian inference methods in disguise. This post is mostly based on this paper from our lab about the connection between variational inference and model-based deep learning. In future posts I will cover several other papers using the joint-contrastive framework. In the last decade, deep neural networks have revolutionized machine learning with their surprising generalization properties. Deep learning methods pushed forward the state-of-the-art in most machine learning problems such as image recognition and machine translation. The rise of deep learning also led several researchers to abandon Bayesian methods and to revert to the deterministic and maximal likelihood methods that were an integral part of the neural network tradition. This happened in my lab as well and led to great works such as this one about the connection between deep learning and the human visual system. However, there isn’t any incompatibility between deep learning and Bayesian inference. Deep parametric models, stochastic gradient descent and backpropagation are just tools that can be used for constructing and training any kind of machine learning model. The key conceptual technology that is leading to the incorporation of Bayesian inference and deep learning is stochastic variational inference.

Conventional deep learning as model-free inference

Deep neural networks are trained to approximate complex functional relationships between paired data. Consider a meteorological example. Let $y$ be ground measurements of the state of the earth atmosphere and $x$ be a set of images of our planet taken from a satellite. I will denote the empirical distribution of experimentally collected pairs $(x,y)$ as $k(x,y)$. We can train a deep neural network for recovering the probability of the ground measurements from the images. If our aim is to recover the full conditional distribution instead of a single point-estimate, the loss will have the following form:

where $\mathfrak{q}(y\vert g_w(x))$ is a probability distribution parameterized by the output $g_w(x)$ of a neural network. A common choice for $\mathfrak{q}$ is a diagonal Gaussian parameterized by a vector of means and variances. This is a form of model-free probabilistic inference. The architecture of the deep network usually does not encode any information about the process that generated the data and the causal relationships between variables. Conversely, the functional association is directly learned from the data by leveraging a huge training set.

Bayesian statistics as model-based inference

What is Bayesian inference? To put it simply, it is a mathematically sound method for model-based probabilistic reasoning. Consider again the meteorological example. Let $z$ be the state of the earth atmosphere and $x$ a satellite picture. The generative model is given by a factorized joint distribution of the variables:

The most interesting feature of this expression is that it can be interpreted causally. The prior $p(z)$ describes our knowledge of the dynamics of the atmosphere as encoded by known fluidodynamic equations. Analogously, $p(x\vert z)$ encapsulates our knowledge about the causal relation between the atmospheric state and the generation of the images. This source of knowledge can for example be encoded in a 3D simulator that generates CG images. Bayes rule is a deceivingly simple formula for combining the data $x$ and our knowledge of the process in order to infer the state of the latent variable $z$:

Does this mean we are done? We have the formula for the optimal model-based inference right here, who needs deep learning! Unfortunately it is not so easy. Firstly, the term $p(x)$ is a high dimensional integral that we are usually not able to solve:

Furthermore, in our example we cannot even compute the probabilities $p(x\vert z)$ and $p(z)$ since they come from very complex models. These distributions are said to be implicit in the probabilistic inference jargon since we can sample from them but we cannot evaluate the probabilities.

Stochastic variational inference

Since deep learning techniques are centered around non-convex optimization, it is not surprising that variational inference is the major force behind the unification between Bayesian inference and deep learning. In fact, variational inference is a family of methods that turn Bayesian inference problems into (usually non-convex) optimization problems. In variational inference, an approximate Bayesian posterior distribution is obtained by minimizing a statistical divergence between the intractable posterior $p(z\vert x)$ and a parameterized variational approximation $q_w(z\vert x)$:

where $w$ is a set of parameters. Statistical divergences measure the dissimilarity between probability distributions. The most commonly used variational loss is the (reverse) KL divergence:

At first glance this expression seems to be intractable since we cannot evaluate the true posterior $\log{p(z\vert x)}$. Fortunately, since $\log{p(z\vert x)} = \log{p(z,x)} - \log{p(x)}$, we can decompose the KL divergence as follows:

where the first term is the evidence lower bound (ELBO):

Note that $\log{p(x)}$ does not depend on $w$. Consequently, maximizing the ELBO is equivalent to minimizing the reverse KL divergence between the variational approximation and the real posterior distribution.

Amortized inference

Maximizing the ELBO is somewhat wasteful since we need to re-optimize a variational posterior for each satellite image. This is where deep neural networks come into play. We can parameterize the whole variational conditional distribution using a deep network $g_w(x)$:

where $\mathfrak{q}(z\vert g_w(x))$ is a distribution over $z$ and parameterized by $g_w(x)$ . Therefore, the deep network $g_w(x)$ maps each image $x$ to the variational posterior $q_w(z\vert x)$. In order to train $g_w(x)$ we can use an amortized variational loss obtained by averaging the negative ELBO over all the images in the training set:

This expression looks very similar to the kind of loss function that we are used to minimizing in deep learning. The main difference is that one of the variables is sampled from a generative model instead of being sampled from the empirical distribution of some dataset. In this sense, we can see that Bayesian variational inference is analogous to a form of model-based deep learning where we use a generative model instead of a set of paired data-points.

Joint-contrastive variational inference

I will call this form of variational inference posterior-contrastive since the divergence to minimize measures the difference between the posterior $p(z\vert x)$ and the variational approximation of the posterior. This kind of terminology was introduced in this neat paper and in this insightful blog post. Joint-amortized variational inference was introduced in the adversarially learned inference (ALI) paper. Since there is already quite some material concerning the use of adversarial methods for variational inference, I will focus on non-adversarial methods. The loss of joint-contrastive variational inference is a divergence between the model joint distribution and a joint variational distribution:

Without further constraints the minimization of this loss functional is not particularly useful as the model joint $p(z,x)$ is usually tractable and does not need to be approximated. The key idea for approximating the intractable posterior $p(z\vert x)$ by minimizing this loss is to factorize the variational joint as the product of a variational posterior $q_w(z\vert x)$ and the sampling distribution of the data:

Given this factorization, the minimization of the joint-contrastive loss approximates the model posterior with $q(z\vert x)$. Furthermore, if the generative model $p(x)$ has some free parameters, this minimization will also fit $p(x)$ to the sampling distribution of the data.

Why is the joint-contrastive loss useful?

Now the reader might wonder why we should take the trouble to use a joint-contrastive loss when the posterior-contrastive one works just fine. There are two main reasons for that. First, the posterior-contrastive loss requires to take a divergence with respect to the intractable true posterior distribution. In the case of the reverse KL divergence we got lucky since the intractable normalization constant $p(x)$ does not affect the gradient. Unfortunately this magic does not happen with almost any other divergence. A common example is the forward KL divergence:

In this expression we need to sample the latent variable from $p(z\vert x)$. Needless to say, if we could sample from the posterior we would not need variational inference in the first place. This is a pity since the forward KL divergence has some very nice properties. For example, it is optimized by the exact marginal distributions even when the variational approximation is fully factorized (mean field approximation). The second reason for favoring a joint-contrastive variational loss is that it naturally leads to inference amortization.

Amortized inference revisited

In the previous section I presented amortized inference loss $\mathcal{L}{AVI}[w]$ without any theoretical motivation. I will now show that $\mathcal{L}{AVI}[w]$ follows directly from a joint-contrastive inference loss. Consider the following joint-contrastive loss:

where $\mathcal{H}_x$ is the differential entropy of the sampling distribution $k(x)$, which does not depend on the parameters $w$. Therefore, the reverse KL joint-contrastive loss has the same gradient as the amortized ELBO and it leads to the same optimization problem.

Forward amortized inference and simulation-based deep learning

We can now try to use the forward KL divergence in a joint-contrastive inference loss:

Note that neither $\log{p(z,x)}$ nor $k(x)$ depend on $w$. Therefore, we can minimize the forward KL divergence by minimizing the following loss:

This loss function has a very intuitive interpretation as a form of simulation-based deep learning. The pairs $(z,x)$ are sampled from the generative model. In our example, the atmospheric state $z$ is obtained by integrating some fluidodynamic equations while the images $x$ are generated by a 3D environment that takes as input the state of the atmosphere and outputs a picture of our planet. The network is then trained to predict the probability of the atmospheric state given an image. The generative model can be used for simulating as many data-points as needed. Consequently, the network parameterizing $q_w(z\vert x)$ cannot overfit the training set. The forward loss has a very important advantage over other forms of variational inference. The first advantage is that in order to evaluate it you only need to be able to sample from the generative model. Conversely, if you want to evaluate the reversed loss you need to explicitly evaluate the probability $p(z,x)$. This can be extremely difficult if your generative model is not defined in terms of simple mathematical formulas such as in the case of our 3D simulator which generates images of our planet. This kind of generative models are said to be implicit and the the focus of much research in machine learning. Most existing variational inference methods that can be used with implicit models are based on adversarial training. Unfortunately, variational training only works when the generator is a differentiable program. You cannot just create a simulator in a 3D engine like Unity and perform adversarial inference on it, you will need to re-implement Unity in something like Tensorflow and be sure to only use differentiable functions! Conversely, you can use forward amortized inference with any kind of simulator straight away without any extra implementation effort.