Improving the Improved Training of
A Consistency Term and Its Dual Effect
Despite being impactful on a variety of problems and applications, the generative adversarial nets (GANs) are remarkably difficult to train. This issue is formally analyzed by Arjovsky & Bottou (2017), who also propose an alternative direction to avoid the caveats in the minmax two-player training of GANs. The corresponding algorithm, called Wasserstein GAN (WGAN), hinges on the 1-Lipschitz continuity of the discriminator. In this paper, we propose a novel approach to enforcing the Lipschitz continuity in the training procedure of WGANs. Our approach seamlessly connects WGAN with one of the recent semi-supervised learning methods. As a result, it gives rise to not only better photo-realistic samples than the previous methods but also state-of-the-art semi-supervised learning results. In particular, our approach gives rise to the inception score of more than 5.0 with only 1,000 CIFAR-10 images and is the first that exceeds the accuracy of 90% on the CIFAR-10 dataset using only 4,000 labeled images, to the best of our knowledge.
|Xiang Wei, Boqing Gong††thanks: Equal contribution. , Zixia Liu, Wei Lu, Liqiang Wang|
|Department of Computer Science, University of Central Florida, Orlando, FL, USA 32816|
|School of Software Engineering, Beijing Jiaotong University, Beijing, China 100044|
|Tencent AI Lab, Bellevue, WA, USA 98004|
|email@example.com, firstname.lastname@example.org, email@example.com|
We have witnessed a great surge of interests in deep generative networks in recent years (Kingma & Welling, 2013; Goodfellow et al., 2014; Li et al., 2015). The central idea therein is to feed a random vector to a (e.g., feedforward) neural network and then take the output as the desired sample. This sampling procedure is very efficient without the need of any Markov chains.
In order to train such a deep generative network, two broad categories of methods are proposed. The first is to use stochastic variational inference (Kingma & Welling, 2013; Rezende et al., 2014; Kingma et al., 2014) to optimize the lower bound of the data likelihood. The other is to use the samples as a proxy to minimize the distribution divergence between the model and the real through a two-player game (Goodfellow et al., 2014; Salimans et al., 2016), maximum mean discrepancy (Li et al., 2015; Dziugaite et al., 2015; Li et al., 2017b), f-divergence (Nowozin et al., 2016; Nock et al., 2017), and the most recent Wasserstein distance (Arjovsky et al., 2017; Gulrajani et al., 2017).
With no doubt, the generative adversarial networks (GANs) among them (Goodfellow et al., 2014) have the biggest impact thus far on a variety of problems and applications (Radford et al., 2015; Denton et al., 2015; Im et al., 2016; Isola et al., 2016; Springenberg, 2015; Sutskever et al., 2015; Odena, 2016; Zhu et al., 2017). GANs learn the generative network (generator) by playing a two-player game between the generator and an auxiliary discriminator network. While the generator has no difference from other deep generative models in the sense that it translates a random vector into a desired sample, it is impossible to calculate the sample likelihood from it. Instead, the discriminator serves to evaluate the quality of the generated samples by checking how difficult it is to differentiate them from real data points.
However, it is remarkably difficult to train GANs without good heuristics (Goodfellow et al., 2014; Salimans et al., 2016; Radford et al., 2015) which may not generalize across different network architectures or application domains. The training dynamics are often unstable and the generated samples could collapse to limited modes. These issues are formally analyzed by Arjovsky & Bottou (2017), who also propose an alternative direction (Arjovsky et al., 2017) to avoid the caveats in the minmax two-player training of GANs. The corresponding algorithm, namely, Wasserstein GAN (WGAN), shows not only superior performance over GANs but also a nice correlation between the sample quality and the value function that GANs lack.
1.1 Background: WGAN and the Improved Training of WGAN
WGAN (Arjovsky et al., 2017) aims to learn the generator network , for any random vector , such that the Wasserstein distance is minimized between the resulting distribution of the generated samples and the real distribution underlying the observed data points ; in other words, . The Wasserstein distance is shown a more sensible cost function for learning the distributions supported by low-dimensional manifolds than the other popular distribution divergences and distances — for example, the Jensen-Shannon (JS) divergence implicitly employed in GANs (Goodfellow et al., 2014).
Due to the Kantorovich-Rubinstein duality (Villani, 2008) for calculating the Wasserstein distance, the value function of WGAN is then written as
where is the set of 1-Lipschitz functions. Analogous to GANs, we still call the “discriminator” although it is actually a real-valued function and not a classifier at all. Arjovsky et al. (2017) specify this family of functions by neural networks and then use weight clipping to enforce the Lipschitz continuity. However, as the authors note, the networks’ capacities become limited due to the weight clipping and there could be gradient vanishing problems in the training.
Improved training of WGAN.
Gulrajani et al. (2017) give more concrete examples to illustrate the perils of the weight clipping and propose an alternative way of imposing the Lipschitz continuity. In particular, they introduce a gradient penalty term by noting that the differentiable discriminator is 1-Lipschitz if and only if the norm of its gradients is at most 1 everywhere,
where is uniformly sampled from the straight line between a pair of data points sampled from the model and the real , respectively. A similar regularization is used by Kodali et al. (2017).
Unlike the weight clipping, however, by no means one can penalize everywhere using this term through a finite number of training iterations. As a result, the gradient penalty term takes effect only upon the sampled points , leaving significant parts of the support domain not examined at all. In particular, consider the observed data points and their underlying manifold that supports the real distribution . At the beginning of the training stage, the generated sample and hence could be distant from the manifold. The Lipschitz continuity over the manifold is not enforced until the generative model becomes close enough to the real one , if it can.
1.2 Our Approach and Contributions
In light of the above pros and cons, we propose to improve the improved training of WGAN by additionally laying the Lipschitz continuity condition over the manifold of the real data . Moreover, instead of focusing on one particular data point at a time, we devise a regularization over a pair of data points drawn near the manifold following the most basic definition of the 1-Lipschitz continuity. In particular, we perturb each real data point twice and use a Lipschitz constant to bound the difference between the discriminator’s responses to the perturbed data points .
Figure 1 illustrates our main idea. The gradient penalty often fails to check the continuity of region near the real data , around which the discriminator function can freely violate the 1-Lipschitz continuity. We alleviate this issue by explicitly checking the continuity condition using the two perturbed version near any observed real data point .
In this paper, we make the following contributions. (1) We propose an alternative mechanism for enforcing the Lipschitz continuity over the family of discriminators by resorting to the basic definition of the Lipschitz continuity. It effectively improves the gradient penalty method (Gulrajani et al., 2017) and gives rise to generators with more photo-realistic samples and higher inception scores (Salimans et al., 2016). (2) Our approach is very data efficient in terms of being less prone to overfitting even for very small training sets. We do not observe obvious overfitting phenomena even when the model is trained on only 1000 images of CIFAR-10 (Krizhevsky & Hinton, 2009). (3) Our approach can be seamlessly integrated with GANs to be a competitive semi-supervised training technique (Chapelle et al., 2009) thanks to that both inject noise to the real data points.
As results, we are able to report the state-of-the-art results on the generative model with a inception score of on CIFAR-10, and the semi-supervised learning results of on CIFAR-10 using only 4,000 labeled images — especially, they are significantly better than all the existing GAN-based semi-supervised learning results, to the best of our knowledge.
We firstly review the definition of the Lipschitz continuity and then discuss how to use it to regularize the training of WGAN. We then arrive at an approach that can be seamlessly integrated with the semi-supervised learning method (Laine & Aila, 2016). By bringing the best of the two worlds, we report better semi-supervised learning results than both (Laine & Aila, 2016) and existing GAN-based methods.
2.1 Improving the improved training of WGAN
Let denote the metric on an input space used in this paper. A discriminator is Lipschitz continuous if there exists a real constant such that, for all ,
Immediately, we can add the following soft consistency term () to the value function of WGAN in order to penalize the violations to the inequality in eq. 3,
Here we face the same snag as in (Gulrajani et al., 2017), i.e., it is impractical to substitute all the possibilities of pairs into eq. 4. What pairs and which regions of the input set should we check for eq. 4? Arguably, it is fairly safe to limit our scope to the manifold that supports the real data distribution and its surrounding regions mainly for two reasons. First, we keep the gradient penalty term and improve it by the proposed consistency term in our overall approach. While the former enforces the continuity over the points sampled between the real and generated points, the latter complement the former by focusing on the region around the real data manifold instead. Second, the distribution of the generative model is virtually desired to be as close as possible to . We use the notation in eq. 3 and a different in eq. 4 to reflect the fact that the continuity will be checked only sparsely at finite data points in practice.
Perturbing the real data.
To this end, the very first version we tried was to directly add Gaussian noise to each real data point, resulting in a pair of , where . However, as noted by Arjovsky et al. (2017) and Wu et al. (2016), we found that the samples from the generator become blurry due to the Gaussian noise used in the training. We have also tested the dropout noise that is applied to the input and found that the resulting MNIST samples are cut off here and there.
The success comes after we perturb the hidden layers of the discriminator using dropout, as opposed to the input . When the dropout rate is small, the perturbed discriminator’s output can be considered as the output of the clean discriminator in response to a “virtual” data point that is not far from . Thus, we denote by the discriminator output after applying dropout to its hidden layers. In the same manner, we find the second virtual point around by applying the (stochastic) dropout again to the hidden layers of the discriminator, and denote by the corresponding output.
Note that, however, it becomes impossible to compute the distance between the two virtual data points. In this work, we assume it is bounded by a constant and absorb the constant to . Accordingly, we tune in our experiments to take account of this unknown constant; the best results are obtained between and . For consistency, we use to report all the results in this paper.
A consistency regularization.
Our final consistency regularization takes the following form,
where is the output of the discriminator given the input and after we apply dropout to the hidden layers of the discriminator. We envision this is equivalent to passing a “virtual” data point through the clean discriminator. We find that it slightly improves the performance by further controlling the second-to-last layer of the discriminator, i.e., the above.
This new consistent regularization enforces the Lipschitz continuity over the data manifold and its surrounding regions, effectively complementing and improving the gradient penalty used in the improved training of WGAN. Putting them together, our new objective function for updating the weigts of the discriminator is
Algorithm 1 shows the complete algorithm for learning a WGAN in this paper. For the hyper-parameters, we borrow from Gulrajani et al. (2017) and use for all our experiments no matter on which dataset. Another hyper-parameter is in eq. 5. As stated previously, taking values between 0 and 0.2 gives rise to about the same results in our experiments.
2.2 A seamless connection with a semi-supervised learning method
In this section, we extend the WGAN to a semi-supervised learning approach by drawing insights from two related works.
Following (Salimans et al., 2016), we modify the output layer of the discriminator such that it has output neurons, where is the number of classes of interest and the ()-th neuron is reserved for contrasting the generated samples with the real data using the Wasserstein distance in the WGAN. We use a -way softmax as the activation function of the last layer.
Following (Laine & Aila, 2016), we add our consistency regularization to the objective function of the semi-supervised training, in order to take advantage of the effect of temporal ensembling.
Figure 2 shows the framework for training the discriminator, with the objective function below,
where the first three terms are the same as in (Salimans et al., 2016), while the last consistency regularization is calculated after we apply dropout to the discriminator. The last term essentially leads to a temporal self-ensembling scheme to benefit the semi-supervised learning. Please see (Laine & Aila, 2016) for more insightful discussions about it.
The generator loss matches the expected features of the generated sample and the real data points,
3 Experimental results
We conduct experiments on the prevalent MNIST (LeCun et al., 1998), and CIFAR-10 (Krizhevsky & Hinton, 2009) datasets. The code is available on https://github.com/biuyq/CT-GAN to facilitate the reproducibility of our results.
The MNIST dataset provides 70,000 handwritten digits in total and 10,000 of them are often left out for the testing purpose. Following (Gulrajani et al., 2017), we use only 1,000 of them to train the WGAN for a fair comparison with it. We use all the 60,000 training examples in the semi-supervised learning experiments and reveal the labels of 100 of them (10 per class). This setup is the same as in (Rasmus et al., 2015). No data augmentation is used. Please see Appendix A for the network architectures of the generator and discriminator, respectively.
Figure 3 shows the generated samples with improved training of WGAN by the gradient penalty (GP-WGAN) and ours with the consistency regularization (CT-GAN), respectively, after 50,000 generator iterations. It is clear that our approach gives rise to more realistic samples than GP-WGAN. The contrasts of our samples between the foreground and the background are in general sharper than those of GP-WGAN.
We find that our approach is less prone to overfitting. To demonstrate this point, we show the convergence curves of the discriminator’s value functions by GP-WGAN and our CT-GAN in Figure 4. The red curves are evaluated on the training set and the blue ones are on the test set. We can see that the results on the test set become saturated pretty early in GP-WGAN, while ours can consistently decrease the costs on both the training and the test sets. This observation also holds for the CIFAR-10 dataset (cf. Appendix E).
Semi-supervised learning results.
We compare our semi-supervised learning results with those of several competitive methods in Table 1. Our approach is among the best on this MNIST dataset.
CIFAR-10 (Krizhevsky & Hinton, 2009) contains 50,000 natural images of size 3232. We use it to test two networks for the generative model: a small CNN and a ResNet (cf. Appendix A for the network structures). For the former we use only 1,000 images to train the model and we use the whole training set to learn the ResNet.
Additionally, we also draw the histograms of the discriminator’s weights in Figure 6 after we train it using GP-WGAN and CT-GAN, respectively. It is interesting to see that ours controls the weights within a smaller and more symmetric range than the by GP-WGAN, partially explaining why our approach is less prone to overfitting.
|Method||Supervised IS||Unsupervised IS||Accuracy(%)|
|SteinGANs (Wang & Liu, 2016)||6.35||–||–|
|DCGANs (Radford et al., 2015)||6.58||–|
|Improved GANs (Salimans et al., 2016)||–||–|
|AC-GANs (Odena et al., 2016)||–||–|
|GP-WGAN (Gulrajani et al., 2017)||91.85|
|SGANs (Huang et al., 2016)||–||–|
|ALI (Warde-Farley & Bengio, 2016)||–||–|
|BEGAN (Berthelot et al., 2017)||–||5.62||–|
|EGAN-Ent-VI (Dai et al., 2017)||–||–|
|DFM (Warde-Farley & Bengio, 2016)||–||–|
Comparison of the inception scores.
Finally, we compare our approach with GP-WGAN on the whole training set for both unsupervised and supervised generative-purpose task using ResNet. For model selection, we use the first 50,000 samples to compute the inception scores (Salimans et al., 2016), then choose the best model, and finally report the “test” score on another 50,000 samples. The experiment follows the previous setup in (Odena et al., 2016). From the comparison results in Tables 2, we conclude that our proposed model achieves the highest inception score on the CIFAR-10 dataset, to the best of our knowledge. Some generated samples are shown in Figure 7.
For the small CNN based generator, the inception scores of GP-WGAN and our CT-GAN are and , respectively.
|Method||Test error (%)|
|Ladder (Rasmus et al., 2015)|
|VAT (Miyato et al., 2017)||10.55|
|TE (Laine & Aila, 2016)|
|Teacher-Student (Tarvainen & Valpola, 2017)|
|CatGANs (Springenberg, 2015)|
|Improved GANs (Salimans et al., 2016)|
|ALI (Dumoulin et al., 2016)|
|CLS-GAN (Qi, 2017)|
|Triple GAN (Li et al., 2017a)|
|Improved semi-GAN (Kumar et al., 2017)|
For the semi-supervised learning approach, we follow the standard training/test split of the dataset but use only 4,000 labels in the training. A regular data augmentation with flipping the images horizontally and randomly translating the images within [-2,2] pixels is used in our paper (No ZCA whitening). We report the semi-supervised learning results in Table 3. The mean and standard errors are obtained by running the experiments 5 rounds. Comparing to several very competitive methods, ours is able to achieve state-of-the-art results. Notably, our CT-GAN outperfroms all the GAN based methods by a large margin. Please see Appendix A for the network architectures and Appendix C for the ablation study of our algorithm.
In this paper, we present a consistency term derived from Lipschitz inequality, which boosts the performance of GANs model. The proposed term has been demonstrated to be an efficient manner to ease the over-fitting problem when data amount is limited. Experiments show that our model obtains the state-of-the-art accuracy and Inception score on CIFAR-10 dataset for both the semi-supervised learning task and the learning of generative models.
This work is partially supported by NSF IIS-1741431, IIS-1566511, and ONR N00014-18-1-2121. B.G. and L.W. would also like to thank Adobe Research and NVIDIA, and Amazon, respectively, for their gift donations.
- Arjovsky & Bottou (2017) Martin Arjovsky and Léon Bottou. Towards principled methods for training generative adversarial networks. arXiv preprint arXiv:1701.04862, 2017.
- Arjovsky et al. (2017) Martin Arjovsky, Soumith Chintala, and Léon Bottou. Wasserstein gan. arXiv preprint arXiv:1701.07875, 2017.
- Berthelot et al. (2017) David Berthelot, Tom Schumm, and Luke Metz. Began: Boundary equilibrium generative adversarial networks. arXiv preprint arXiv:1703.10717, 2017.
- Chapelle et al. (2009) Olivier Chapelle, Bernhard Scholkopf, and Alexander Zien. Semi-supervised learning (chapelle, o. et al., eds.; 2006)[book reviews]. IEEE Transactions on Neural Networks, 20(3):542–542, 2009.
- Dai et al. (2017) Zihang Dai, Amjad Almahairi, Philip Bachman, Eduard Hovy, and Aaron Courville. Calibrating energy-based generative adversarial networks. arXiv preprint arXiv:1702.01691, 2017.
- Deng et al. (2009) Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In Computer Vision and Pattern Recognition, 2009. CVPR 2009. IEEE Conference on, pp. 248–255. IEEE, 2009.
- Denton et al. (2015) Emily L Denton, Soumith Chintala, Rob Fergus, et al. Deep generative image models using a laplacian pyramid of adversarial networks. In Advances in neural information processing systems, pp. 1486–1494, 2015.
- Dumoulin et al. (2016) Vincent Dumoulin, Ishmael Belghazi, Ben Poole, Alex Lamb, Martin Arjovsky, Olivier Mastropietro, and Aaron Courville. Adversarially learned inference. arXiv preprint arXiv:1606.00704, 2016.
- Dziugaite et al. (2015) Gintare Karolina Dziugaite, Daniel M Roy, and Zoubin Ghahramani. Training generative neural networks via maximum mean discrepancy optimization. arXiv preprint arXiv:1505.03906, 2015.
- Goodfellow et al. (2014) Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In Advances in neural information processing systems, pp. 2672–2680, 2014.
- Gulrajani et al. (2017) Ishaan Gulrajani, Faruk Ahmed, Martin Arjovsky, Vincent Dumoulin, and Aaron Courville. Improved training of wasserstein gans. arXiv preprint arXiv:1704.00028, 2017.
- Huang et al. (2016) Xun Huang, Yixuan Li, Omid Poursaeed, John Hopcroft, and Serge Belongie. Stacked generative adversarial networks. arXiv preprint arXiv:1612.04357, 2016.
- Im et al. (2016) Daniel Jiwoong Im, Chris Dongjoo Kim, Hui Jiang, and Roland Memisevic. Generating images with recurrent adversarial networks. arXiv preprint arXiv:1602.05110, 2016.
- Isola et al. (2016) Phillip Isola, Jun-Yan Zhu, Tinghui Zhou, and Alexei A Efros. Image-to-image translation with conditional adversarial networks. arXiv preprint arXiv:1611.07004, 2016.
- Kingma & Welling (2013) Diederik P Kingma and Max Welling. Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114, 2013.
- Kingma et al. (2014) Diederik P Kingma, Shakir Mohamed, Danilo Jimenez Rezende, and Max Welling. Semi-supervised learning with deep generative models. In Advances in Neural Information Processing Systems, pp. 3581–3589, 2014.
- Kodali et al. (2017) Naveen Kodali, Jacob Abernethy, James Hays, and Zsolt Kira. On convergence and stability of gans. arXiv preprint arXiv:1705.07215, 2017.
- Krizhevsky & Hinton (2009) Alex Krizhevsky and Geoffrey Hinton. Learning multiple layers of features from tiny images. 2009.
- Kumar et al. (2017) Abhishek Kumar, Prasanna Sattigeri, and P Thomas Fletcher. Improved semi-supervised learning with gans using manifold invariances. arXiv preprint arXiv:1705.08850, 2017.
- Laine & Aila (2016) Samuli Laine and Timo Aila. Temporal ensembling for semi-supervised learning. arXiv preprint arXiv:1610.02242, 2016.
- LeCun et al. (1998) Yann LeCun, Léon Bottou, Yoshua Bengio, and Patrick Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998.
- Li et al. (2017a) Chongxuan Li, Kun Xu, Jun Zhu, and Bo Zhang. Triple generative adversarial nets. arXiv preprint arXiv:1703.02291, 2017a.
- Li et al. (2017b) Chun-Liang Li, Wei-Cheng Chang, Yu Cheng, Yiming Yang, and Barnabás Póczos. Mmd gan: Towards deeper understanding of moment matching network. arXiv preprint arXiv:1705.08584, 2017b.
- Li et al. (2015) Yujia Li, Kevin Swersky, and Rich Zemel. Generative moment matching networks. In Proceedings of the 32nd International Conference on Machine Learning (ICML-15), pp. 1718–1727, 2015.
- Miyato et al. (2017) Takeru Miyato, Shin-ichi Maeda, Masanori Koyama, and Shin Ishii. Virtual adversarial training: a regularization method for supervised and semi-supervised learning. arXiv preprint arXiv:1704.03976, 2017.
- Nock et al. (2017) Richard Nock, Zac Cranko, Aditya Krishna Menon, Lizhen Qu, and Robert C Williamson. f-gans in an information geometric nutshell. arXiv preprint arXiv:1707.04385, 2017.
- Nowozin et al. (2016) Sebastian Nowozin, Botond Cseke, and Ryota Tomioka. f-gan: Training generative neural samplers using variational divergence minimization. In Advances in Neural Information Processing Systems, pp. 271–279, 2016.
- Odena (2016) Augustus Odena. Semi-supervised learning with generative adversarial networks. arXiv preprint arXiv:1606.01583, 2016.
- Odena et al. (2016) Augustus Odena, Christopher Olah, and Jonathon Shlens. Conditional image synthesis with auxiliary classifier gans. arXiv preprint arXiv:1610.09585, 2016.
- Qi (2017) Guo-Jun Qi. Loss-sensitive generative adversarial networks on lipschitz densities. arXiv preprint arXiv:1701.06264, 2017.
- Radford et al. (2015) Alec Radford, Luke Metz, and Soumith Chintala. Unsupervised representation learning with deep convolutional generative adversarial networks. arXiv preprint arXiv:1511.06434, 2015.
- Rasmus et al. (2015) Antti Rasmus, Mathias Berglund, Mikko Honkala, Harri Valpola, and Tapani Raiko. Semi-supervised learning with ladder networks. In Advances in Neural Information Processing Systems, pp. 3546–3554, 2015.
- Rezende et al. (2014) Danilo Jimenez Rezende, Shakir Mohamed, and Daan Wierstra. Stochastic backpropagation and approximate inference in deep generative models. arXiv preprint arXiv:1401.4082, 2014.
- Salimans et al. (2016) Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, and Xi Chen. Improved techniques for training gans. In Advances in Neural Information Processing Systems, pp. 2234–2242, 2016.
- Springenberg (2015) Jost Tobias Springenberg. Unsupervised and semi-supervised learning with categorical generative adversarial networks. arXiv preprint arXiv:1511.06390, 2015.
- Sutskever et al. (2015) Ilya Sutskever, Rafal Jozefowicz, Karol Gregor, Danilo Rezende, Tim Lillicrap, and Oriol Vinyals. Towards principled unsupervised learning. arXiv preprint arXiv:1511.06440, 2015.
- Tarvainen & Valpola (2017) Antti Tarvainen and Harri Valpola. Weight-averaged consistency targets improve semi-supervised deep learning results. arXiv preprint arXiv:1703.01780, 2017.
- Villani (2008) Cédric Villani. Optimal transport: old and new, volume 338. Springer Science & Business Media, 2008.
- Wang & Liu (2016) Dilin Wang and Qiang Liu. Learning to draw samples: With application to amortized mle for generative adversarial learning. arXiv preprint arXiv:1611.01722, 2016.
- Warde-Farley & Bengio (2016) David Warde-Farley and Yoshua Bengio. Improving generative adversarial networks with denoising feature matching. 2016.
- Wu et al. (2016) Yuhuai Wu, Yuri Burda, Ruslan Salakhutdinov, and Roger Grosse. On the quantitative analysis of decoder-based generative models. arXiv preprint arXiv:1611.04273, 2016.
- Yu et al. (2015) Fisher Yu, Ari Seff, Yinda Zhang, Shuran Song, Thomas Funkhouser, and Jianxiong Xiao. Lsun: Construction of a large-scale image dataset using deep learning with humans in the loop. arXiv preprint arXiv:1506.03365, 2015.
- Zhu et al. (2017) Jun-Yan Zhu, Taesung Park, Phillip Isola, and Alexei A Efros. Unpaired image-to-image translation using cycle-consistent adversarial networks. arXiv preprint arXiv:1703.10593, 2017.
Appendix A Network architectures
Table 4 (MNIST) and Table 5 (CIFAR-10) detail the network architectures used in our classification-purpose CT-GAN, where the classifiers are the same as the widely used ones in most semi-supervised networks (Laine & Aila, 2016) except that we apply weight-norm rather than batch-norm. For the generators, we follow the network structures in GP-WGAN (Salimans et al., 2016), but we use lower dimensional noise (50D) as the input to the generator for CIFAR-10 in order not to reproduce the complicated images and instead shift the focus of the training to the classifier.
|Classifier C||Generator G|
|Input: Labels , 28*28 Images ,||Input: Noise 100|
|Gaussian noise 0.3, MLP 1000, ReLU||MLP 500, Softplus, Batch norm|
|Gaussian noise 0.5, MLP 500, ReLU||MLP 500, Softplus, Batch norm|
|Gaussian noise 0.5, MLP 250, ReLU||MLP 784, Sigmoid, Weight norm|
|Gaussian noise 0.5, MLP 250, ReLU|
|Gaussian noise 0.5, MLP 250, ReLU|
|Gaussian noise 0.5, MLP 10, Softmax|
|Classifier C||Generator G|
|Input: Labels , 32*32*3 Colored Image ,||Input: Noise 50|
|0.2 Dropout||MLP 8192, ReLU, Batch norm|
|3*3 conv. 128, Pad =1, Stride =1, lReLU, Weight norm||Reshape 512*4*4|
|3*3 conv. 128, Pad =1, Stride =1, lReLU, Weight norm||5*5 deconv. 256*8*8,|
|3*3 conv. 128, Pad =1, Stride =2, lReLU, Weight norm||ReLU, Batch norm|
|3*3 conv. 256, Pad =1, Stride =1, lReLU, Weight norm|
|3*3 conv. 256, Pad =1, Stride =1, lReLU, Weight norm||5*5 deconv. 128*16*16,|
|3*3 conv. 256, Pad =1, Stride =2, lReLU, Weight norm||ReLU, Batch norm|
|3*3 conv. 512, Pad =0, Stride =1, lReLU, Weight norm|
|3*3 conv. 256, Pad =0, Stride =1, lReLU, Weight norm||5*5 deconv. 3*32*32,|
|3*3 conv. 128, Pad =0, Stride =1, lReLU, Weight norm||Tanh, Weight norm|
|MLP 10, Weight norm, Softmax|
Appendix B Hyper-parameters and other training details
For the semi-supervised learning experiments, we set in Eq.(7) in all our experiments. For CIFAR-10, the number of training epochs is set to 1,000 with a constant learning rate of 0.0003. For MNIST, the number of training epochs is set to 300 with a constant learning rate of 0.003. The other hyper-parameters are exactly the same as in the improved GAN (Salimans et al., 2016).
For the experiments on the generative models, to have a fair comparison for our results with the existing ones, we keep the network structure and hyper-parameters the same as in the improved training of WGAN (Gulrajani et al., 2017) except that we add three dropout layers to some of the hidden layers as shown in tables 8, 7 and 6.
|Input: 1*28*28 Image||Input: Noise 128|
|5*5 conv. 64, Pad = same, Stride = 2, lReLU||MLP 4096, ReLU|
|0.5 Dropout||Reshape 256*4*4|
|5*5 conv. 128, Pad = same, Stride = 2, lReLU||5*5 deconv. 128*8*8|
|0.5 Dropout||ReLU, Cut 128*7*7|
|5*5 conv. 256, Pad = same, Stride = 2, lReLU||5*5 deconv. 64*14*14|
|Reshape 256*4*4 (D_)||5*5 deconv. 1*28*28|
|MLP 1 (D)||Sigmoid|
|Input: 3*32*32 Image ,||Input: Noise 128|
|5*5 conv. 128, Pad = same, Stride = 2, lReLU||MLP 8192, ReLU, Batch norm|
|0.5 Dropout||Reshape 512*4*4|
|5*5 conv. 256, Pad = same, Stride = 2, lReLU||5*5 deconv. 256*8*8|
|0.5 Dropout||ReLU, Bach norm|
|5*5 conv. 512, Pad = same Stride = 2, lReLU||5*5 deconv. 128*16*16|
|0.5 Dropout||ReLU, Batch norm|
|Reshape 512*4*4 (D_)||5*5 deconv. 3*32*32|
|MLP 1 (D)||Tanh|
|Input: 3*32*32 Image||Input: Noise 128|
|[3*3]*2 Residual Block, Resample = DOWN||MLP 2048|
|[3*3]*2 Residual Block, Resample = DOWN||[3*3]*2 Residual Block, Resample = UP|
|128*8*8 0.2 Dropout||128*8*8|
|[3*3]*2 Residual Block, Resample = None||[3*3]*2 Residual Block, Resample = UP|
|128*8*8 0.5 Dropout||128*16*16|
|[3*3]*2 Residual Block, Resample = None||[3*3]*2 Residual Block, Resample = UP|
|128*8*8 0.5 Dropout||128*32*32|
|ReLU, Global mean pool (D_)||3*3 conv. 3*32*32|
|MLP 1 (D)||Tanh|
Appendix C Ablation study of our approach to SSL
We run an ablation study of our approach to the semi-supervised learning (SSL). Thanks to the dual effect of the proposed consistency (CT) term, we are able to connect GAN with the temporal ensembling (TE) Laine & Aila (2016) method for SSL. Our superior results thus benefit from both of them, verified by the ablation study detailed in Table 9.
If we remove the CT term, the test error goes up to , signifying the effectiveness of the CT regularization.
If we remove GAN from our approach, it almost reduces to TE; in fact, all the settings here are the same as TE except that we use an extra regularization ( in CT) over the second-to-last layer. We can see that the error is still significantly larger than our overall method.
We use the weight normalization as in Salimans et al. (2016), which becomes a core constituent of our approach. The batch normalization would actually invalidate the feature matching in Salimans et al. (2016).
Finally, we also test the version without the regularization to the second-to-last layer and observe a little drop in the performance.
|w batch norm||–|
|w/o over the second-to-last layer||10.700.24|
Appendix D Examining the 1-Lipschitz continuity
Norm of gradient.
In our experiments, we find that although the GP-WGAN (Gulrajani et al., 2017) has applied a Lipschitz constraint in the form of the gradient penalty over the input sampled between a real data point and a generated one, the actual effect on the norm of the gradient is not as good as our CT-GAN model in the real data points, as illustrated in Figure 1. We empirically verify this fact by Figure 8, which shows the norms of the gradients of the discriminator with respect to the real data points. The closer to 1 the norms are, the better the 1-Lipschitz continuity is preserved. Figure 8 further demonstrates that our consistency (CT) regularization is able to improve GP-WGAN (Gulrajani et al., 2017).
Definition of the Lipschitz continuity.
Additionally, we check how much the 1-Lipschitz continuity is satisfied according to its vanilla definition (cf. Eq. (3)). Figures 9 and 10 plot the CTs of Eq. (4) and Eq. (5), respectively, over different iterations of the training using 1000 CIFAR-10 images. Figure 10 is the actual CTs used to train the generative model. Figure 9 is drawn as follows. For every 100 iterations, we randomly pick up 64 real examples and split them into two subsets of the same size. We compute for all the pairs, where is from the first subset and is from the second. The maximum of is plotted in Figure 9. We can see that the CT-GAN curve converges under a certain value much faster than GP-WGAN. The 1-Lipschitz continuity is better maintained by CT-GAN than GP-WGAN over the whole course of the training procedure on the manifold of the real data.
Appendix E GP-WGAN with dropout
We run another experiment of GP-WGAN by adding the same dropout layers used in our method to the networks of GP-WGAN, denoted by GP-WGAN+dropout. This may help understand the contribution of our approach in preventing the overfitting problem — is it due to the CT regularization, or merely due to the dropout? This experiment is done by removing the CT term out of our value function for training WGAN using 1,000 CIFAR-10 images and yet keep the dropout layers.
The Inception score of GP-WGAN+dropout is , and the generated samples are shown in Figure 11(a). We also plot the curve of CT in the training procedure of GP-WGAN+dropout and contrast it to the curve of our CT-GAN in Figure 11(b). It is clear that GP-WGAN+dropout outperforms GP-WGAN (), and our method () outperforms GP-WGAN with a large margin.
Finally, we plot the convergence curves of the discriminators’ negative cost function learned by GP-WGAN, GP-WGAN+Dropout, and our CT-GAN in Figure 12. We can see that dropout is able to reduce the overfitting of GP-WGAN, but it is not as effective as our CT-GAN.
Appendix F Experiments on large dataset
In this section, we further present experimental results on the large-scale ImageNet (Deng et al., 2009) and LSUN bedroom (Yu et al., 2015) datasets. The experiment setup (e.g., network architecture, learning rates, etc.) is exactly the same as in the GP-WGAN work (Gulrajani et al., 2017). We refer readers to (Gulrajani et al., 2017) or our code on GitHub (https://github.com/biuyq/CT-GAN) for the details of the setup.
Our experiment on ImageNet is conducted using images of the size . After generator iterations, the inception score of the proposed CT-GAN is 10.270.15, whereas GP-WGAN’s is 9.850.17. In addition, the inception score comparison of GP-WGAN and our CT-GAN in each generator iteration is shown in Figure 13. We can observe that the inception score of our proposed CT-GAN becomes higher than GP-WGAN’s after early generator iterations. Finally, Figure 14 shows the samples generated by GP-WGAN and CT-GAN, respectively.
For the LSUN bedroom dataset, we show some results in Figure 15. They are the generated samples by our CT-GAN generator after 20k training iterations.