Stein Bridging: Enabling Mutual Reinforcement between Explicit and Implicit Generative Models

Stein Bridging: Enabling Mutual Reinforcement between Explicit and Implicit Generative Models

Qitian Wu, Rui Gao, Hongyuan Zha
Shanghai Jiao Tong University
University of Texas at Austin
The Chinese University of Hong Kong (Shenzhen)
Georgia Institute of Technology
Email: echo740@sjtu.edu.cn, rui.gao@mccombs.utexas.edu, zha@cc.gatech.edu. Part of the work was done when Qitian Wu and Rui Gao were visiting The Chinese University of Hong Kong (Shenzhen). Hongyuan Zha is on leave from Georgia Institute of Technology.
Abstract

Deep generative models are generally categorized into explicit models and implicit models. The former defines an explicit density form, whose normalizing constant is often unknown; while the latter, including generative adversarial networks (GANs), generates samples without explicitly defining a density function. In spite of substantial recent advances demonstrating the power of the two classes of generative models in many applications, both of them, when used alone, suffer from respective limitations and drawbacks. To mitigate these issues, we propose Stein Bridging, a novel joint training framework that connects an explicit density estimator and an implicit sample generator with Stein discrepancy. We show that the Stein Bridge induces new regularization schemes for both explicit and implicit models. Convergence analysis and extensive experiments demonstrate that the Stein Bridging i) improves the stability and sample quality of the GAN training, and ii) facilitates the density estimator to seek more modes in data and alleviate the mode-collapse issue. Additionally, we discuss several applications of Stein Bridging and useful tricks in practical implementation used in our experiments.

1 Introduction

Deep generative model, as a powerful unsupervised framework for learning the distribution of high-dimensional multi-modal data, has been extensively studied in recent literature. Typically, there are two types of generative models [15]. Explicit models define an explicit (unnormalized) density function, while implicit models learn to sample from the distribution without explicitly define a density function.

Explicit models have wide applications in undirected graphical models [27, 41, 20, 34], random graph theory [39], energy-based reinforcement learning [19], etc. However, the unknown normalizing constant makes the model hard to train and sample from, and the explicit models might not be able to capture the complex structure of true samples while maintaining tractability. In contrast, implicit models are more flexible in training and easy to sample from, and in pariticular, generative adverarial networks (GANs) have shown great power in learning representations of images, natural languages, graphs, etc. [15, 38, 2, 4]. Nevertheless, due to the minimax game between generator and discriminator/critic in GANs, the training process often suffers from instability, and produces undesirable samples often associated with missing modes in data or generating extra modes out of data. More discussion on related work is in Appendix A.

There are situations where we need both an explicit density and a flexible implicit sampler. For sample evaluation, it is not enough to merely distinguish samples from real to faked one, and one may also expect to provide fine-grained evaluation on generated samples, where the energy values given by the explicit models can be a good metric [6]. Another situation is outlier detection. Implicit models often leverage all true samples (possibly mixed with corrupted samples) as true examples for training. To make up for the issue, explicit models could help to detect out-of-distribution samples via the estimated densities [49]. Also, when given insufficient observed samples, explicit models may fail to capture an accurate distribution, in which case implicit model may help with data augmentation and facilitate training for density estimation. These situations motivate us to combine both of the worlds in an effective way so as to make the two models compensate and reinforce each other.

In this work, we aim at jointly learning explicit and implicit generative models. In our framework, an explicit energy model is used to estimate the unnormalized densities of true samples via minimizing a Stein discrepancy; in the meantime, an implicit generator model is exploited to minimize the Wasserstein metric (or Jensen-Shannon divergence) between distributions of true and generated samples. On top of these, another Stein discrepancy, acting as a bridge between implicit generated samples and explicit estimated densities, is introduced and pushes the two models to achieve a consensus. We show that the Stein bridge allows the two generative models to reinforce each other by imposing new regularizations on both models, which help the generator to output high-quality samples and facilitate the energy model to avoid mode-collapse. Moreover, we show that the joint training helps to stabilize GAN training via a convergence analysis. Extensive experiments on various tasks verify our theoretical findings as well as demonstrate the superiority of proposed methods compared with existing deep energy models and GAN-based models.

2 Background

In this section, we briefly provide some technical background used in our model.

2.1 Energy Model.

The energy model assigns each data with a scalar energy value , where is called the energy function and is parameterized by . The model is expected to assign low energy to true samples according to a Gibbs distribution , where is a normalizing constant dependent on .

The energy function is often parametrized as a sum of multiple experts [21] and each expert can have various function forms depending on the distributions. If using sigmoid distribution, the energy function becomes (see section 2.1 in [25] for details)

(1)

where maps input to a feature vector and could be specified as a deep neural network, which corresponds to deep energy model [34].

The normalizing term is often hard to compute, making the training intractable, and various methods are proposed to detour such term (see Appendix A).

2.2 Stein Discrepancy.

Stein discrepancy [17, 30, 5, 37] measures the difference between two distributions. Assume to be a continuously differentiable density supported on and a smooth vector function. Define as a Stein operator. If is a Stein class (satisfying some mild boundary conditions) then we have the following Stein identity property:

Such property induces the Stein discrepancy between distributions and , :

(2)

where is what we call Stein critic that exploits over function space and if is large enough then if and only if . Note that in (2), we do not need the normalized constant for which enables Stein discrepancy to deal with unnormalized density.

If is a unit ball in a Reproducing Kernel Hilbert Space (RKHS) with a positive definite kernel function , then the supremum in (2) would have a close form (see [30] for more details):

(3)

where . The (3) gives the Kernel Stein Discrepancy (KSD). One can refer to a recent work [3] for some important properties about (kernel) Stein discrepancy.

2.3 Wasserstein Metric.

Wasserstein metric is suitable for measuring distances between two distributions with non-overlapping supports [2]. The Wasserstein-1 metric between distributions and is defined as , where the minimization is over all joint distributions with marginals and . By Kantorovich-Rubinstein duality, it has a dual representation

(4)

where the maximization is over all 1-Lipschitz continuous functions.

2.4 Sobolev space and Sobolev dual norm.

Use to denote the canonical Hilbert space on equipped with an inner product . The Sobolev space is defined as the closure of , the set of smooth functions on with compact support, with respect to the norm . For , its Sobolev dual norm is defined by [9]

Therefore, can be viewed as a measure of smoothness, which measures the similarity (in terms of largest -norm) between and a subset of smooth functions in .

3 Proposed Model

In this section, we formulate our model, Stein Bridging, and highlight its regularization effects.

3.1 Model Formulation

We denote by the underlying real distribution from which the data are sampled. We simultaneously learn two generative models – one explicit and one implicit – that represent estimates of . The explicit generative model has an explicit probability density proportional to , where is referred to as an energy function. The implicit generative model transforms an easy-to-sample random noise with distribution via a generator to a generated sample with distribution . We use the Stein discrepancy as a measure of closeness between the explicit density and the real distribution , and use the Wasserstein metric as a measure of closeness between the implicit distribution and .

To jointly learn the two generative models and , arguably the most straightforward approach is to minimize the sum of the Stein discrepancy and the Wasserstein metric:

where is a weight coefficient. However, this approach appears no different than learning the two generative models separately. To better train the model, we incorporate the objective another term – called the Stein bridge – that measures the closeness between the explicit density and the implicit distribution :

(5)

where are weight coefficients. Although the Stein bridge might seem redundant mathematically, we show that it helps regularize the models in Section 3.2.

The Wasserstein term in (5) is implemented using its equivalent dual representation (4). The two Stein terms in (5) can be implemented using (2) with either a Stein critic parameterized by a neural network, or the Kernel Stein Discrepancy. To reduce the computational cost, the two Stein critics share their parameters, namely, kernels or neural networks. A scheme of our framework is presented in Fig. 1. We also discuss some related works that attempt to combine both of the worlds (such as energy-based GAN, contrastive learning and cooperative learning) in Appendix A.3, and highlight the difference between our method and theirs in terms of the objective in Table 1.

Remark. In general, we can also choose other statistical distances in (5) to measure closeness between probability distributions. For example, the Wasserstein metric can be replaced by other common choices for implicit generative models, such as Jensen-Shannon divergence used in the original GAN paper [15]. If the normalizing constant of is known or easy to calculate, one can replace the Stein discrepancy by the Kullback-Leibler divergence, which is equivalent to the maximum likelihood estimation. We present details for model specifications in various forms and training algorithm in Appendix D.2.

Figure 1: Model framework for Stein Bridging which jointly train an implicit sample generator and an explicit density estimator via a Stein bridge.
Model Objective GAN Energy Model Energy-based GAN [50, 6] Contrastive Learning [25, 22] Cooperative Learning [48, 46] Stein Bridging (ours) Table 1: Comparison of objectives between different generative models, where , and denote general statistical distances between two distributions.

3.2 Regularization Effects by Virtue of The Stein Bridge

The intuitive motivation of the Stein bridge term in (5) is to push the two models to achieve a consensus. In this subsection, we theoretically show that the Stein bridge allows the two models to reinforce each other by imposing regularizations on the critics.

3.2.1 Kernel Sobolev dual norm regularization on the Wasserstein critic

We first show the regularization effect of the Stein bridge on the Wasserstein critic. Fixing the energy function , consider the max-min problem over the Wasserstein critic and the generator :

(6)

We define the kernel Sobolev dual norm as

which can be viewed as a kernel generalization of the Sobolev dual norm defined in Section 2.4, which reduces to the Sobolev dual norm when and being the Lebesgue measure.

Assuming that exhausts all probability distributions, we have the following result.

Theorem 1.

Formally, problem (6) is equivalent to

Note that if is an optimal Wasserstein critic in (4), so does , . This is consistent to the penalty term , since the penalties are identical for and . According to Section 2.4, the regularization term would penalize the non-smoothness of the Wasserstein critic , which is in the same spirit of gradient-based penalty (e.g., [18, 40]), but with a new way to encouraging smoothness.

Another way to interpret the Sobolev dual norm penalty is by observing that if , and [43], then

where denotes the 2-Wasserstein metric. Therefore, the regularization ensures that would not change suddenly on the high-density region of , and the explicit model reinforces the learning of the Wasserstein critic.

3.2.2 Lipschitz Regularization on the Stein critic

We next investigate how the Stein bridge helps to regularize the Stein critic. Recall that the two Stein terms in (5) share the same Stein critic. Fixing the energy function , consider the max-min problem over the Stein critic and the generator :

(7)

Assuming that exhausts all probability distributions, we have the following result.

Theorem 2.

Problem (7) is equivalent to

where denotes the Lipschitz constant of the function .

Theorem 2 shows that the Stein bridge, together with the Wasserstein metric , plays as a smoothness regularization on the Stein critic via the constraint . The regularization will penalize large variation of values given by Stein operators on adjacent instances and further encourage the energy model to seek more modes in data instead of focusing on some dominated modes, thus helping to alleviate the mode-collapse issue. To the best of our knowledge, this suggests a novel regularization scheme for Stein-based GAN.

4 Convergence Analysis

In Section 3.2, we justify Stein Bridging by showing the regularization effects. In this section, we further show that it could help to stabilize GAN training with local convergence guarantee. To this end, we first compare the behaviors of WGAN, likelihood- and entropy-regularized WGAN, and our Stein Bridging under SGD via an easy to comprehend toy example. Then we give a formal result that interprets why the introduction of density estimator could stablize GAN training and help for convergence.

4.1 Analysis of a Linear System

The training for minimax game in GAN is difficult. When using traditional gradient methods, the training would suffer from some oscillatory behaviors [16, 29]. In order to better understand the optimization behaviors, we first study a one-dimension linear system that provides some insights on this problem. Note that such toy example (or a similar one) is also utilized by [14, 32] to shed lights on the instability of WGAN training111Our theoretical discussions focus on WGAN, and we also compare with original GAN in the experiments.. Consider a linear critic and generator . Then the Wasserstein GAN objective can be written as a constrained bilinear problem: , which could be further simplified as an unconstrained version (the behaviors could be generalized to multi-dimensional cases [14]):

(8)

Unfortunately, such simple objective cannot guarantee convergence by traditional gradient methods like SGD with alternate updating222Here, we adopt the most widely used alternate updating strategy. The simultaneous updating, i.e., and , would diverge in this case.: , Such optimization would suffer from an oscillatory behavior, i.e., the updated parameters go around the optimum point () forming a circle without converging to the centrality, which is shown in Fig. 2(a). A recent study in [29] theoretically show that such oscillation is due to the interaction term in (8).

One solution to the instability of GAN training is to add (likelihood) regularization, which has been widely studied by recent literatures [44, 28]. With regularization term, the objective changes into where denotes the likelihood function and is a hyperparameter. A recent study [42] proves that when (likelihood-regularization), the extra term is equivalent to maximizing sample evidence, helping to stabilize GAN training; when (entropy-regularization), the extra term maximizes sample entropy, which encourages diversity of generator. Here we consider a Gaussian likelihood function for generated sample , which is up to a constant, and then the objective becomes (see Appendix C.1 for details):

(9)
Figure 2: Numerical iterations for SGD training. (a) Comparison of WGAN, likelihood-regularized WGAN (WGAN+LR), variational annealing for WGAN+LR (WGAN+LR+VA), entropy-regularized WGAN (WGAN+ER) and our Stein Bridging. (b) Stein Bridging with different and .

The above system would converge with and diverge with in gradient-based optimization, shown in Fig. 2(a). Another issue of likelihood-regularization is that the extra term changes the optimum point and makes the model converge to a biased distribution, as proved by [42]. In this case, one can verify that the optimum point becomes , resulting a bias. To avoid this issue, [42] proposes to temporally decrease through training. However, such method would also be stuck in oscillation when gets close to zero as is shown in Fig. 2(a).

Finally, let us consider our proposed model. We also simplify the density estimator as a basic energy model whose score function is . Then if we specify the two Stein discrepancies in (5) as KSD, we have the objective,

(10)

Interestingly, one can verify that for , the optimum point remains the same . Then we show that the optimization can guarantee convergence to .

Proposition 1.

Using alternate SGD for (10) geometrically decreases the square norm , for any with ,

(11)

In Fig. 2(a), we can see that Stein Bridging achieves a good convergence to the right optimum. Compared with (8), the objective (10) adds a new bilinear term , which acts like a connection between the two generator and estimator, and two other quadratic terms, which help to push the values to decrease through training. The added terms and the original terms in (10) cooperate to guarantee convergence to a unique optimum. (More discussions in Appendix C.1).

We further generalize the analysis to multi-dimensional bilinear system which is extensively used by researches for analysis of GAN stability [16, 13, 29, 14]. For any bilinear system, with the added term where to the objective, we can prove that i) the optimum point remains the same as the original system (Proposition 2) and ii) using alternate SGD algorithm for the new objective can guarantee convergence (Theorem 4). The results are given in Appendix C.3.

4.2 Local Convergence for a General Model

To study the convergence for Stein Bridging, we proceed to consider a general optimization objective

where , and and ( is a shared parameter set). Use to denote the optimum point of and , represent the optimum points of and respectively. Define and , where , , denote constraint sets for , , respectively. Function is -strongly convex, and is -strongly convex for and -strongly concave for (see Appendix D.4 for definition of strongly convex condition). Here we define , , and then we have the following theorem.

Theorem 3.

If is -strongly convex-concave and is -strongly convex, we can leverage the alternate SGD algorithm, i.e.

(12)
(13)

where , , , , and denotes the projection mapping to . Then we can achieve the convergence by using .

Theorem 3 shows that Stein Bridging could converge to at least a local optimum. Due to the unknown and intricate landscape of deep neural networks, the global optimization and convergence analysis for GAN has remained as an unexplored problem. Despite the fact that strong convexity assumption cannot be guaranteed with deep neural networks, the optimization could converge to a stable point once there exists a local region that satisfies the strongly convex conditions. In the experiments, we will empirically compare the training stability of each method on various datasets to validate our theoretical discussions.

5 Experiments

In this section, we conduct experiments to verify the effectiveness of proposed method333The reproducible codes are available at https://github.com/echo740/SteinBridging from multifaceted views. First, we select three tasks with different evaluation metrics in Section 5.1, 5.2 and 5.3. Then we further discuss some applications of joint training as well as some useful tricks in Section 5.4, 5.5 and 5.6.

We consider two synthetic datasets with mixtures of Gaussian distributions: Two-Circle and Two-Spiral. The first one is composed of 24 Gaussian mixtures that lie in two circles. Such dataset is extended from the 8-Gaussian-mixture scenario which is widely used in previous GAN papers and is more difficult, so that we can use it to test the quality of generated samples and mode coverage of learned density. The second synthetic dataset consists of 100 Gaussian mixtures whose centers are densely arranged on two centrally symmetrical spiral-shaped curves. This dataset can be used to examine the power of generative model on complicated data distributions. The ground-truth distributions and samples are shown in Fig. 3 (a) and Fig. 4 (a). Furthermore, we also apply the methods to MNIST and CIFAR datasets which require the model to deal with high-dimensional data. In each dataset, we use observed samples as input of the model and leverage them to train the generators and the estimators. The details for each dataset are reported in Appendix D.1.

In our experiments, we also replace the Wasserstein metric in (5) by JS divergence. To well distinguish different specifications, we term the model Joint-W if using Wasserstein metric and Joint-JS if using JS divergence in this section. We consider several competitors. First, for implicit generative models, we consider valina GAN, WGAN-GP [18], likelihood-regularized GAN/WGAN-GP (short as GAN+LR/WGAN+LR), entropy-regularized GAN/WGAN-GP (short as GAN+ER/WGAN+ER) and a recently proposed variational annealing regularization [42] for GAN (short as GAN+VA/WGAN+ VA) to compare the quality of generated samples. We employ the denoising auto-encoder to estimate the gradient for regularization penalty, which is proposed by [1] and utilized by [42]. Second, for explicit density models, we consider Deep Energy Model (DEM) which is optimized based on Stein discrepancy, and energy-based GAN (EGAN) [6]. Besides, we also compare with Deep Directed Generative (DGM) Model [25] which adopts contrastive divergence to unite sample generator and density estimator. See Appendix A for brief introduction of these methods and Appendix D.3 for implementation details for each method.

Figure 3: (a) True samples and (b)(f) generated samples produced by the generators of different methods on Two-Circle (upper line) and Two-Spiral (bottom line) datasets.
Figure 4: (a) True densities and (b)(f) estimated densities given by the estimators of different methods on Two-Circle (upper line) and Two-Spiral (bottom line) datasets.
MNIST (Conditional) MNIST (Unconditional) CIFAR-10 (Unconditional)
Method Score CEPC Method Score CEPC Method Score CEPC
DCGAN 8.43 0.168 WGAN-GP 7.71 0.256 WGAN-GP 6.80 0.153
DCGAN+LR 8.40 0.171 WGAN+LR 7.82 0.243 WGAN+LR 6.89 0.154
DCGAN+ER 8.33 0.179 WGAN+ER 7.75 0.252 WGAN+ER 6.99 0.156
DCGAN+VA 8.40 0.172 WGAN+VA 7.74 0.254 WGAN+VA 6.95 0.154
DGM 8.15 0.201 DGM 6.87 0.372 DGM 4.79 0.146
Joint-JS(ours) 8.53 0.156 Joint-W(ours) 7.90 0.231 Joint-W(ours) 7.11 0.151
Table 2: Inception scores (higher is better) and conditional entropies (short as CEPC and lower is better) on MNIST and CIFAR-10. We directly use the best result reported in their paper.

5.1 Sample Quality of Implicit Model

Calibrating explicit density model with implicit generator is expected to improve the quality of generated samples. In Fig. 3 and Fig. 4 we show the results of different generators in Two-Circle and Two-Spiral datasets. As we can see, in Two-Circle, there are a large number of generated samples given by GAN, WGAN-GP and DGM (the worst one in this case) locating between two Gaussian components, and the boundary for each component is not distinguishable. Since the ground-truth densities of regions between two components are very low, such generated samples possess low-quality, which depicts that these models capture the combinations of two dominated features (i.e., modes) in the data but such combination does not make sense in practice. By contrast, Joint-JS and Joint-W could alleviate such issue, reduce the low-quality samples and produce more distinguishable boundaries for components. In Two-Spiral, similarly, the generated samples given by GAN and WGAN-GP form a circle instead of two spirals while the samples of DGM ‘link’ two spirals. Joint-JS manages to focus more on true high densities compared to GAN and Joint-W provides the best results. To quantitatively measure the sample quality, we adopt two metrics: Maximum Mean Discrepancy (MMD) and High-quality Sample Rate (HSR). The detailed definitions are given in Appendix D.4 and we report the results in Table 5.

We visualize the generated digits/images on MNIST/CIFAR-10 datasets in Fig. 9 and Fig. 10 and use Inception Score and conditional entropy of predicted classes (CEPC) to measure the sample quality (See Appendix D.4 for details). As shown in Table 2, Joint-W (resp. Joint-JS) is superior than WGAN-GP (resp. DCGAN), regularized WGAN (resp. DCGAN) and DGM. The CEPC characterizes how well the picture can be distinguished by a pre-trained classifier, i.e., the quality of picture, so the results depict that proposed method could give higher-quality generated pictures.

(a) Two-Circle
(b) Two-Spiral
(c) MNIST (Condition)
(d) MNIST (Uncondition)
Figure 5: Learning curves of Joint-W (resp. Joint-JS) compared with WGAN (resp. GAN or DCGAN) and its regularization-based variants.

5.2 Density Estimation of Explicit Model

Another advantage of joint learning is that the generator could help the density estimator to capture more accurate distribution. As shown in Fig 3, both Joint-JS and Joint-W manage to capture all Gaussian components while other methods miss some of modes. In Fig 4, Joint-JS and Joint-W exactly fit the ground-truth distribution. By contrast, DEM misses one spiral while EGAN degrades to a uniform-like distribution. DGM manages to fit two spirals but allocate high densities to regions that have low densities in the groung-truth distribution. To quantitatively measure the performance, we introduce three evaluation metrics: KL & JS divergence between the ground-truth and estimated densities and Area Under the Curve (AUC) for false-positive rate v.s. true-positive rate where we select points with true high (resp. low) densities as positive (resp. negative) examples. The detailed information and results are given in Appendix D.4 and Table 5 respectively. The values show that Joint-W and Joint-JS could provide more accurate density estimation than other competitors.

We also rank the generated digits (and true digits) on MNIST w.r.t the densities given by the energy model in Fig. 11, Fig. 12 and Fig. 13. As depicted in the figures, the digits with high densities (or low densities) given by Joint-JS possess enough diversity (the thickness, the inclination angles as well as the shapes of digits diverses). By constrast, all the digits with high densities given by DGM tend to be thin and digits with low densities are very thick. Also, as for EGAN, digits with high (or low) densities appear to have the same inclination angle (for high densities, ‘1’ keeps straight and ‘9’ ’leans’ to the left while for low densities, just the opposite). Such phenomenon indicates that DGM and EGAN tend to allocate high (or low) densities to data with certain modes and would miss some modes that possibly possess high densities in ground-truth distributions. Fortunately, our method overcomes the issue and manages to capture complicated distributions.

5.3 Enhancing the Stability of GAN

Our discussions and analysis show that joint training helps to stabilize GAN training. In Fig. 5 we present the learning curves of Joint-W (resp. Joint-JS) compared with WGAN (resp. GAN or DCGAN) and its regularization-based variants on different datasets. One can clearly see from the curves that joint training could reduce the variance of metric values especially during the second half of training. Furthermore, we visualize the generated pictures given by the same noise in adjacent epoches in Fig. 7. The results show that Joint-W outputs more stable generation in adjacent epoches while the generated samples given by WGAN-GP and WGAN+VA exhibit an obvious variation. Especially, some digits generated by WGAN-GP and WGAN+VA change from one class to another. Such phenomenon is quite similar to the oscillatory behavior with non-convergence in optimization that we discuss in Section 4.1.

Figure 6: Generated digits (resp. images) given by the same noise in adjacent training epoches on MNIST (reps. CIFAR) dataset.
(a) Noised Data (b) Insufficient Data (c) Warm up on CIFAR-10
Figure 7: Joint-W with (a) noised data, (b) insufficient data and (c) ‘warm up’ iterations before joint training.

Another issue discussed in Section 4.1 is the bias of model distribution for regularized GAN methods. To quantify this evaluation, we calculate and distances between the means of 50000 generated digits (resp. images) and 50000 true digits (resp. images) in MNIST (reps. CIFAR-10). The results are shown in Table 3. The smaller distances given by Joint-W indicate that it converges to a better local optimum with smaller bias from the original data distribution. Also, in Table 6 (resp. Table 7), we report the distances for digits (resp. images) in each class on MNIST (resp. CIFAR).

5.4 Detecting Out-of-Distribution Samples

The explicit model estimates densities for each sample and one of its applications is to detect outliers in the input data. Here, we adopt CIFAR-10 to measure the ability of our estimator to distinguish the in-distribution samples and (true/false) out-of-distribution samples. We consider four situations and in each case, we consider the test images of CIFAR-10 as positive set (expected to allocate high densities) and construct a negative set (expected to allocate low densities). We let the model output densities for images in two sets, rank them according to the densities and plot the ROC curve for false-positive rate v.s. true-positive rate in Fig. 8. In the first case, we flip each image in the positive set as negative set. Note that such flipped images are not out-of-distribution samples, so the model is expected to allocate high densities to them, i.e., the ROC curve should be close to a straight line from to . The results show that Joint-W, EGAN and DEM give the exact results while DGM assigns all flipped images with lower densities, which means that it fails to capture the semantics in images. In the following three cases, we i) generate random noise, ii) average two images with different CIFAR classes, and iii) adopt Lsun Bedroom dataset as the negative set, respectively. In these situations, the model is expected to distinguish the images in two sets. The results in Fig. 8 show that DGM provides the best results while the performance of Joint-W is quite close to DGM and much better than DEM and EGAN.

MNIST CIFAR Method Dis Dis Dis Dis WGAN-GP 13.80 0.93 80.98 1.72 WGAN+LR 12.91 0.86 82.96 1.81 WGAN+ER 12.26 0.77 72.28 1.59 WGAN+VA 12.38 0.78 69.01 1.53 DGM 12.12 0.79 179.30 3.95 Joint-W 11.82 0.73 64.23 1.41 Table 3: Distances between means of generated digits (resp. images) and ground-truth digits (resp. images) on MNIST (resp. CIFAR-10).
Figure 8: ROC curves for evaluation of outlier detection on CIFAR-10.

5.5 Addressing Data Insufficiency and Noisy Data

We proceed to test the model performance in some extreme situations where the observed samples are mixed with noises or the observed samples are quite insufficient. The results are presented in Fig. 7(a) where we add different ratios of random noise to the true samples in Two-Circle dataset and Fig. 7(b) where we only sample insufficient data for training in Two-Spiral dataset. The details are in Appendix D.1. The noise in data impacts the performance of WGAN and Joint-W, but comparatively, the performance decline for Joint-W is less insignificant than WGAN, which indicates better robustness of joint training w.r.t noised data. In Fig. 7(b), when the sample size decreases from 2000 to 100, the AUC value of DEM declines dramatically, showing its dependency on sufficient training samples. By contrast, the AUC of Joint-W exhibits a small decline when the sample size is more than 500 and suffers from an obvious decline when it is less than 300. Such phenomenon demonstrates lower sensitivity of joint training to observed sample size.

5.6 When to Start Joint Learning

In our experiment, we also observe an interesting phenomenon: the performance achieved at convergence would be better if we start joint training after some iterations with independent training for the generator and the estimator. In other words, at the beginning, we could set (or some very small values) in (5) and after some iterations set it as a normal level. We report the inception scores on MNIST with different numbers of iterations for independent training in Fig. 7(c) where we can see that the score firstly goes up and then goes down when we increase iterations for independent training. Such phenomenon is quite similar to the ‘warm up’ trick used for training deep networks where one can use small learning rates at iterations in the begining and amplify its value for further training. One intuitive reason behind this phenomenon is that at the beginning, both the generator and estimator are weak and if we minimize the discrepancy between them at this point, they would possibly constrain each other and get limited in some bad local optima. When they become strong enough after some training iterations, uniting them through joint training would help them compensate and reinforce each other as our discussions.

6 Conclusions

In this paper, we aim at uniting the training for implicit generative model (represented by GAN) and explicit generative model (represented by a deep energy-based model). Besides two loss terms for GAN and energy-based model, we introduce the third loss characterized via Stein discrepancy between the generator in GAN and the energy-based model. Theoretically, we show that joint training could i) help to stablize GAN training and facilitate its convergence, and ii) enforcing dual regularization effects on both models and help to escape from local optima in optimization. We also conduct extensive experiments with different tasks and application senarios to verify our theoretical findings as well as demonstrate the superiority of our method compared with various GAN models and deep energy-based models.

References

  • [1] G. Alain and Y. Bengio (2014) What regularized auto-encoders learn from the data-generating distribution. J. Mach. Learn. Res. 15 (1), pp. 3563–3593. Cited by: §5.
  • [2] M. Arjovsky, S. Chintala, and L. Bottou (2017) Wasserstein generative adversarial networks. In ICML, pp. 214–223. Cited by: §A.2, §1, §2.3.
  • [3] A. Barp, F. Briol, A. B. Duncan, M. A. Girolami, and L. W. Mackey (2019) Minimum stein discrepancy estimators. CoRR abs/1906.08283. Cited by: §2.2.
  • [4] A. Brock, J. Donahue, and K. Simonyan (2019) Large scale GAN training for high fidelity natural image synthesis. In ICLR, Cited by: §1.
  • [5] K. Chwialkowski, H. Strathmann, and A. Gretton (2016) A kernel test of goodness of fit. In ICML, pp. 2606–2615. Cited by: §2.2.
  • [6] Z. Dai, A. Almahairi, P. Bachman, E. H. Hovy, and A. C. Courville (2017) Calibrating energy-based generative adversarial networks. In ICLR, Cited by: §A.3, §1, Figure 1, §5.
  • [7] C. Du, K. Xu, C. Li, J. Zhu, and B. Zhang (2018) Learning implicit generative models by teaching explicit ones. CoRR abs/1807.03870. Cited by: §A.3.
  • [8] Y. Du and I. Mordatch (2019) Implicit generation and generalization in energy-based models. CoRR abs/1903.08689. Cited by: §A.1.
  • [9] L.C. Evans and A. M. Society (2010) Partial differential equations. Graduate studies in mathematics, American Mathematical Society. External Links: ISBN 9781470411442, LCCN 2009044716, Link Cited by: §2.4.
  • [10] R. Gao, X. Chen, and A. J. Kleywegt (2017) Wasserstein distributional robustness and regularization in statistical learning. arXiv preprint arXiv:1712.06050. Cited by: §B.2.
  • [11] R. Gao and A. J. Kleywegt (2016) Distributionally robust stochastic optimization with wasserstein distance. arXiv preprint arXiv:1604.02199. Cited by: §B.2.
  • [12] S. Geman and D. Geman (1984) Stochastic relaxation, gibbs distributions, and the bayesian restoration of images. IEEE Trans. Pattern Anal. Mach. Intell. 6 (6), pp. 721–741. Cited by: §A.1.
  • [13] I. Gemp and S. Mahadevan (2018) Global convergence to the equilibrium of gans using variational inequalities. CoRR abs/1808.01531. Cited by: §A.2, §C.3, §4.1.
  • [14] G. Gidel, H. Berard, G. Vignoud, P. Vincent, and S. Lacoste-Julien (2019) A variational inequality perspective on generative adversarial networks. In ICLR, Cited by: §A.2, §C.3, §C.3, §4.1, §4.1.
  • [15] I. J. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. C. Courville, and Y. Bengio (2014) Generative adversarial nets. In NIPS, pp. 2672–2680. Cited by: §A.2, §1, §1, §3.1.
  • [16] I. J. Goodfellow (2017) NIPS 2016 tutorial: generative adversarial networks. CoRR abs/1701.00160. Cited by: §C.3, §4.1, §4.1.
  • [17] J. Gorham and L. Mackey (2015) Measuring sample quality with stein’s method. In Advances in Neural Information Processing Systems, pp. 226–234. Cited by: §2.2.
  • [18] I. Gulrajani, F. Ahmed, M. Arjovsky, V. Dumoulin, and A. C. Courville (2017) Improved training of wasserstein gans. In NIPS, pp. 5767–5777. Cited by: §A.2, §3.2.1, §5.
  • [19] T. Haarnoja, H. Tang, P. Abbeel, and S. Levine (2017) Reinforcement learning with deep energy-based policies. In ICML, pp. 1352–1361. Cited by: §A.1, §1.
  • [20] G. E. Hinton, S. Osindero, and Y. W. Teh (2006) A fast learning algorithm for deep belief nets. Neural Computation 18 (7), pp. 1527–1554. Cited by: §A.1, §1.
  • [21] G. E. Hinton (1999) Product of experts. In ICANN’99 Artificial Neural Networks, Cited by: §2.1.
  • [22] G. E. Hinton (2002) Training products of experts by minimizing contrastive divergence. Neural Computation 14 (8), pp. 1771–1800. Cited by: §A.1, §A.1, Figure 1.
  • [23] T. Hu, Z. Chen, H. Sun, J. Bai, M. Ye, and G. Cheng (2018) Stein neural sampler. CoRR abs/1810.03545. Cited by: §A.1.
  • [24] A. Hyvärinen (2005) Estimation of non-normalized statistical models by score matching. J. Mach. Learn. Res. 6, pp. 695–709. Cited by: §A.1.
  • [25] T. Kim and Y. Bengio (2016) Deep directed generative models with energy-based probability estimation. CoRR abs/1606.03439. Cited by: §A.1, §A.3, §2.1, Figure 1, §5.
  • [26] D. P. Kingma and M. Welling (2014) Auto-encoding variational bayes. In ICLR, Cited by: §A.2.
  • [27] Y. LeCun, S. Chopra, R. Hadsell, M. Ranzato, and F. J. Huang (2006) A tutorial on energy-based learning. Predicting Structured Data, MIT Press. Cited by: §A.1, §1.
  • [28] Y. Li and R. E. Turner (2018) Gradient estimators for implicit models. In 6th International Conference on Learning Representations, ICLR 2018, Vancouver, BC, Canada, April 30 - May 3, 2018, Conference Track Proceedings, Cited by: §A.2, §4.1.
  • [29] T. Liang and J. Stokes (2019) Interaction matters: A note on non-asymptotic local convergence of generative adversarial networks. In AISTATS, pp. 907–915. Cited by: §A.2, §C.3, §C.3, §4.1, §4.1.
  • [30] Q. Liu, J. D. Lee, and M. I. Jordan (2016) A kernelized stein discrepancy for goodness-of-fit tests. In ICML, pp. 276–284. Cited by: §2.2, §2.2.
  • [31] Q. Liu and D. Wang (2016) Stein variational gradient descent: A general purpose bayesian inference algorithm. In NIPS, pp. 2370–2378. Cited by: §A.1.
  • [32] V. Nagarajan and J. Z. Kolter (2017) Gradient descent GAN optimization is locally stable. In NIPS, pp. 5585–5595. Cited by: §4.1.
  • [33] R. M. Neal (2011) Stochastic relaxation, gibbs distributions, and the bayesian restoration of images. Handbook of Markov Chain Monte Carlo 2. Cited by: §A.1.
  • [34] J. Ngiam, Z. Chen, P. W. Koh, and A. Y. Ng (2011) Learning deep energy models. In ICML, pp. 1105–1112. Cited by: §A.1, §1, §2.1.
  • [35] A. Nguyen, J. Clune, Y. Bengio, A. Dosovitskiy, and J. Yosinski (2017) Plug & play generative networks: conditional iterative generation of images in latent space. In CVPR, pp. 3510–3520. Cited by: §A.1.
  • [36] E. Nijkamp, M. Hill, S. Zhu, and Y. Wu (2019) On learning non-convergent non-persistent short-run mcmc toward energy-based model. CoRR abs/1904.09770. Cited by: §A.1.
  • [37] C. J. Oates, M. Girolami, and N. Chopin (2017) Control functionals for monte carlo integration. Journal of the Royal Statistical Society, Series B. Cited by: §2.2.
  • [38] A. Radford, L. Metz, and S. Chintala (2016) Unsupervised representation learning with deep convolutional generative adversarial networks. In ICLR, Cited by: §A.2, §1.
  • [39] G. Robins, P. Pattison, Y. Kalish, and D. Lusher (2007) An introduction to exponential random graph (p) models for social networks. Social Networks 29 (2), pp. 173–191. Cited by: §1.
  • [40] K. Roth, A. Lucchi, S. Nowozin, and T. Hofmann (2017) Stabilizing training of generative adversarial networks through regularization. In Advances in neural information processing systems, pp. 2018–2028. Cited by: §3.2.1.
  • [41] R. Salakhutdinov and G. E. Hinton (2009) Deep boltzmann machines. In AISTATS, pp. 448–455. Cited by: §A.1, §1.
  • [42] C. Tao, S. Dai, L. Chen, K. Bai, J. Chen, C. Liu, R. Zhang, G. V. Bobashev, and L. Carin (2019) Variational annealing of gans: A langevin perspective. In ICML, pp. 6176–6185. Cited by: §A.2, §4.1, §4.1, §5.
  • [43] C. Villani (2008) Optimal transport: old and new. Vol. 338, Springer Science & Business Media. Cited by: §3.2.1.
  • [44] D. Warde-Farley and Y. Bengio (2017) Improving generative adversarial networks with denoising feature matching. In 5th International Conference on Learning Representations, ICLR 2017, Toulon, France, April 24-26, 2017, Conference Track Proceedings, Cited by: §A.2, §4.1.
  • [45] Y. N. Wu, S. C. Zhu, and X. Liu (2000) Equivalence of julesz ensembles and FRAME models. International Journal of Computer Vision 38 (3), pp. 247–265. Cited by: §A.1.
  • [46] J. Xie, Y. Lu, R. Gao, and Y. N. Wu (2018) Cooperative learning of energy-based model and latent variable model via MCMC teaching. In AAAI, pp. 4292–4301. Cited by: §A.3, Figure 1.
  • [47] J. Xie, Y. Lu, S. Zhu, and Y. N. Wu (2016) A theory of generative convnet. In ICML, pp. 2635–2644. Cited by: §A.1.
  • [48] J. Xie, Y. Lu, S. Zhu, and Y. N. Wu (2016) Cooperative training of descriptor and generator networks. CoRR abs/1609.09408. Cited by: §A.3, Figure 1.
  • [49] S. Zhai, Y. Cheng, W. Lu, and Z. Zhang (2016) Deep structured energy based models for anomaly detection. In ICML, pp. 1100–1109. Cited by: §A.1, §1.
  • [50] J. J. Zhao, M. Mathieu, and Y. LeCun (2017) Energy-based generative adversarial networks. In ICLR, Cited by: §A.3, Figure 1.
  • [51] S. C. Zhu, Y. N. Wu, and D. Mumford (1997) Minimax entropy principle and its application to texture modeling. Neural Computation 9 (8), pp. 1627–1660. Cited by: §A.1.

Appendix A Literature Reviews

We discuss some of related literatures and shed lights on the relationship between our work with others.

a.1 Explicit Generative Models

Explicit generative models are interested in fitting each instance with a scaler density expected to explicitly capture the distribution behind data. Such densities are often up to a constant and called as energy functions which are common in undirected graphical models [27]. Hence, explicit generative models are also termed as energy-based models. An early version of energy-based models is the FRAME (Filters, Random field, And Maximum Entropy) model [51, 45]. Later on, some works leverage deep neural networks to model the energy function [34, 47] and pave the way for researches on deep energy model (DEM) (e.g., [22, 25, 49, 19, 8, 36]). Apart from DEM, there are also some other forms of deep explicit models based on restricted Boltzmann machines like deep belief networks [20] and deep Boltzmann machines [41].

The normalized constant under the energy function requires an intractable integral over all possible instances, which makes the model hard to learn via Maximum Likelihood Estimation (MLE). To solve this issue, some works propose to approximate the constant by MCMC methods [12, 33]. However, MCMC requires an inner-loop samples in each training, which induces high computational costs. Another solution is to optimize an alternate surrogate loss function. For example, contrastive divergence (CD) [22] is proposed to measure how much KL divergence can be improved by running a small numbers of Markov chain steps towards the intractable likelihood, while score matching (SM) [24] detours the constant by minimizing the distance for gradients of log-likelihoods. Moreover, the intractable normalized constant makes it hard to sample from. To obtain an accurate samples from unnormalized densities, many studies propose to approximate the generation by diffusion-based processes, like generative flow [35] and variational gradient descent [31]. Also, a recent work [23] leverages Stein discrepancy to design a neural sampler from unnormalized densities. The fundamental disadvantage of explicit model is that the energy-based learning is difficult to accurately capture the distribution of true samples due to the low manifold of real-world instances [22].

a.2 Implicit Generative Models

Implicit generative models focus on a generation mapping from random noises to generated samples. Such mapping function is often called as generator and possesses better flexibility compared with explicit models. Two typical implicit models are Variational Auto-Encoder (VAE) [26] and Generative Adversarial Networks (GAN) [15]. VAE introduces a latent variable and attempts to maximize the variational lower bound for likelihood of joint distribution of latent variable and observable variable, while GAN targets an adversarial game between the generator and a discriminator (or critic in WGAN) that aims at discriminating the generated and true samples. In this paper, we focus on GAN and its variants (e.g., WGAN [2], WGAN-GP [18], DCGAN [38], etc.) as the implicit generative model and we leave the discussions on VAE as future work.

Two important issues concerning GAN and its variants are instability of training and local optima. The typical local optima for GAN can be divided into two categories: mode-collapse (the model fails to capture all the modes in data) and mode-redundance (the model generates modes that do not exist in data). Recently there are many attempts to solve these issues from various perspectives. One perspective is from regularization. Two typical regularization methods are likelihood-based and entropy-based regularization with the prominent examples [44] and [28] that respectively leverage denoising feature matching and implicit gradient approximation to enforce the regularization constraints. The likelihood and entropy regularizations could respectively help the generator to focus on data distribution and encourage more diverse samples, and a recent work [42] uses Langevin dynamics to indicate that i) the entropy and likelihood regularizations are equivalent and share an opposite relationship in mathematics, and ii) both regularizations would make the model converge to a surrogate point with a bias from original data distribution. Then [42] proposes a variational annealing strategy to empirically unite two regularizations and tackle the biased distributions.

To deal with the instability issue, there are also some recent literatures from optimization perspectives and proposes different algorithms to address the non-convergence of minimax game optimization (for instance, [13, 29, 14]). Moreover, the disadvantage of implicit models is the lack of explicit densities over instances, which disables the black-box generator to characterize the distributions behind data.

a.3 Attempts to Combine Both of the Worlds

Recently, there are several studies that attempt to combine explicit and implicit generative models from different ways. For instance, [50] proposes energy-based GAN that leverages energy model as discriminator to distinguish the generated and true samples. The similar idea is also used by [25] and [6] which let the discriminator estimate a scaler energy value for each sample. Such discriminator is optimized to give high energy to generated samples and low energy to true samples while the generator aims at generating samples with low energy. The fundamental difference is that [50] and [6] both aim at minimizing the discrepancy between distributions of generated and true samples while the motivation of [25] is to minimize the KL divergence between estimated densities and true samples. [25] adopts contrastive divergence (CD) to link MLE for energy model over true data with the adversarial training of energy-based GAN. However, both CD-based method and energy-based GAN have limited power for both generator and discriminator. Firstly, if the generated samples resemble true samples, then the gradients for discriminator given by true and generated samples are just the opposite and will counteract each other, and the training will stop before the discriminitor captures accurate data distribution. Second, since the objective boils down to minimizing the KL divergence (for [25]) or Wasserstein distance (for [6]) between model and true distributions, the issues concerning GAN (or WGAN) like training instability and mode-collapse would also bother these methods.

Another way for combination is by cooperative training. [48] (and its improved version [46]) leverages the samples of generator as the MCMC initialization for energy-based model. The synthesized samples produced from finite-step MCMC are closer to the energy model and the generator is optimized to make the finite-step MCMC revise its initial samples. Also, a recent work [7] proposes to regard the explicit model as a teacher net who guides the training of implicit generator as a student net to produce samples that could overcome the mode-collapse issue. The main drawback of cooperative training is that they indirectly optimize the discrepancy between the generator and data distribution via the energy model as a ‘mediator’, which leads to a fact that once the energy model gets stuck in a local optimum (e.g., mode-collapse or mode-redundance) the training for the generator would be affected. In other words, the training for two models would constrain rather than exactly compensate each other. In Table 1, we do a high-level comparison among the above-mentioned generative models w.r.t the objectives. Different from other methods, our model considers three discrepancies simultaneously as a triangle to jointly train the generator and the estimator, enabling them to compensate and reinforce each other.

Appendix B Proofs of Results in Section 3.2

b.1 Proof of Theorem 1

Proof.

Fixing the Wasserstein critic , we are going to solve for

(14)

By definition, if there exists a -measure zero set with positive -measure, then . Hence, to solve (14), it suffices to consider distributions whose support belongs to . Since is dense in , we can restrict to those ’s that are absolutely continuous with respect to :

Set , the problem above becomes

For the minimization problem above, invoking Lagrangian duality gives

Applying the approximation to the minimization problem above yields

Consider a further approximation

(15)

By definition of , the inner infimum equals , and hence (15) equals

Formally, the gap between (14) and (15) is , which completes the proof. We note that a sufficient condition to make the above formal derivation hold is that is compact and the kernel is bounded, although this can be greatly weakened.

b.2 Proof for Theorem 2

Proof.

Essentially the result is a consequence of distributionally robust optimization with Wasserstein metric [11, 10]. Here we provide a simplified version for completeness. Consider

By assumption, using the definition of Wasserstein metric, we write the problem above as

where the minimization is over all joint distributions of with -marginal being . Using the law of total expectation, the problem above is equivalent to

where the minimization is over , all conditional distributions of given . If , then the minimal value equals , otherwise the minimal value equals . Hence the proof is completed. ∎

Appendix C Proofs and More Discussions in Section 4

c.1 Details for One-Dimensional Case

For the analysis of 1-dim regularized WGAN in section 3.1.1, we assume a Gaussian likelihood function for generated sample , which is up to a constant. Its parameter can be estimated by . Then since , we have . Like the case in WGAN, we consider . Assume and we have . Hence, for the analysis on likelihood- (and entropy-) regularized WGAN, we can study the following system:

When , the above objective degrades to (8); when (likelihood-regularization), the the gradient of regularization term pushes to shrink, which helps for convergence; when (entropy-regularization), the added term forms an amplifiying strength on and leads to divergence.

Interestingly, the added terms in (10) and the original terms in WGAN play both necessary roles to guarantee the convergence to the unique optimum points . If we remove the critic and optimize and with the remaining loss terms, we would find that the training would converge but not necessarily to (since the optimum points are not unique in this case). On the other hand, if we remove the estimator, the system degrades to (8) and would not converge to the unique optimum point . If we consider both of the world and optimize three terms together, the training would converge to a unique global optimum .

c.2 Proof for Proposition 1

Proof.

Instead of directly studying the optimization for (10), we first prove the following problem will converge to the unique optimum,

(16)

Applying alternate SGD we have the following iterations:

Then we obtain the relationship between adjacent iterations:

We further calculate the eigenvalues for matrix and have the following equations (assume the eigenvalue as ):

One can verify that the solutions to the above equation satisfy .

Then we have the following relationship

where denotes the eigenvalue with the maximum absolute value of matrix . Hence, we have

We proceed to replace , and in (16) by , and respectively and conduct a change of variable: let and . Then we get the conclusion in the proposition.

c.3 Generalization to Bilinear Systems

Our analysis in the one-dimension case inspires us that we can add affiliated variable to modify the objective and stabilize the training for general bilinear system. The bilinear system is of wide interest for researchers focusing on stability of GAN training [16, 29, 14, 13]. The general bilinear function can be written as

(17)

where are both -dimensional vectors and the objective is which can be seen as a basic form of various GAN objectives. Unfortunately, if we directly use simultaneous (resp. alternate) SGD to optimize such objectives, one can obtain divergence (resp. fluctuation). To solve the issue, some recent papers propose several optimization algorithms, like extrapolation from the past [14], crossing the curl [13] and consensus optimization [29]. Also, [29] shows that it is the interaction term which generates non-zero values for and that leads to such instability of training. Different from previous works that focused on algorithmic perspective, we propose to add new affiliated variables which modify the objective function and allow the SGD algorithm to achieve convergence without changing the optimum points.

Based on the minimax objective of (17) we add affiliated -dimensional variable (corresponding to the estimator in our model) the original system and tackle the following problem:

(18)

where , and is a non-negative constant. Theoretically, the new problem keeps the optimum points of (17) unchanged. Let

Proposition 2.

Assume the optimum point of are , then the optimum points of (18) would be where .

Proof.

The condition tells us that and . Then we derive the gradients for ,

(19)
(20)
(21)

Combining (20) and (21) we get . Hence, the optimum point of (18) is where . ∎

The advantage of the new problem is that it can be solved by SGD algorithm and guarantees convergence theoretically. We formulate the results in the following theorem.

Theorem 4.

For problem using alternate SGD algorithm, i.e.,

(22)

we can achieve convergence to where with at least linear rate of where , and (resp. ) denotes the maximum (resp. minimum) singular value of matrix .

To prove Theorem 3, we can prove a more general argument.

Lemma 1.

If we consider any first-order optimization method on (18), i.e.,