Prescribed Generative Adversarial Networks
GAN are a powerful approach to unsupervised learning. They have achieved state-of-the-art performance in the image domain. However, \glsplGAN are limited in two ways. They often learn distributions with low support—a phenomenon known as mode collapse—and they do not guarantee the existence of a probability density, which makes evaluating generalization using predictive log-likelihood impossible. In this paper, we develop the prescribed \glsGAN (Pres\glsGAN) to address these shortcomings. Pres\glsplGAN add noise to the output of a density network and optimize an entropy-regularized adversarial loss. The added noise renders tractable approximations of the predictive log-likelihood and stabilizes the training procedure. The entropy regularizer encourages Pres\glsplGAN to capture all the modes of the data distribution. Fitting Pres\glsplGAN involves computing the intractable gradients of the entropy regularization term; Pres\glsplGAN sidestep this intractability using unbiased stochastic estimates. We evaluate Pres\glsplGAN on several datasets and found they mitigate mode collapse and generate samples with high perceptual quality. We further found that Pres\glsplGAN reduce the gap in performance in terms of predictive log-likelihood between traditional \glsplGAN and \glsplVAE.111Code: The code for this paper can be found at https://github.com/adjidieng/PresGANs.
Keywords: generative adversarial networks, entropy regularization, log-likelihood evaluation, mode collapse, diverse image generation, deep generative models
fontsize= \newacronymALIaliadversarially learned inference \newacronymBIGANbiganbidirectional generative adversarial network \newacronymVIvivariational inference \newacronymKLklKullback-Leibler \newacronymELBOelboevidence lower bound \newacronymMCMCmcmcMarkov chain Monte Carlo \newacronymHMChmcHamiltonian Monte Carlo \newacronymRNNrnnrecurrent neural network \newacronymMLPmlpfeed forward neural network \newacronymVAEvaevariational auto-encoder \newacronymGANgangenerative adversarial network \newacronymDCGANdcgandeep convolutional generative adversarial network \newacronymPresGANpresganprescribed generative adversarial network \newacronymDGMdgmdeep generative model \newacronymPGANpganprescribed generative adversarial network \newacronymVEEGANveeganveegan \newacronymPACGANpacganpacked GAN \newacronymSTYLEGANstyleganStyle GAN \newacronymFIDfidFréchet Inception distance
1 Introduction\glsresetall \Glspl
GAN (goodfellow2014generative) are a family of generative models that have shown great promise. They achieve state-of-the-art performance in the image domain; for example image generation (karras2019style; brock2018large), image super-resolution (ledig2017photo), and image translation (isola2017image).
GAN learn densities by defining a sampling procedure. A latent variable is sampled from a prior and a sample is generated by taking the output of a neural network with parameters , called a generator, that takes as input. The density implied by this sampling procedure is implicit and undefined (mohamed2016learning). However, \glsplGAN effectively learn the parameters by introducing a classifier —a deep neural network with parameters , called discriminator—that distinguishes between generated samples and real data , with distribution . The parameters and are learned jointly by optimizing the \glsGAN objective,
GAN iteratively maximize the loss in Eq. 1 with respect to and minimize it with respect to .
In practice, the minimax procedure described above is stopped when the generator produces realistic images. This is problematic because high perceptual quality does not necessarily correlate with goodness of fit to the target density. For example, memorizing the training data is a trivial solution to achieving high perceptual quality. Fortunately, \glsplGAN do not merely memorize the training data (zhang2017discrimination; arora2017generalization).
However \glsplGAN are able to produce images indistinguishable from real images while still failing to fully capture the target distribution (brock2018large; karras2019style). Indeed \glsplGAN suffer from an issue known as mode collapse. When mode collapse happens, the generative distribution is degenerate and of low support (arora2017generalization; arora2018gans). Mode collapse causes \glsplGAN, as density estimators, to fail both qualitatively and quantitatively. Qualitatively, mode collapse causes lack of diversity in the generated samples. This is problematic for certain applications of \glsplGAN, e.g. data augmentation. Quantitatively, mode collapse causes poor generalization to new data. This is because when mode collapse happens, there is a (support) mismatch between the learned distribution and the data distribution. Using annealed importance sampling with a kernel density estimate of the likelihood, wu2016quantitative report significantly worse log-likelihood scores for \glsplGAN when compared to \glsplVAE. Similarly poor generalization performance was reported by grover2018flow.
A natural way to prevent mode collapse in \glsplGAN is to maximize the entropy of the generator (belghazi2018mine). Unfortunately the entropy of \glsplGAN is unavailable. This is because the existence of the generative density is not guaranteed (mohamed2016learning; arjovsky2017wasserstein).
In this paper, we propose a method to alleviate mode collapse in \glsplGAN resulting in a new family of \glsplGAN called prescribed \glsplGAN (Pres\glsplGAN). Pres\glsplGAN prevent mode collapse by explicitly maximizing the entropy of the generator. This is done by augmenting the loss in Eq. 1 with the negative entropy of the generator, such that minimizing Eq. 1 with respect to corresponds to fitting the data while also maximizing the entropy of the generative distribution. The existence of the generative density is guaranteed by adding noise to the output of a density network (mackay1995bayesian; diggle1984monte). This process defines the generative distribution , not as an implicit distribution as in standard \glsplGAN, but as an infinite mixture of well-defined densities as in continuous \glsplVAE (kingma2013auto; rezende2014stochastic). The generative distribution of Pres\glsplGAN is therefore very flexible.
Although the entropy of the generative distribution of Pres\glsplGAN is well-defined, it is intractable. However, fitting a Pres\glsGAN to data only involves computing the gradients of the entropy and not the entropy itself. Pres\glsplGAN use unbiased Monte Carlo estimates of these gradients.
An illustrative example. To demonstrate how Pres\glsplGAN alleviate mode collapse, we form a target distribution by organizing a uniform mixture of two-dimensional Gaussians on a ring. We draw samples from this target distribution. We first fit a \glsGAN, setting the hyperparameters so that the \glsGAN suffers from mode collapse222A \glsGAN can perfectly fit this distribution when choosing the right hyperparameters.. We then use the same settings for Pres\glsGAN to assess whether it can correct the collapsing behavior of the \glsGAN. Figure 1 shows the collapsing behavior of the \glsGAN, which misses modes of the target distribution. The Pres\glsGAN, on the other hand, recovers all the modes. Section 5 provides details about the settings of this synthetic experiment.
Contributions. This paper contributes to the literature on the two main open problems in the study of \glsplGAN: preventing mode collapse and evaluating log-likelihood.
How can we perform entropy regularization of the generator of a \glsGAN so as to effectively prevent mode collapse? We achieve this by adding noise to the output of the generator; this ensures the existence of a density and makes its entropy well-defined. We then regularize the \glsGAN loss to encourage densities with high entropy. During training, we form unbiased estimators of the (intractable) gradients of the entropy regularizer. We show how this prevents mode collapse, as expected, in two sets of experiments (see Section 5). The first experiment follows the current standard for measuring mode collapse in the \glsGAN literature, which is to report the number of modes recovered by the \glsGAN on mnist ( modes) and stackedmnist ( modes) and the \glsKL divergence between the true label distribution and the one induced by the \glsGAN. We conducted a second experiment which sheds light on another way mode collapse can occur in \glsplGAN, which is when the data is imbalanced.
How can we measure log-likelihood in \glsplGAN? Evaluating log-likelihood for \glsplGAN allows assessing how they generalize to new data. Existing measures focus on sample quality, which is not a measure of generalization. This inability to measure predictive log-likelihood for \glsplGAN has restricted their use to domains where one can use perceptual quality measures (e.g., the image domain). Existing methods for evaluating log-likelihood for \glsplGAN either use a proxy to log-likelihood (sanchez2019out) or define the likelihood of the generator only at test time, which creates a mismatch between training and testing (wu2016quantitative), or assume invertibility of the generator of the \glsGAN (grover2018flow). Adding noise to the output of the generator immediately makes tractable predictive log-likelihood evaluation via importance sampling.
Outline. The rest of the paper is organized as follows. In Section 2 we set the notation and provide desiderata for deep generative modeling. In Section 3 we describe Pres\glsplGAN and how we compute their predictive log-likelihood to assess generalization. In Section 4 we discuss related work. We then assess the performance of Pres\glsplGAN in terms of mode collapse, sample quality, and log-likelihood in Section 5. Finally, we conclude and discuss key findings in Section 6.
In this paper, we characterize a \glsDGM by its generative process and by the loss used to fit its parameters. We denote by the generative distribution induced by the generative process—it is parameterized by a deep neural network with parameters . The loss, that we denote by , often requires an additional set of parameters that help learn the model parameters . We next describe choices for and and then specify desiderata for deep generative modeling.
The generative distribution. Recent \glsplDGM define the generative distribution either as an implicit distribution or as an infinite mixture (goodfellow2014generative; kingma2013auto; rezende2014stochastic).
Implicit generative models define a density using a sampling procedure. This is the approach of \glsplGAN (goodfellow2014generative). A latent variable is sampled from a prior , usually a standard Gaussian or a uniform distributon, and a sample is generated by taking the output of a neural network that takes as input. The density implied by this sampling procedure is undefined. Any measure that relies on an analytic form of the density is therefore unavailable; e.g., the log-likelihood or the entropy.
An alternative way to define the generative distribution is by using the approach of \glsplVAE (kingma2013auto; rezende2014stochastic). They define as an infinite mixture,
Here the mixing distribution is the prior . The conditional distribution is an exponential family distribution, such as a Gaussian or a Bernoulli, parameterized by a neural network with parameters . Although both the prior and are simple tractable distributions, the generative distribution is highly flexible albeit intractable. Because in Eq. 2 is well-defined, the log-likelihood and the entropy are also well-defined (although they may be analytically intractable).
The loss function. Fitting the models defined above requires defining a learning procedure by specifying a loss function. \glsplGAN introduce a classifier , a deep neural network parameterized by , to discriminate between samples from the data distribution and the generative distribution . The auxiliary parameters are learned jointly with the model parameters by optimizing the loss in Eq. 1. This training procedure leads to high sample quality but often suffers from mode collapse (arora2017generalization; arora2018gans).
An alternative approach to learning is via maximum likelihood. This requires a well-defined density such as the one in Eq. 2. Although well-defined, is intractable, making it difficult to learn the parameters by maximum likelihood. \glsplVAE instead introduce a recognition network—a neural network with parameters that takes data as input and outputs a distribution over the latent variables —and maximize a lower bound on with respect to both and ,
Here denotes the \glsKL divergence. Maximizing is equivalent to minimizing this \glsKL which leads to issues such as latent variable collapse (bowman2015generating; dieng2018avoiding). Furthermore, optimizing Eq. 3 may lead to blurriness in the generated samples because of a property of the reverse known as zero-forcing (minka2005divergence).
Desiderata. We now outline three desiderata for \glsplDGM.
High sample quality. A \glsDGM whose parameters have been fitted using real data should generate new data with the same qualitative precision as the data it was trained with. For example, if a \glsDGM is trained on a dataset composed of human faces, it should generate data with all features that make up a face at the same resolution as the training data.
High sample diversity. High sample quality alone is not enough. For example, a degenerate \glsDGM that is only able to produce one single sample is not desirable, even if the sample quality is perfect. Therefore we require sample diversity; a \glsDGM should ideally capture all modes of the data distribution.
Tractable predictive log-likelihood. \glsplDGM are density estimators and as such we should evaluate how they generalize to new data. High sample quality and diversity are not measures of generalization. We therefore require tractable predictive log-likelihood as a desideratum for deep generative modeling.
We next introduce a new family of \glsplGAN that fulfills all the desiderata.
3 Prescribed Generative Adversarial Networks
Pres\glsplGAN generate data following the generative distribution in Eq. 2. Note that this generative process is the same as for standard \glsplVAE (kingma2013auto; rezende2014stochastic). In particular, Pres\glsplGAN set the prior and the likelihood to be Gaussians,
The mean and covariance of the conditional are given by a neural network that takes as input.
In general, both the mean and the covariance can be functions of . For simplicity, in order to speed up the learning procedure, we set the covariance matrix to be diagonal with elements independent from , i.e., , and we learn the vector together with . From now on, we parameterize the mean with , write , and define as the parameters of the generative distribution.
To fit the model parameters , Pres\glsplGAN optimize an adversarial loss similarly to \glsplGAN. In doing so, they keep \glsplGAN’ ability to generate samples with high perceptual quality. Unlike \glsplGAN, the entropy of the generative distribution of Pres\glsplGAN is well-defined, and therefore Pres\glsplGAN can prevent mode collapse by adding an entropy regularizer to Eq. 1. Furthermore, because Pres\glsplGAN define a density over their generated samples, we can measure how they generalize to new data using predictive log-likelihood. We describe the entropy regularization in Section 3.1 and how to approximate the predictive log-likelihood in Section 3.3.
3.1 Avoiding mode collapse via entropy regularization
One of the major issues that \glsplGAN face is mode collapse, where the generator tends to model only some parts or modes of the data distribution (arora2017generalization; arora2018gans). Pres\glsplGAN mitigate this problem by explicitly maximizing the entropy of the generative distribution,
Here denotes the entropy of the generative distribution. It is defined as
The loss in Eq. 5 can be that of any of the existing \glsGAN variants. In Section 5 we explore the standard \glsDCGAN (radford2015unsupervised) and the more recent Style\glsGAN (karras2019style) architectures.
The constant in Eq. 5 is a hyperparameter that controls the strength of the entropy regularization. In the extreme case when , the loss function of Pres\glsGAN coincides with the loss of a \glsGAN, where we replaced its implicit generative distribution with the infinite mixture in Eq. 2. In the other extreme when , optimizing corresponds to fitting a maximum entropy generator that ignores the data. For any intermediate values of , the first term of encourages the generator to fit the data distribution, whereas the second term encourages to cover all of the modes of the data distribution.
3.2 Fitting Prescribed Generative Adversarial Networks
We fit Pres\glsplGAN following the same adversarial procedure used in \glsplGAN. That is, we alternate between updating the parameters of the generative distribution and the parameters of the discriminator . The full procedure is given in Algorithm 1. We now describe each part in detail.
Fitting the generator. We fit the generator using stochastic gradient descent. This requires computing the gradients of the Pres\glsGAN loss with respect to ,
We form stochastic estimates of based on reparameterization (kingma2013auto; rezende2014stochastic; titsias2014doubly); this requires differentiating Eq. 1. Specifically, we introduce a noise variable to reparameterize the conditional from Eq. 4,333With this reparameterization we use the notation instead of to denote a sample from the generative distribution.
where and . Here and denote the mean and standard deviation of the conditional , respectively. We now write the first term of Eq. 7 as an expectation with respect to the latent variable and the noise variable and push the gradient into the expectation,
In practice we use an estimate of Eq. 9 using one sample from and one sample from ,
The second term in Eq. 7, corresponding to the gradient of the entropy, is intractable. We estimate it using the same approach as titsias2018unbiased. We first use the reparameterization in Eq. 8 to express the gradient of the entropy as an expectation,
where we have used the score function identity on the second line. We form a one-sample estimator of the gradient of the entropy as
In Eq. 11, the gradient with respect to the reparameterization transformation is tractable and can be obtained via back-propagation. We now derive ,
While this expression is still intractable, we can estimate it. One way is to use self-normalized importance sampling with a proposal learned using moment matching with an encoder (dieng2019reweighted). However, this would lead to a biased (albeit asymptotically unbiased) estimate of the entropy. In this paper, we form an unbiased estimate of using samples from the posterior,
We obtain these samples using \glsHMC (neal2011mcmc). Crucially, in order to speed up the algorithm, we initialize the \glsHMC sampler at stationarity. That is, we initialize the \glsHMC sampler with the sample that was used to produce the generated sample in Eq. 8, which by construction is an exact sample from . This implies that only a few \glsHMC iterations suffice to get good estimates of the gradient (titsias2018unbiased). We also found this holds empirically; for example in Section 5 we use burn-in iterations and \glsHMC samples to form the Monte Carlo estimate in Eq. 12.
In particular, the gradient with respect to the generator’s parameters is unbiasedly approximated by
and the gradient estimator with respect to the standard deviation is
These gradients are used in a stochastic optimization algorithm to fit the generative distribution of Pres\glsGAN.
Fitting the discriminator. Since the entropy term in Eq. 5 does not depend on , optimizing the discriminator of a Pres\glsGAN is analogous to optimizing the discriminator of a \glsGAN,
To prevent the discriminator from getting stuck in a bad local optimum where it can perfectly distinguish between real and generated data by relying on the added noise, we apply the same amount of noise to the real data as the noise added to the generated data. That is, when we train the discriminator we corrupt the real data according to
where is the standard deviation of the generative distribution and denotes the real data. We then let the discriminator distinguish between and from Eq. 8.
This data noising procedure is a form of instance noise (sonderby2016amortised). However, instead of using a fixed annealing schedule for the noise variance as sonderby2016amortised, we let be part of the parameters of the generative distribution and fit it using gradient descent according to Eq. 3.2.
Stability. Data noising stabilizes the training procedure and prevents the discriminator from perfectly being able to distinguish between real and generated samples using the background noise. We refer the reader to ferenc2016instance for a detailed exposition.
When fitting Pres\glsplGAN, data noising is not enough to stabilize training. This is because there are two failure cases brought in by learning the variance using gradient descent. The first failure mode is when the variance gets very large, leading to a generator completely able to fool the discriminator. Because of data noising, the discriminator cannot distinguish between real and generated samples when the variance of the noise is large.
The second failure mode is when gets very small, which makes the gradient of the entropy in Eq. 3.2 dominate the overall gradient of the generator. This is problematic because the learning signal from the discriminator is lost.
To stabilize training and avoid the two failure cases discussed above we truncate the variance of the generative distribution, (we apply this truncation element-wise). The limits and are hyperparameters.
3.3 Enabling tractable predictive log-likelihood approximation
Replacing the implicit generative distribution of \glsplGAN with the infinite mixture distribution defined in Eq. 2 has the advantage that the predictive log-likelihood can be tractably approximated. Consider an unseen datapoint . We estimate its log marginal likelihood using importance sampling,
where we draw samples from a proposal distribution .
There are different ways to form a good proposal , and we discuss several alternatives in Section 7.1 of the appendix. In this paper, we take the following approach. We define the proposal as a Gaussian distribution,
We set the mean parameter to the maximum a posteriori solution, i.e., . We initialize this maximization algorithm using the mean of a pre-fitted encoder, . The encoder is fitted by minimizing the reverse \glsKL divergence between and the true posterior using the training data. This \glsKL is
Because the generative distribution is fixed at test time, minimizing the \glsKL here is equivalent to maximizing the second term in Eq. 20, which is the \glsELBO objective of \glsplVAE.
We set the proposal covariance as an overdispersed version444In general, overdispersed proposals lead to better importance sampling estimates. of the encoder’s covariance matrix, which is diagonal. In particular, to obtain we multiply the elements of the encoder’s covariance by a factor . In Section 5 we set to .
4 Related Work\glspl
GAN (goodfellow2014generative) have been extended in multiple ways, using alternative distance metrics and optimization methods (see, e.g., li2015generative; dziugaite2015training; nowozin2016f; arjovsky2017wasserstein; ravuri2018learning; genevay2017learning) or using ideas from \glsplVAE (makhzani2015adversarial; mescheder2017adversarial; dumoulin2016adversarially; donahue2016adversarial; tolstikhin2017wasserstein; ulyanov2018takes; rosca2017variational).
Other extensions aim at improving the sample diversity of \glsplGAN. For example, srivastava2017veegan use a reconstructor network that reverses the action of the generator. lin2018pacgan use multiple observations (either real or generated) as an input to the discriminator to prevent mode collapse. azadi2018discriminator and turner2018metropolis use sampling mechanisms to correct errors of the generative distribution. xiao2018bourgan relies on identifying the geometric structure of the data embodied under a specific distance metric. Other works have combined adversarial learning with maximum likelihood (grover2018flow; yin2019semi); however, the low sample quality induced by maximum likelihood still occurs. Finally, cao2018improving introduce a regularizer for the discriminator to encourage diverse activation patterns in the discriminator across different samples. In contrast to these works, Pres\glsplGAN regularize the entropy of the generator to prevent mode collapse.
The idea of entropy regularization has been widely applied in many problems that involve estimation of unknown probability distributions. Examples include approximate Bayesian inference, where the variational objective contains an entropy penalty (jordan1998learning; bishop2006pattern; wainwright2008graphical; blei2017variational); reinforcement learning, where the entropy regularization allows to estimate more uncertain and explorative policies (schulman2015trust; mnih2016asynchronous); statistical learning, where entropy regularization allows an inferred probability distribution to avoid collapsing to a deterministic solution (freund1997decision; soofi2000principal; jaynes2003probability); or optimal transport (rigollet2018entropic). More recently, kumar2019maximum have developed maximum-entropy generators for energy-based models using mutual information as a proxy for entropy.
Another body of related work is about how to quantitatively evaluate \glsplGAN. Inception scores measure the sample quality of \glsplGAN and are used extensively in the \glsGAN literature (salimans2016improved; heusel2017gans; binkowski2018demystifying). However, sample quality measures only assess the quality of \glsplGAN as data generators and not as density estimators. Density estimators are evaluated for generalization to new data. Predictive log-likelihood is a measure of goodness of fit that has been used to assess generalization; for example in \glsplVAE. Finding ways to evaluate predictive log-likelihood for \glsplGAN has been an open problem, because \glsplGAN do not define a density on the generated samples. wu2016quantitative use a kernel density estimate (parzen1962estimation) and estimate the log-likelihood with annealed importance sampling (neal2001annealed). balaji2018entropic show that an optimal transport \glsGAN with entropy regularization can be viewed as a generative model that maximizes a variational lower bound on average sample likelihoods, which relates to the approach of \glsplVAE (kingma2013auto). sanchez2019out propose Eval\acrshortGAN, a method to estimate the likelihood. Given an observation , Eval\acrshortGAN first finds the closest observation that the \glsGAN is able to generate, and then it estimates the likelihood by approximating the proportion of samples that lead to samples that are close to . Eval\glsGAN requires selecting an appropriate distance metric for each problem and evaluates \glsplGAN trained with the usual implicit generative distribution. Finally, grover2018flow assume invertibility of the generator to make log-likelihood tractable.
5 Empirical Study
Here we demonstrate Pres\glsplGAN’ ability to prevent mode collapse and generate high-quality samples. We also evaluate its predictive performance as measured by log-likelihood.
5.1 An Illustrative Example
In this section, we fit a \glsGAN to a toy synthetic dataset of modes. We choose the hyperparameters such that the \glsGAN collapses. We then apply these same hyperparameters to fit a Pres\glsGAN on the same synthetic dataset. This experiment demonstrates the Pres\glsGAN’s ability to correct the mode collapse problem of a \glsGAN.
We form the target distribution by organizing a uniform mixture of two-dimensional Gaussians on a ring. The radius of the ring is and each Gaussian has standard deviation . We then slice the circle into parts. The location of the centers of the mixture components are determined as follows. Consider the mixture component. Its coordinates in the D space are
We draw samples from the target distribution and fit a \glsGAN and a Pres\glsGAN.
We set the dimension of the latent variables used as the input to the generators to . We let both the generators and the discriminators have three fully connected layers with tanh activations and hidden units in each layer. We set the minibatch size to and use Adam for optimization (kingma2014adam), with a learning rate of and for the discriminator and the generator respectively. The Adam hyperparameters are and . We take one step to optimize the generator for each step of the discriminator. We pick a random minibatch at each iteration and run both the \glsGAN and the Pres\glsGAN for epochs.
For Pres\glsGAN we set the burn-in and the number of \glsHMC samples to . We choose a standard number of leapfrog steps and set the \glsHMC learning rate to . The acceptance rate is fixed at . The log-variance of the noise of the generative distribution of Pres\glsGAN is initialized at . We put a threshold on the variance to a minimum value of and a maximum value of . The regularization parameter is . We fit the log-variance using Adam with a learning rate of .
Figure 1 demonstrates how the Pres\glsGAN alleviates mode collapse. The distribution learned by the regular \glsGAN misses modes of the target distribution. The Pres\glsGAN is able to recover all the modes of the target distribution.
5.2 Assessing mode collapse
In this section we evaluate Pres\glsplGAN’ ability to mitigate mode collapse on real datasets. We run two sets of experiments. In the first set of experiments we adopt the current experimental protocol for assessing mode collapse in the \glsGAN literature. That is, we use the mnist and stackedmnist datasets, for which we know the true number of modes, and report two metrics: the number of modes recovered by the Pres\glsGAN and the \glsKL divergence between the label distribution induced by the Pres\glsGAN and the true label distribution. In the second set of experiments we demonstrate that mode collapse can happen in \glsplGAN even when the number of modes is as low as but the data is imbalanced.
Increased number of modes. We consider the mnist and stackedmnist datasets. mnist is a dataset of hand-written digits,555See http://yann.lecun.com/exdb/mnist. in which each image corresponds to a digit. There are training digits and digits in the test set. mnist has modes, one for each digit. stackedmnist is formed by concatenating triplets of randomly chosen mnist digits along the color channel to form images of size (Metz2017). We keep the same size as the original mnist, training digits for test digits. The total number of modes in stackedmnist is , corresponding to the number of possible triplets.
We consider \glsDCGAN as the base architecture and, following radford2015unsupervised, we resize the spatial resolution of images to pixels.
|Pres\acrshortGAN (this paper)|
|Pres\acrshortGAN (this paper)|
To measure the degree of mode collapse we form two diversity metrics, following srivastava2017veegan. Both of these metrics require to fit a classifier to the training data. Once the classifier has been fit, we sample images from the generator. The first diversity metric is the number of modes captured, measured by the number of classes that are captured by the classifier. We say that a class has been captured if there is at least one generated sample for which the probability of being assigned to class is the largest. The second diversity metric is the \glsKL divergence between two discrete distributions: the empirical average of the (soft) output of the classifier on generated images, and the empirical average of the (soft) output of the classifier on real images from the test set. We choose the number of generated images to match the number of test samples on each dataset. That is, for both mnist and stackedmnist. We expect the \glsKL divergence to be zero if the distribution of the generated samples is indistinguishable from that of the test samples.
We measure the two mode collapse metrics described above against \glsDCGAN (radford2015unsupervised) (the base architecture of Pres\glsGAN for this experiment). We also compare against other methods that aim at alleviating mode collapse in \glsplGAN, namely, \acrshortVEEGAN (srivastava2017veegan) and \acrshortPACGAN (lin2018pacgan). For Pres\glsGAN we set the entropy regularization parameter to . We chose the variance thresholds to be and .
Tables 5.2 and 5.2 show the number of captured modes and the \glsKL for each method. The results are averaged across runs. All methods capture all the modes of mnist. This is not the case on stackedmnist, where the Pres\glsGAN is the only method that can capture all the modes. Finally, the proportion of observations in each mode of Pres\glsGAN is closer to the true proportion in the data, as evidenced by lower \acrshortKL divergence scores.
We also study the impact of the entropy regularization by varying the hyperparameter from to . Table 5.2 illustrates the results. Unsurprisingly, when there is no entropy regularization, i.e., when , then mode collapse occurs. This is also the case when the level of regularization is not enough (). There is a whole range of values for such that mode collapse does not occur (). Finally, when is too high for the data and architecture under study, mode collapse can still occur. This is because when is too high, the entropy regularization term dominates the loss in Eq. 5 and in turn the generator does not fit the data as well. This is also evidenced by the higher \glsKL divergence score when vs. when .
Increased data imbalance. We now show that mode collapse can occur in \glsplGAN when the data is imbalanced, even when the number of modes of the data distribution is small. We follow dieng2018learning and consider a perfectly balanced version of mnist as well as nine imbalanced versions. To construct the balanced dataset we used training examples per class, totaling training examples. We refer to this original balanced dataset as . Each additional training set leaves only training examples for each class , and for the rest. (See the Appendix for all the class distributions.)
We used the same classifier trained on the unmodified mnist but fit each method on each of the new mnist distributions. We chose for Pres\glsGAN. Figure 2 illustrates the results in terms of both metrics—number of modes and \glsKL divergence. \glsDCGAN, \acrshortVEEGAN, and \acrshortPACGAN face mode collapse as the level of imbalance increases. This is not the case for Pres\glsGAN, which is robust to imbalance and captures all the modes.
5.3 Assessing sample quality
In this section we assess Pres\glsplGAN’ ability to generate samples of high perceptual quality. We rely on perceptual quality of generated samples and on \glsFID scores (heusel2017gans). We also consider two different \glsGAN architectures, the standard \glsDCGAN and the more recent Style\glsGAN, to show robustness of Pres\glsplGAN vis-a-vis the underlying \glsGAN architecture.
DCGAN. We use \glsDCGAN (radford2015unsupervised) as the base architecture and build Pres\glsGAN on top of it. We consider four datasets: mnist, stackedmnist, cifar-10, and CelebA. cifar-10 (krizhevsky2009learning) is a well-studied dataset of images that are classified into one of the following categories: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck. CelebA (liu2015deep) is a large-scale face attributes dataset. Following radford2015unsupervised, we resize all images to pixels. We use the default \glsDCGAN settings. We refer the reader to the code we used for \glsDCGAN, which was taken from https://github.com/pytorch/examples/tree/master/dcgan. We set the seed to for reproducibility.
|Pres\acrshortGAN (this paper)||mnist|
There are hyperparameters specific to Pres\glsGAN. These are the noise and \glsHMC hyperparameters. We set the learning rate for the noise parameters to and constrain its values to be between and for all datasets. We initialize to . We set the burn-in and the number of \glsHMC samples to . We choose a standard number of leapfrog steps and set the \glsHMC learning rate to . The acceptance rate is fixed at . We found that different values worked better for different datasets. We used for cifar-10 and celeba for mnist and stackedmnist.
We found the Pres\glsGAN’s performance to be robust to the default settings for most of these hyperparameters. However we found the initialization for and its learning rate to play a role in the quality of the generated samples. The hyperparameters mentioned above for worked well for all datasets.
Table 5.3 shows the \glsFID scores for \glsDCGAN and Pres\glsGAN across the four datasets. We can conclude that Pres\glsGAN generates images of high visual quality. In addition, the \glsFID scores are lower because Pres\glsGAN explores more modes than \glsDCGAN. Indeed, when the generated images account for more modes, the \glsFID sufficient statistics (the mean and covariance of the Inception-v3 pool3 layer) of the generated data get closer to the sufficient statistics of the empirical data distribution.
We also report the \glsFID for \acrshortVEEGAN and \acrshortPACGAN in Table 5.3. \acrshortVEEGAN achieves better \glsFID scores than \acrshortDCGAN on all datasets but celeba. This is because \acrshortVEEGAN collapses less than \acrshortDCGAN as evidenced by Table 5.2 and Table 5.2. \acrshortPACGAN achieves better \glsFID scores than both \acrshortDCGAN and \acrshortVEEGAN on all datasets but on stackedmnist where it achieves a significantly worse \glsFID score. Finally, Pres\glsGAN outperforms all of these methods on the \glsFID metric on all datasets signaling its ability to mitigate mode collapse while preserving sample quality.
Besides the \glsFID scores, we also assess the visual quality of the generated images. In Section 7.3 of the appendix, we show randomly generated (not cherry-picked) images from \glsDCGAN, \acrshortVEEGAN, \acrshortPACGAN, and Pres\glsGAN. For Pres\glsGAN, we show the mean of the conditional distribution of given . The samples generated by Pres\glsGAN have high visual quality; in fact their quality is comparable to or better than the \glsDCGAN samples.
Style\glsGAN. We now consider a more recent \glsGAN architecture (Style\glsGAN) (karras2019style) and a higher resolution image dataset (ffhq). ffhq is a diverse dataset of faces from Flickr666See https://github.com/NVlabs/ffhq-dataset. introduced by karras2019style. The dataset contains high-quality png images with considerable variation in terms of age, ethnicity, and image background. We use a resolution of pixels.
Style\glsGAN feeds multiple sources of noise to the generator. In particular, it adds Gaussian noise after each convolutional layer before evaluating the nonlinearity. Building Pres\glsGAN on top of Style\glsGAN therefore requires to sample all noise variables through \glsHMC at each training step. To speed up the training procedure, we only sample the noise variables corresponding to the input latent code and condition on all the other Gaussian noise variables. In addition, we do not follow the progressive growing of the networks of karras2019style for simplicity.
For this experiment, we choose the same \glsHMC hyperparameters as for the previous experiments but restrict the variance of the generative distribution to be . We set for this experiment.
Figure 3 shows cherry-picked images generated from Style\glsGAN and Pres\glsGAN. We can observe that the Pres\glsGAN maintains as good perceptual quality as the base architecture. In addition, we also observed that the Style\glsGAN tends to produce some redundant images (these are not shown in Figure 3), something that we did not observe with the Pres\glsGAN. This lack of diversity was also reflected in the \glsFID scores which were for Style\glsGAN and for Pres\glsGAN. These results suggest that entropy regularization effectively reduces mode collapse while preserving sample quality.
5.4 Assessing held-out predictive log-likelihood
In this section we evaluate Pres\glsplGAN for generalization using predictive log-likelihood. We use the \glsDCGAN architecture to build Pres\glsGAN and evaluate the log-likelihood on two benchmark datasets, mnist and cifar-10. We use images of size .
We compare the generalization performance of the Pres\glsGAN against the \glsVAE (kingma2013auto; rezende2014stochastic) by controlling for the architecture and the evaluation procedure. In particular, we fit a \glsVAE that has the same decoder architecture as the Pres\glsGAN. We form the \glsVAE encoder by using the same architecture as the \glsDCGAN discriminator and getting rid of the output layer. We used linear maps to get the mean and the log-variance of the approximate posterior.
To measure how Pres\glsplGAN compare to traditional \glsplGAN in terms of log-likelihood, we also fit a Pres\glsGAN with .
Evaluation. We control for the evaluation procedure and follow what’s described in Section 3.3 for all methods. We use samples to form the importance sampling estimator. Since the pixel values are normalized in , we use a truncated Gaussian likelihood for evaluation. Specifically, for each pixel of the test image, we divide the Gaussian likelihood by the probability (under the generative model) that the pixel is within the interval . We use the truncated Gaussian likelihood at test time only.
Settings. For the Pres\glsGAN, we use the same \glsHMC hyperparameters as for the previous experiments. We constrain the variance of the generative distribution using and . We use the default \glsDCGAN values for the remaining hyperparameters, including the optimization settings. For the cifar-10 experiment, we choose . We set all learning rates to . We set the dimension of the latent variables to . We ran both the \glsVAE and the Pres\glsGAN for a maximum of epochs. For mnist, we use the same settings as for cifar-10 but use and ran all methods for a maximum of epochs.
Results. Table 5.4 summarizes the results. Here \glsGAN denotes the Pres\glsGAN fitted using . The \glsVAE outperforms both the \glsGAN and the Pres\glsGAN on both mnist and cifar-10. This is unsurprising given \glsplVAE are fitted to maximize log-likelihood. The \glsGAN’s performance on cifar-10 is particularly bad, suggesting it suffered from mode collapse. The Pres\glsGAN, which mitigates mode collapse achieves significantly better performance than the \glsGAN on cifar-10. To further analyze the generalization performance, we also report the log-likelihood on the training set in Table 5.4. We can observe that the difference between the training log-likelihood and the test log-likelihood is very small for all methods.
We introduced the Pres\glsGAN, a variant of \glsplGAN that addresses two of their limitations. Pres\glsplGAN prevent mode collapse and are amenable to predictive log-likelihood evaluation. Pres\glsplGAN model data by adding noise to the output of a density network and optimize an entropy-regularized adversarial loss. The added noise stabilizes training, renders approximation of predictive log-likelihoods tractable, and enables unbiased estimators for the gradients of the entropy of the generative distribution. We evaluated Pres\glsplGAN on several image datasets. We found they effectively prevent mode collapse and generate samples of high perceptual quality. We further found that Pres\glsplGAN reduce the gap in performance between \glsplGAN and \glsplVAE in terms of predictive log-likelihood.
We found the level of entropy regularization plays an important role in mode collapse. We leave as future work the task of finding the optimal . We now discuss some insights that we concluded from our empirical study in Section 5.
Implicit distributions and sample quality. It’s been traditionally observed that \glsplGAN generate samples with higher perceptual quality than \glsplVAE. This can be explained by looking at the two ways in which \glsplGAN and \glsplVAE differ; the generative distribution and the objective function. \glsplVAE use prescribed generative distributions and optimize likelihood whereas \glsplGAN use implicit generative distributions and optimize an adversarial loss. Our results in Section 5 suggest that the implicit generators of traditional \glsplGAN are not the key to high sample quality; rather, the key is the adversarial loss. This is because Pres\glsplGAN use the same prescribed generative distributions as \glsplVAE and achieve similar or sometimes better sample quality than \glsplGAN.
Mode collapse, diversity, and imbalanced data. The current literature on measuring mode collapse in \glsplGAN only focuses on showing that mode collapse happens when the number of modes in the data distribution is high. Our results show that mode collapse can happen not only when the number of modes of the data distribution is high, but also when the data is imbalanced; even when the number of modes is low. Imbalanced data are ubiquitous. Therefore, mitigating mode collapse in \glsplGAN is important for the purpose of diverse data generation.
GAN and generalization. The main method to evaluate generalization for density estimators is predictive log-likelihood. Our results agree with the current literature that \glsplGAN don’t generalize as well as \glsplVAE which are specifically trained to maximize log-likelihood. However, our results show that entropy-regularized adversarial learning can reduce the gap in generalization performance between \glsplGAN and \glsplVAE. Methods that regularize \glsplGAN with the maximum likelihood objective achieve good generalization performance when compared to \glsplVAE but they sacrifice sample quality when doing so (grover2018flow). In fact we also experienced this tension between sample quality and high log-likelihood in practice.
Why is there such a gap in generalization, as measured by predictive log-likelihood, between \glsplGAN and \glsplVAE? In our empirical study in Section 5 we controlled for the architecture and the evaluation procedure which left us to compare maximizing likelihood against adversarial learning. Our results suggest mode collapse alone does not explain the gap in generalization performance between \glsplGAN and \glsplVAE. Indeed Table 5.4 shows that even on MNIST, where mode collapse does not happen, the \glsVAE achieves significantly better log-likelihood than a \glsGAN.
We looked more closely at the encoder fitted at test time to evaluate log-likelihood for both the \glsVAE and the \glsGAN (not shown in this paper). We found that the encoder implied by a fitted \glsGAN is very underdispersed compared to the encoder implied by a fitted \glsVAE. Underdispersed proposals have a negative impact on importance sampling estimates of log-likelihood. We tried to produce a more overdispersed proposal using the procedure described in Section 3.3. However we leave as future work learning overdispersed proposals for \glsplGAN for the purpose of log-likelihood evaluation.
We thank Ian Goodfellow, Andriy Mnih, Aaron Van den Oord, and Laurent Dinh for their comments. Francisco J. R. Ruiz is supported by the European Union’s Horizon 2020 research and innovation programme under the Marie Skłodowska-Curie grant agreement No. 706760. Adji B. Dieng is supported by a Google PhD Fellowship.
7.1 Other Ways to Compute Predictive Log-Likelihood
Here we discuss different ways to obtain a proposal in order to approximate the predictive log-likelihood. For a test instance , we estimate the marginal log-likelihood using importance sampling,
where we draw the samples from a proposal distribution . We next discuss different ways to form the proposal .
One way to obtain the proposal is to set as a Gaussian distribution whose mean and variance are computed using samples from an \acrshortHMC algorithm with stationary distribution . That is, the mean and variance of are set to the empirical mean and variance of the \acrshortHMC samples.
The procedure above requires to run an \acrshortHMC sampler, and thus it may be slow. We can accelerate the procedure with a better initialization of the \acrshortHMC chain. Indeed, the second way to evaluate the log-likelihood also requires the \acrshortHMC sampler, but it is initialized using a mapping . The mapping is a network that maps from observed space to latent space . The parameters of the network can be learned at test time using generated data. In particular, can be obtained by generating data from the fitted generator of Pres\acrshortGAN and then fitting to map to by maximum likelihood. This is, we first sample pairs from the learned generative distribution and then we obtain by minimizing . Once the mapping is fitted, we use it to initialize the \acrshortHMC chain.
A third way to obtain the proposal is to learn an encoder network jointly with the rest of the Pres\acrshortGAN parameters. This is effectively done by letting the discriminator distinguish between pairs and rather than discriminate against samples from the generative distribution. These types of discriminator networks have been used to learn a richer latent space for \acrshortGAN (donahue2016adversarial; dumoulin2016adversarially). In such cases, we can use the encoder network to define the proposal, either by setting or by initializing the \acrshortHMC sampler at the encoder mean.
The use of an encoder network is appealing but it requires a discriminator that takes pairs . The approach that we follow in the paper also uses an encoder network but keeps the discriminator the same as for the base \acrshortDCGAN. We found this approach to work better in practice. More in detail, we use an encoder network ; however the encoder is fitted at test time by maximizing the variational \acrshortELBO, given by . We set the proposal . (Alternatively, the encoder can be used to initialize a sampler.)
7.2 Assessing mode collapse under increased data imbalance
In the main paper we show that mode collapse can happen not only when there are increasing number of modes, as done in the \glsGAN literature, but also when the data is imbalanced. We consider a perfectly balanced version of mnist by using 5,000 training examples per class, totalling 50,000 training examples. We refer to this original balanced dataset as D. We build nine additional training sets from this balanced dataset. Each additional training set D leaves only training examples for each class . See Table 6 for all the class distributions.
7.3 Sample quality
Here we show some sample images generated by \acrshortDCGAN and Pres\acrshortGAN, together with real images from each dataset. These images were not cherry-picked, we randomly selected samples from all models. For Pres\acrshortGAN, we show the mean of the generator distribution, conditioned on the latent variable . In general, we observed the best image quality is achieved by the entropy-regularized Pres\acrshortGAN.