MIM: Mutual Information Machine

MIM: Mutual Information Machine

Micha Livne
University of Toronto
Vector Institute
mlivne@cs.toronto.edu
&Kevin Swersky
Google Research
kswersky@google.com
&David J. Fleet
University of Toronto
Vector Institute
fleet@cs.toronto.edu
September 2019
Abstract

We introduce the Mutual Information Machine (MIM), an autoencoder model for learning joint distributions over observations and latent states. The model formulation reflects two key design principles: 1) symmetry, to encourage the encoder and decoder to learn consistent factorizations of the same underlying distribution; and 2) mutual information, to encourage the learning of useful representations for downstream tasks. The objective comprises the Jensen-Shannon divergence between the encoding and decoding joint distributions, plus a mutual information term. We show that this objective can be bounded by a tractable cross-entropy loss between the true model and a parameterized approximation, and relate this to maximum likelihood estimation and variational autoencoders. Experiments show that MIM is capable of learning a latent representation with high mutual information, and good unsupervised clustering, while providing data log likelihood comparable to VAE (with a sufficiently expressive architecture).

\PassOptionsToPackage

numbersnatbib \NewEnvironscaletikzpicturetowidth[1]\BODY

1 Introduction

Mutual information is a natural indicator of the quality of a learned representation (hjelm2018learning), along with other characteristics, such as the compositionality of latent factors that are expected to be useful in downstream tasks, like transfer learning (DBLP:journals/corr/BengioTPPB17). Mutual information is, however, computationally difficult to estimate for continuous high-dimensional random variables. As such, it can be hard to optimize when learning latent variable models (Hjelm2018; Chen2016).

This paper formulates a new class of probabilistic autoencoder model that is motivated by two key design principles, namely, the maximization of mutual information, and the symmetry of the encoder-decoder components. Symmetry captures our desire for both the encoder and decoder to effectively and consistently model the underlying observation and latent domains. This is particularly useful for downstream tasks in which either one or both of the encoder and decoder play a central role. These properties are formulated in terms of the symmetric Jensen-Shannon Divergence between the encoder and decoder, combined with an objective term to maximize mutual information. We refer to the resulting model as the mutual information machine, or MIM.

We contrast MIM with models trained using (approximate) maximum likelihood, the canonical example being the variational autoencoder, or VAE (Kingma2013; Rezende2014). The VAE comprises a probabilistic decoder and an approximate encoder, learned via optimization of an evidence-based lower bound (ELBO) on the log marginal data distribution. In contrast to MIM it is asymmetric in its formulation, and while often producing excellent representations, VAEs sometimes produce pathological results in which the encoder, or approximate posterior, conveys relatively little information between observations and latent states. This behavior, often referred to as posterior collapse, results in low mutual information between observations and inferred latent states (DBLP:journals/corr/BowmanVVDJB15; ChenKSDDSSA16; DBLP:journals/corr/abs-1901-03416; DBLP:journals/corr/OordKK16; DBLP:journals/corr/abs-1711-00937).

We formulate the MIM model, and a learning algorithm that minimizes an upper bound on the desired loss. The resulting objective can be viewed as a symmetrized form of KL divergence, thereby closely related to the asymmetric KL objective of the VAE. This also enables direct comparisons to the VAE in terms of posterior collapse, mutual information, data log likelihood, and clustering. Experiments show that MIM offers favourable mutual information, better clustering in the latent representation, and similar reconstruction, at the expense of sampling quality and data log likelihood when compared to a VAE with the same architecture. We also demonstrate that for a sufficiently powerful architecture, MIM can match sampling quality and log likelihood of a VAE with the same architecture.

2 Variational Autoencoders

VAE learning entails optimization of a variational lower bound on the log-marginal likelihood of the data, , to estimate the parameters of an approximate posterior over latent states (i.e., the encoder) and a corresponding decoder, (Kingma2013; Rezende2014). A prior over the latent space, , often assumed to be an isotropic Gaussian, serves as a prior for in the evidence-lower-bound (ELBO) on the marginal likelihood:

Here, we use the notation and to emphasize that these priors are given, and that we can draw random samples from them, but not necessarily evaluate the log-likelihood of samples under them. In what follows we often refer to them as anchors to further emphasize their role.

With amortized posterior inference, we take expectation over the observation distribution, , to obtain the VAE objective:

(1)

Gradients of Eqn. (1) are estimated through MC sampling from with reparameterization, yielding unbiased low-variance gradient estimates (Kingma2013; Rezende2014).

VAEs are normally thought of as maximizing a lower bound on the data log-likelihood, however it can also be expressed as minimizing the divergence between two joint distributions over and . To see this, we first subtract from (1), which does not change the gradients of the objective with respect to . We then negate the result, as we will be performing minimization. This yields a VAE loss

(2)

The VAE optimization is therefore equivalent to minimizing the KL divergence between an encoding distribution and a decoding distribution .

3 Symmetry and Mutual Information

(a)

(b)

(c)
Figure 1: A MIM model learns two factorizations of a joint distribution: (a) encoding; (b) decoding factorizations; and (c) the estimated joint distribution (an undirected graphical model).

Our goal is to find a consistent encoder-decoder pair, representing a joint distribution over the observation and latent domains, with high mutual information between observations and latent states. By consistent, we mean that the encoding and decoding distributions, and , define the same joint distribution. Figure 1 depicts this basic idea, in which the same distribution is identical under both the encoding and decoding factorizations. Effectively, we estimate an undirected graphical model with two valid factorizations. We note that consistency is achievable in the VAE when the approximate posterior is capable of representing the posterior under the decoding distribution . In the general case, however, consistency is not usually achieved.

In contrast to the asymmetric divergence between encoding and decoding distributions in the VAE objective (2), here we consider a symmetric measure, namely, the well-known Jensen-Shannon divergence (JSD),

(3)

where is an equally weighted mixture of the encoding and decoding distributions; i.e.,

(4)

In addition to encoder-decoder consistency, to learn useful latent representations we also want high mutual information between and . Indeed, the link between mutual information and representation learning has been explored in recent work (Hjelm2018; Chen2016; hjelm2018learning). Here, to emphasize high mutual information, we add a particular regularizer of the form

(5)

This is the average of the joint entropy over and according to the encoding and decoding distributions. This is related to mutual information by the identity . That is, minimizing joint entropy encourages the minimization of the marginal entropy and maximization of the mutual information. In addition to encouraging high mutual information, one can show that this particular regularizer has a deep connection to JSD and the entropy of , i.e.,

(6)

The derivation for Eqn. (6) is given in Appendix A.1.

4 Mutual Information Machine

The loss function in Eqn. (6) reflects our desire for model symmetry and high mutual information. Nevertheless, it is difficult to optimize directly since we do not know how to evaluate in the general case (i.e., we do not have an exact closed-form expression for ). As a consequence, we introduce parameterized approximate priors, and , to derive tractable bounds on the penalized Jensen-Shannon divergence. This is similar in spirit to VAEs, which introduce a parameterized approximate posterior. These parameterized priors, together with the conditional encoder and decoder, and , comprise a new pair of joint distributions, i.e.,

These joint distributions allow us to formulate a new, tractable loss that bounds . That is,

(7)

where

(8)

and denotes the cross-entropy between and .

We refer to as the cross-entropy loss. It aims to match the model prior distributions to the anchors, while also minimizing . A key advantage of this formulation is that the cross-entropy loss can be trained by Monte Carlo sampling from the anchor distributions with reparameterization (Kingma2013; Rezende2014).

At this stage it might seem odd to introduce a parametric prior for . Indeed, setting it directly is certainly an option. Nevertheless, in order to achieve consistency between and it can be advantageous to allow to vary. Essentially, we trade-off latent prior fidelity for increased model consistency. We provide more insights about this in Appendix D.3.

One issue with is that, while it will try to enforce consistency between the model and the anchored distributions, i.e., and , it will not directly try to achieve model consistency: . To remedy this, we bound using Jensen’s inequality, i.e.,

(9)

Equation (9) gives us the loss function for the Mutual Information Machine (MIM). It is an average of cross entropy terms between the mixture distribution and the model encoding and decoding distributions respectively. To see that this encourages model consistency, it can be shown that is equivalent to plus a non-negative model consistency term; i.e.,

(10)

The non-negativity of is a simple consequence of in (9). One can further show (see Appendix A.2) that satisfies

(11)
(12)

One can conclude from Eqn. (11) that is zero only when the two joint model distributions, and , are identical under fair samples from the joint sample distribution . In practice we find that encouraging model consistency also helps to stabilize learning.

To understand the MIM objective in greater depth, we find it helpful to express as a sum of fundamental terms that provide some intuition for its expected behavior. In particular, as derived in the Appendix A.3,

(13)

The first term in Eqn. (13) encourages high mutual information between observations and latent states. The second shows that MIM directly encourages the model priors to match the anchor distributions. Indeed, the KL term between the data anchor and the model prior is the maximum likelihood objective. The third term encourages consistency between the model distributions and the anchored distributions, in effect fitting the model decoder to samples drawn from the anchored encoder (cf. VAE), and, via symmetry, fitting the model encoder to samples drawn from the anchored decoder (both with reparameterization). As such, MIM can be seen as simultaneously training and distilling a model distribution over the data into a latent variable model. The idea of distilling density models has been used in other domains, e.g., for parallelizing auto-regressive models (oord2017parallel).

In summary, the MIM loss provides an upper bound on the joint entropy of the observation and latent states under the mixture distribution :

(14)

Through the MIM loss and the introduction of the parameterized model distribution , we are pushing down on the entropy of the anchored mixture distribution , which is the sum of marginal entropies minus the mutual information. Minimizing the MIM bound yields consistency of the model encoder and decoder, and high mutual information under between observations and latent states.

5 Learning

Here we provide a detailed description of MIM learning, with algorithmic pseudo-code. In addition we offer practical considerations regarding the choice of priors’ parameterization , and gradient estimation. The empirical upper bound objective, in Eqn. (9), is expressed in terms of two cross-entropy terms. Given fair samples, drawn from the anchored (sample) distribution, in (4), the empirical loss is

(15)

where samples from comprise equal numbers of points from and . Samples from the anchors, and , are treated as external observations; i.e., we assume we can sample from them but not necessarily evaluate the density of points under the anchor distributions.

Algorithm 1 specifies the corresponding training procedure. The algorithm makes no assumptions on the form of the parameterized distributions (e.g., discrete, or continuous). In practice, for gradient-based optimization, we would like an unbiased gradient estimator without the need to accurately approximate the full expectations per se (i.e., in the cross entropy terms). This is particularly important when dealing with high dimensional data (e.g., images), where it is computationally expensive to estimate the value of the expectation. We next discuss practical considerations for the continuous case and the discrete case.

0:  Samples from anchors
1:  while not converged do
2:     
3:     
4:     
5:     # See definition of in Eq. (15)
6:     
7:     # Minimize loss
8:     
9:  end while
Algorithm 1 MIM learning of parameters

5.1 MIM Parametric Priors

There are several effective ways to parameterize the priors. For the 1D experiments in Appendix D we model using linear mixtures of isotropic Gaussians. With complex, high dimensional data one might also consider more powerful models (e.g., autoregressive, or flow-based priors). Unfortunately, the use of complex models typically increases the required computational resources, and the training and inference time.

As an alternative, for image data, we make use of the VampPrior DBLP:journals/corr/TomczakW17, which models the latent prior as a mixture of posteriors, i.e., with learnable pseudo-inputs . This is effective and allows one to avoid the need for additional parameters (see DBLP:journals/corr/TomczakW17 for details on VampPrior’s effect over gradient estimation).

0:  Samples from anchors
0:  Define
1:  while not converged do
2:     # Sample encoding distribution
3:     
4:     # Compute objective, approximate with 1 sample and importance sampling
5:     
6:     
7:     # Sample decoding distribution
8:     
9:     # Compute objective, approximate with 1 sample
10:     
11:     
12:     # Minimize loss
13:     
14:  end while
Algorithm 2 MIM learning with marginal

Another useful model with high dimensional data, following DBLP:journals/corr/BornscheinSFB15, is to define as the marginal of the decoding distribution; i.e.,

(16)

Like the vampprior, this entails no new parameters. It also helps to encourage consistency between the encoding and decoding distributions. In addition it enables direct empirical comparison between MIM and VAE as we can then use identical parameterizations and architectures for both. During learning, when is defined as the marginal (16), we evaluate with a single sample and reparameterization. When is drawn directly from the latent prior:

When is drawn from the encoder, given a sample observation, we use importance sampling:

Algorithm 2 provides algorithm details with the marginal prior.

5.2 Gradient Estimation

Optimization is performed through minibatch stochastic gradient descent. To ensure unbiased gradient estimates of we use the reparameterization trick Kingma2013; Rezende2014 when taking expectation with respect to continuous encoder and decoder distributions, and . Reparameterization entails sampling an auxiliary variable , with known , followed by a deterministic mapping from sample variates to the target random variable, that is and for prior and conditional distributions. In doing so we assume is independent of the parameters . It then follows that

where is the loss function with parameters . It is common to let be standard normal, , and for to be Gaussian with mean and standard deviation , in which case . A more generic exact density model can be learned by mapping a known base distribution (e.g., Gaussian) to a target distribution with normalizing flows Dinh2014; Dinh2016a; Rezende2015.

For discrete distributions, e.g., with discrete data, reparameterization is not readily applicable. There exist continuous relaxations that permit reparameterization (e.g., DBLP:journals/corr/MaddisonMT16; DBLP:journals/corr/TuckerMMS17), but current methods are rather involved, and require adaptation of the objective function or the optimization process. Here we simply use the REINFORCE algorithm Sutton:1999:PGM:3009657.3009806 for unbiased gradient estimates, as follows

(17)

The derivation for Eqn. (17) is as follows:

for which the step from the first line to the second line makes use of the well-known identity, . This relation is essential as it enables a Monte Carlo approximation to the integral.

5.3 Training Time

Training times of MIM models are comparable to training times for VAEs with comparable architectures. One important difference concerns the time required for sampling from the decoder during training. This is particularly significant for models like auto-regressive decoders (e.g., Kingma2016) for which sampling is very slow. In such cases, we find that we can also learn effectively with a sampling distribution that only includes samples from the encoding distribution, i.e., , rather than the mixture. We refer to this particular MIM variant as asymmetric-MIM (or A-MIM). We use it in Sec. 6.3 when working with the PixelHVAE architecture Kingma2016.

6 Experiments

In what follows we examine MIM empirically, with the VAE as a baseline. We consider synthetic datasets and well-known image datasets, namely MNIST (LeCun1998), Fashion-MNIST (DBLP:journals/corr/abs-1708-07747) and Omniglot (Lake2015). All models were trained using Adam optimizer DBLP:journals/corr/KingmaB14 with a learning rate of , and a mini-batch size of 128. Following DBLP:journals/corr/abs-1711-00464, we anneal the loss to stabilize the optimization. To this end we linearly increase from 0 to 1 in the following expression for a number of ’warm-up’ epochs:

(18)

Training continues until the loss (i.e., with ) on a held-out validation set has not improved for the same number of epochs as the warm-up steps (i.e., defined per experiment). We have found the number of epochs to convergence of MIM learning to be between 2 to 5 times greater than a VAE with the same architecture. (Code is available from https://github.com/seraphlabs-ca/MIMhttps://github.com/seraphlabs-ca/MIM ).

6.1 Relation to VAE and Posterior Collapse

Before turning to empirical results, it is useful to briefly revisit similarities and differences between MIM and the canonical VAE formulation. To that end, one can show from Eqns. (1) and (2) that the VAE loss can be expressed in a form that bears similarity to the MIM loss in Eqn. (9). In particular, following the derivation in Appendix C,

(19)

where . Like the MIM loss, the first term in Eqn. (19) in the average of two cross entropy terms, between a sample distribution and the encoding and decoding distributions. Unlike the MIM loss, these terms are asymmetric as the samples are drawn only from the encoding distribution. Also unlike the MIM loss, the VAE loss includes the last three terms in Eqn. (19), the sum of which comprise the negative joint entropy under the sample distribution .

While the MIM objective explicitly encourages high mutual information between observations and corresponding latent embedings, this VAE loss includes a term that encourages a reduction in the mutual information. We posit that this plays a significant role in the phenomena often referred to as posterior collapse, in which the variance of the variational posterior grows large and the latent embedding conveys relatively little information about the observations (e.g., see (ChenKSDDSSA16) and others).

6.2 Posterior Collapse in Low Dimensional Data

To empirically support the expression in Eqn. (19), we begin with synthetic data comprising 2D observations , with a 2D latent space, . In 2D one can easily visualize the model and measure quantitative properties of interest (e.g., mutual information). Observations are drawn from anchor , a Gaussian mixture model with five isotropic components with standard deviation 0.25 (Fig. 2, top row). The latent anchor is an isotropic standard Normal (Fig. 2, bottom row). The encoder and decoder conditional distributions are Gaussian, the means and variances of which are regressed from the input using two fully connected layers and tanh activation. The parameterized data prior, , is defined to be the marginal of the decoding distribution (16), and the model prior is defined to be , so the only model parameters are those of the encoder and decoder. We can thus learn models with MIM and VAE objectives, but with the same architecture and parameterization. We used a warm-up scheduler Vaswani2017 for the learning rate, with a warm-up of 3 steps, and with each epoch comprising 10000 samples. Training and test sets are drawn independently from the GMM.

Figure 2 depicts three models for VAE (odd columns) and MIM (even columns), with increasing numbers of hidden units (from left to right) to control model expressiveness. The top row depicts observation space where black contours are levels sets of constant density , and red points are reconstructed samples, i.e., one point drawn from where is drawn from the encoder , given a test point from . In each case we also report the mutual information and the root-mean-squared reconstruction error, with MIM producing superior results.

VAE MIM VAE MIM VAE MIM
(a) (b) (c)
Figure 2: VAE and MIM models with 2D inputs, a 2D latent space, and 5, 20 and 500 hidden units. Top row: Black contours depict level sets of ; red dots are reconstructed test points. Bottom row: Green contours are one standard deviation ellipses of for test points. Dashed black circles depict one standard deviation of . The VAE predictive variance remains high, regardless of model expressiveness, an indication of various degrees of posterior collapse, while MIM produces lower predictive variance and lower reconstruction errors, consistent with high mutual information (see inset quantities).
(a) MI (b) NLL (c) Recon. Error (d) Classif. (5-NN)
Figure 3: Test performance for MIM (blue) and VAE (red) for the 2D GMM data (cf. Fig. 2), all as functions of the number of hidden units (on x-axis). Plots show (a) mutual information, (b) negative log-likelihood of test points, (c) test reconstruction error, and (d) K-NN mode classification performance. Each plot shows the mean and standard deviation of 10 experiments.

The bottom row of Fig. 2 depicts latent space behavior. The dashed black circle depicts one standard deviation of . Each green curve depicts a one standard deviation ellipse of the encoder posterior given a test point from . For the weakest architecture (a), with 5 hidden units, VAE and MIM posterior variances are similar to the prior in one dimension, a sign of posterior collapse. As the number of hidden units increases (b,c), the VAE posterior variance remains large, preferring lower mutual information while matching the aggregated posterior to the prior. In contrast, the MIM encoder produces tight posteriors, and yields higher mutual information and lower reconstruction errors at the expense of somewhat worse data log likelihoods..

To quantify this behavior Fig. 3 shows mutual information, the average negative log-likelihood (NLL) of test points under the model , the mean reconstruction error of test points, and 5-NN classification performance111We experimented with 1-NN,3-NN,5-NN,10-NN and found the results to be consistent. (predicting which of 5 GMM components each test point was drawn from). The auxiliary classification task provides a proxy for representation quality. Following Hjelm2018, we estimate mutual information using the KSG estimator PhysRevE.69.066138; DBLP:journals/corr/GaoOV16, based on 5-NN neighborhoods.

Mutual information and classification accuracy for test data under the MIM model are higher than for VAE models. One can also see that mutual information is saturated for MIM, as it effectively learns an (approximate) invertible mapping. The encoder and decoder approach deterministic mappings, reflected in the near-zero reconstruction error. Interestingly, MIM learning finds near-invertible mappings with unconstrained architectures (demonstrated here for the 2D case), when the dimensionality of the latent representation and the observations is the same. (See Sec. D of the supplementary material for experiments on variants of MIM and VAE that tease apart the impact of specific terms of the respective objectives.)

(a) MI (b) NLL (c) Recon. Error (d) Classif. (5-NN)
Figure 4: MIM (blue) and VAE (red) for 20D GMM data (i.e., ), all as function of the latent dimensionality, from 2 to 20 (on x-axis). Plots depict mean and standard deviation of 10 experiments. MIM learning produces higher mutual information and classification accuracy, with lower test reconstruction error, while VAE yields better data log likelihoods. The VAE suffers from increased collapse as the latent dimensionality grows.
(a) MI (b) NLL (c) Recon. Error (d) Classif. (5-NN)
Figure 5: MIM (blue) and VAE (red) for 20D Fashion-MNIST, with latent dimension between 2 and 20. Plots depict mean and standard deviation of 10 experiments. MIM opts for better mutual information, and yields better K-NN classification accuracy, at the expense of worse test log likelihood scores.

Next we consider synthetic 20D data from a 5-component GMM, with independent training and test sets, and with latent dimensionalities between 2 and 20. This ensures that the distribution is well modeled with a relatively simple architecture. This experiment extends the experiment in Fig. 3 by adding a bottleneck. The experimental setup was otherwise similar to that used in Fig. 3.

Results are shown in Fig. 4. MIM produces higher mutual information and better classification as the latent dimensionality increases. VAE mutual information and classification accuracy deteriorate with increasing latent dimensionality, due to stronger posterior collapse for higher dimensional latent. The test NLL scores for MIM are not as good as those for VAEs in part because the MIM encoder produces very small posterior variance, approaching a deterministic encoder. Nevertheless, MIM produces lower test reconstruction errors. These results are consistent with those in Fig. 3.

To further investigate MIM learning in low dimensional data, we project 784D images from Fashion-MNIST onto a 20D linear subspace using PCA (capturing 78.5% of total variance), and repeat the experiment in Fig. 4. The training and validation sets had 50,000 and 10,000 images respectively. We trained for 200 epochs, well past convergence, and then selected the model with the lowest validation loss. Fig. 5 summarizes the results, with MIM producing high mutual information and classification accuracy, at all but very low latent dimensions. MIM and VAE yield similar test reconstruction errors, with VAE having better negative log likelihoods for test data.

We conclude that the VAE is prone to posterior collapse for a wide range of models’ expressiveness and latent dimensionality, with latent embeddings exhibiting low mutual information. In contrast, MIM was empirically robust to posterior collapse, and showed higher mutual information, converging to an encoder with small variance. As a result the learned marginal data likelihood for MIM is worse. In this regard, we note that several papers have described ways to mitigate posterior collapse in VAE learning, e.g., by lower bounding, or annealing the KL divergence term in the VAE objective (DBLP:journals/corr/abs-1711-00464; DBLP:journals/corr/abs-1901-03416), or by limiting the expressiveness of the decoder (e.g., ChenKSDDSSA16). We posit that MIM does not suffer from this problem as a consequence of the objective design principles that encourage high mutual information between observations and the latent representation.

6.3 Image Data

convHVAE (S) convHVAE (VP)
Dataset MIM VAE MIM VAE
Fashion-MNIST
MNIST
Omniglot
PixelHVAE (S) PixelHVAE (VP)
A-MIM VAE A-MIM VAE
Fashion-MNIST
MNIST
Omniglot
Table 1: Test NLL (in nats) for high dimensional image data. Quantitative results based on 10 trials per condition. With a more powerful prior, MIM and VAE yield comparable results.
VAE A-MIM
(a) Fashion MNIST
VAE A-MIM
(b) MNIST
VAE A-MIM
(c) Omniglot
Figure 6: MIM and VAE learning with the PixelHVAE (VP) architecture, applied to Fashion-MNIST, MNIST, and Omniglot (left to right). The top three rows (from top to bottom) are test data samples, VAE reconstruction, and A-MIM reconstruction. Bottom: random samples from VAE and A-MIM. With a powerful enough prior, MIM offers samples which are comparable to VAE.

We next consider MIM and VAE learning with image data (Fashion-MNIST, MNIST, Omniglot). Unfortunately, with high dimensional data we cannot reliably compute mutual information (Hjelm2018). Instead, for model assessment we focus on negative log-likelihood, reconstruction, and the quality of random samples. In doing so we also explore multiple architectures, including the top performing models from DBLP:journals/corr/TomczakW17, namely, convHVAE (L = 2) and PixelHVAE (L = 2), with Standard (S) priors222, a standard Normal distribution, where is the identity matrix., and VampPrior (VP) priors333, a mixture model of the encoder conditioned on optimized pseudo-inputs .. The VP pseudo-inputs are initialized with training data samples. All the experiments below use the same experimental setup as in DBLP:journals/corr/TomczakW17, and the same latent dimensionality . Here we also demonstrate that a powerful prior (e.g., PixelHVAE (VP)) allows MIM to learn models with competitive sampling and NLL performance.

Sampling from an auto-regressive decoder (e.g., PixelHVAE) is very slow. To reduce training time, as discussed above in Sec. 5.3), we learn with a sampling distribution comprising just the encoding distribution, i.e., , rather than the mixture, a MIM variant we refer to as asymmetric-MIM (or A-MIM).

Table 1 reports test NLL scores. One can see that VAE models yield better NLL, but with a small gap for more expressive models (i.e., PixelHVAE (VP)). We also show qualitative results for the most expressive models (i.e., PixelHVAE). Fig. 6 depicts reconstruction444Test data in the top row of Fig. 6 are binary, while reconstructions depict the probability of each pixel being 1, following DBLP:journals/corr/TomczakW17. and sampling for Fashion-MNIST, MNIST, and Omniglot, for the top performing model (PixelHVAE (VP)), with MIM and VAE being comparable. The top three rows depict data samples, VAE reconstructions, and A-MIM reconstructions, respectively. The bottom row depicts random samples. Note that, while MIM with a weak prior (Standard) suffers from poor sampling, increasing the expressiveness results in comparable samples and reconstruction. See Appendix E for additional results.

The poor NLL and hence poor sampling for MIM with a weak prior model can be explained by the tightly clustered latent representation (e.g., Fig. 2). A more expressive, learnable prior can capture such clusters more accurately, and as such, also produces good samples (e.g., VampPrior). In other words, while VAE opts for better NLL and sampling at the expense of lower mutual information, MIM provides higher mutual information at the expense of the NLL for a weak prior, and comparable NLL and sampling with more expressive priors. In Sec. 6.4 we probe the effect of higher mutual information on the quality of the learned representation.

6.4 Clustering and Classification

convHVAE (S) convHVAE (VP) PixelHVAE (S) PixelHVAE (VP)
Dataset MIM VAE MIM VAE A-MIM VAE A-MIM VAE
Fashion-MNIST
MNIST
Table 2: Test accuracy of 5-NN classifier for High dimensional image data. Quantitative results based on 10 trials per condition. Standard deviations are less than 0.01, and omitted from the table. MIM offers better unsupervised clustering of classes in the latent representation in all experiments but one.
(a) VAE (S) (b) MIM (S) (c) VAE (VP) (d) MIM (VP)
Figure 7: MIM and VAE embedding for Fashion MNIST with convHVAE architecture. MIM shows stronger disentanglement of classes.
(a) VAE (S) (b) MIM (S) (c) VAE (VP) (d) MIM (VP)
Figure 8: MIM and VAE embedding for MNIST with convHVAE architecture. MIM shows stronger disentanglement of classes.
(a) VAE (S) (b) A-MIM (S) (c) VAE (VP) (d) A-MIM (VP)
Figure 9: A-MIM and VAE embedding for Fashion MNIST with PixelHVAE architecture. MIM shows stronger disentanglement of classes.
(a) VAE (S) (b) A-MIM (S) (c) VAE (VP) (d) A-MIM (VP)
Figure 10: A-MIM and VAE embedding for MNIST with PixelHVAE architecture. MIM shows stronger disentanglement of classes.

Finally, following hjelm2018learning, we consider an auxiliary classification task as a further measure of the quality of the learned representations. We opted for K-NN classification, being a non-parametric method which relies on semantic clustering in latent space without any additional training. Given representations learned above in Sec. 6.3, a simple 5-NN classifier555We omitted results for as we find them similar. was applied to test data to predict one of 10 classes for MNIST and Fashion-MNIST. Table 2 shows that in all but one case, MIM yields more accurate classification results. We attribute the performance difference to higher mutual information of MIM representations, combined with low entropy of the marginals. Figures (7, 8, 9, and 10) provide a qualitative visualization of the latent clustering, for which t-SNE (maaten2008visualizing)) was used to project the latent space down to 2D for Fashion-MNIST, and MNIST data. One can see that MIM learning tends to cluster classes in the latent representation more tightly, while VAE clusters are more diffuse and overlapping, consistent with the results in Table 2.

7 Related Work

Given the vast literature on generative models, here we only touch on the major bodies of work related to MIM.

VAEs Kingma2013 are widely used as latent variable models for representation learning. The VAE provides a strong sampling capability (e.g., DBLP:journals/corr/abs-1901-03416), considered as a proxy for representation quality, in addition to auxiliary tasks such as classification DBLP:journals/corr/BengioTPPB17. Nevertheless, it has been observed that a powerful decoder can suffer from posterior collapse (DBLP:journals/corr/BowmanVVDJB15; ChenKSDDSSA16; DBLP:journals/corr/abs-1901-03416; DBLP:journals/corr/OordKK16; DBLP:journals/corr/abs-1711-00937), where the decoder effectively ignores the encoder in some dimensions, and the learned representation has low mutual information with the observations. While several attempts to mitigate the problem have been proposed DBLP:journals/corr/abs-1711-00464; DBLP:journals/corr/abs-1901-03416, the root cause has not been identified.

As mentioned above, mutual information, together with disentanglement, is considered to be a cornerstone for useful representations (Hjelm2018; hjelm2018learning). Normalizing flows (Rezende2015; Dinh2014; Dinh2016a; Kingma2018; DBLP:journals/corr/abs-1902-00275) directly maximizes mutual information by restricting the architecture to be invertible and tractable. This, however, requires the latent dimension to be the same as the dimension of the observations (i.e., no bottleneck). As a consequence, normalizing flows are not well suited to learning a concise representation of high dimensional data (e.g., images). Here, MIM often yields mappings that are approximately invertible, with high mutual information and low reconstruction errors.

The Bidirectional Helmholtz Machine DBLP:journals/corr/BornscheinSFB15 shares some of the same design principles as MIM, i.e., symmetry and encode/decoder consistency. However, their formulation models the joint density in terms of the geometric mean between the encoder and decoder, for which one must compute an expensive partition function. (pu2017adversarial) focus on minimizing symmetric KL, but must use an adversarial learning procedure, while MIM can be minimized directly.

GANs (NIPS2014_5423), which focus mainly on decoder properties, without a proper inference model, have been shown to minimize JSD between the data anchor and the model generative process (i.e., the marginal of the decoding distribution in MIM terms). In particular, prior work recognizes the importance of symmetry in learning generative models with reference to symmetric discriminators on and (Bang-BiGAN2018; DonahueKD16-BiGAN; dumoulin2016adversarially). In contrast, here we target JSD between the joint encoding and decoding distributions, together with a regularizer to encourage high mutual information.

8 Conclusions

We introduce a new representation learning framework, named the mutual information machine (MIM), that defines a generative model which directly targets high mutual information (i.e., between the observations and the latent representation), and symmetry (i.e., consistency of encoding and decoding factorizations of the joint distribution). We derive a variational bound that enables the maximizion of mutual information in the learned representation for high dimensional continuous data, without the need to directly compute it. We then provide a possible explanation for the phenomena of posterior collapse, and demonstrate that MIM does not suffer from it. Empirical comparisons to VAEs show that MIM learning leads to higher mutual information and better clustering (and classification) in the latent representation, given the same architecture and parametrization. In addition, we show that MIM can provide reconstruction error similar to a deterministic auto-encoder, when the dimensionality of the latent representation is equal to that of the observations. Such behaviour can potentially allow approximate invertibility when the dimensionality differs, with a stochastic mapping that is defined through consistency and high mutual information.

In future work, we intend to focus on utilizing the high mutual information mapping provided by MIM, by exploiting the clustered latent representation to further improve the resulting generative model.

Acknowledgements

Many thanks to Ethan Fetaya, Jacob Goldberger, Roger Grosse, Chris Maddison, and Daniel Roy for interesting discussions and for their helpful comments. We are especially grateful to Sajad Nourozi for extensive discussions and for his help to empirically validate the formulation and experimental work. This work was financially supported in part by the Canadian Institute for Advanced Research (Program on Learning in Machines and Brains), and NSERC Canada.

References

\doparttoc\faketableofcontents

Appendix

\parttoc

Appendix A Derivations for MIM Formulation

In what follows we provide detailed derivations of key elements of the formulation in the paper, namely, Equations (6), (10), (12), and (13). We also consider the relation between MIM based on the Jensen-Shannon divergence and the symmetric KL divergence.

a.1 JSD and Entropy Objectives

First we develop the relation in Eqn. (6), between Jensen-Shannon divergence of the encoder and decoder, the average joint entropy of the encoder and decoder, and the joint entropy of the mixture distribution .

The Jensen-Shannon divergence with respect to the encoding distribution and the decoding distribution is defined as

Where is a mixture of the encoding and decoding distributions. Adding to the JSD term gives

a.2 MIM Consistency

Here we discuss in greater detail how the learning algorithm encourages consistency between the encoder and decoder of a MIM model, beyond the fact that they are fit to the same sample distribution. To this end we expand on several properties of the model and the optimization procedure.

a.2.1 MIM consistency objective

In what follows we derive the form of the MIM consistency term, , given in Eqn. (10). Recall that we define . We can show that is equivalent to plus a regularizer by taking their difference.

where is non-negative, and is zero only when the encoding and decoding distributions are consistent (i.e., they represent the same joint distribution). To prove that and to derive Eqn. (12), we now construct Eqn. (10) in terms of expectation over a joint distribution, which yields

where the inequality follows Jensen’s inequality, and equality holds only when (i.e., encoding and decoding distributions are consistent).

a.2.2 Self-Correcting Gradient

One important property of the optimization follows directly from the difference between the gradient of the upper bound and the gradient of the cross-entropy loss . By moving the gradient operator into the expectation using reparametrization, one can express the gradient of in terms of the gradient of and the regularization term in Eqn. (10). That is, with some manipulation one obtains

(20)

which shows that for any data point where a gap exists, the gradient applied to grows with the gap, while placing correspondingly less weight on the gradient applied to . The opposite is true when . In both case this behaviour encourages consistency between the encoder and decoder. Empirically, we find that the encoder and decoder become reasonably consistent early in the optimization process.

a.2.3 Numerical Stability

Instead of optimizing an upper bound , one might consider a direct optimization of . Earlier we discussed the importance of the consistency regularizer in . Here we motivate the use of from a numerical perspective point of view. In order to optimize directly, one must convert and to and . Unfortunately, this is has the potential to produce numerical errors, especially with 32-bit floating-point precision on GPUs. While various tricks can reduce numerical instability, we find that using the upper bound eliminates the problem while providing the additional benefits outlined above.

a.2.4 Tractability

A linear mixture vis the JSD is not the only way one might combine the encoder and decoder in a symmetric fashion. An alternative to MIM, explored in [DBLP:journals/corr/BornscheinSFB15], is to use a product; i.e.,

(21)

where is the partition function. One can then define the objective to be the cross-entropy as above with a regularizer to encourage to be close to 1, and hence to encourage consistency between the encoder and decoder. This, however, requires a good approximation to the partition function. Our choice of avoids the need for a good value approximation by using reparameterization, which results in unbiased low-variance gradient, independent of the accuracy of the approximation of the value.

a.3 MIM Loss Decomposition

Here we show how to break down the into the set of intuitive components given in Eqn. (13). To this end, first note the definition of :

(22)

We will focus on the first half of Eqn. (22) for now,

(23)

It will be more clear to write out the first term of Eqn. (23),