Wasserstein Auto-Encoders

Wasserstein Auto-Encoders

Abstract

We propose the Wasserstein Auto-Encoder (WAE)—a new algorithm for building a generative model of the data distribution. WAE minimizes a penalized form of the Wasserstein distance between the model distribution and the target distribution, which leads to a different regularizer than the one used by the Variational Auto-Encoder (VAE) [1]. This regularizer encourages the encoded training distribution to match the prior. We compare our algorithm with several other techniques and show that it is a generalization of adversarial auto-encoders (AAE) [2]. Our experiments show that WAE shares many of the properties of VAEs (stable training, encoder-decoder architecture, nice latent manifold structure) while generating samples of better quality, as measured by the FID score.

1Introduction

The field of representation learning was initially driven by supervised approaches, with impressive results using large labelled datasets. Unsupervised generative modeling, in contrast, used to be a domain governed by probabilistic approaches focusing on low-dimensional data. Recent years have seen a convergence of those two approaches. In the new field that formed at the intersection, variational auto-encoders (VAEs) [1] constitute one well-established approach, theoretically elegant yet with the drawback that they tend to generate blurry samples when applied to natural images. In contrast, generative adversarial networks (GANs) [3] turned out to be more impressive in terms of the visual quality of images sampled from the model, but come without an encoder, have been reported harder to train, and suffer from the “mode collapse” problem where the resulting model is unable to capture all the variability in the true data distribution. There has been a flurry of activity in assaying numerous configurations of GANs as well as combinations of VAEs and GANs. A unifying framework combining the best of GANs and VAEs in a principled way is yet to be discovered.

This work builds up on the theoretical analysis presented in [11]. Following [4] and [11], we approach generative modeling from the optimal transport (OT) point of view. The OT cost [5] is a way to measure a distance between probability distributions and provides a much weaker topology than many others, including -divergences associated with the original GAN algorithms [6]. This is particularly important in applications, where data is usually supported on low dimensional manifolds in the input space . As a result, stronger notions of distances (such as -divergences, which capture the density ratio between distributions) often max out, providing no useful gradients for training. In contrast, OT was claimed to have a nicer behaviour [4] although it requires, in its GAN-like implementation, the addition of a constraint or a regularization term into the objective.

In this work we aim at minimizing OT between the true (but unknown) data distribution and a latent variable model specified by the prior distribution of latent codes and the generative model of the data points given . Our main contributions are listed below (cf. also Figure 2):

  • A new family of regularized auto-encoders (Algorithms ?, ? and Equation 4), which we call Wasserstein Auto-Encoders (WAE), that minimize the optimal transport for any cost function . Similarly to VAE, the objective of WAE is composed of two terms: the -reconstruction cost and a regularizer penalizing a discrepancy between two distributions in : and a distribution of encoded data points, i.e. . When is the squared cost and is the GAN objective, WAE coincides with adversarial auto-encoders of [2].

  • Empirical evaluation of WAE on MNIST and CelebA datasets with squared cost . Our experiments show that WAE keeps the good properties of VAEs (stable training, encoder-decoder architecture, and a nice latent manifold structure) while generating samples of better quality, approaching those of GANs.

  • We propose and examine two different regularizers . One is based on GANs and adversarial training in the latent space . The other uses the maximum mean discrepancy, which is known to perform well when matching high-dimensional standard normal distributions [8]. Importantly, the second option leads to a fully adversary-free min-min optimization problem.

  • Finally, the theoretical considerations presented in [11] and used to derive the WAE objective might be interesting in their own right. In particular, Theorem ? shows that in the case of generative models, the primal form of is equivalent to a problem involving the optimization of a probabilistic encoder .

Figure 1: Both VAE and WAE minimize two terms: the reconstruction cost and the regularizer penalizing discrepancy between P_Z and distribution induced by the encoder Q. VAE forces Q(Z|X=x) to match P_Z for all the different input examples x drawn from P_X. This is illustrated on picture (a), where every single red ball is forced to match P_Z depicted as the white shape. Red balls start intersecting, which leads to problems with reconstruction. In contrast, WAE forces the continuous mixture Q_Z:=\int Q(Z|X) dP_X to match P_Z, as depicted with the green ball in picture (b). As a result latent codes of different examples get a chance to stay far away from each other, promoting a better reconstruction.
Figure 1: Both VAE and WAE minimize two terms: the reconstruction cost and the regularizer penalizing discrepancy between and distribution induced by the encoder . VAE forces to match for all the different input examples drawn from . This is illustrated on picture (a), where every single red ball is forced to match depicted as the white shape. Red balls start intersecting, which leads to problems with reconstruction. In contrast, WAE forces the continuous mixture to match , as depicted with the green ball in picture (b). As a result latent codes of different examples get a chance to stay far away from each other, promoting a better reconstruction.

VAE

Figure 2: Both VAE and WAE minimize two terms: the reconstruction cost and the regularizer penalizing discrepancy between P_Z and distribution induced by the encoder Q. VAE forces Q(Z|X=x) to match P_Z for all the different input examples x drawn from P_X. This is illustrated on picture (a), where every single red ball is forced to match P_Z depicted as the white shape. Red balls start intersecting, which leads to problems with reconstruction. In contrast, WAE forces the continuous mixture Q_Z:=\int Q(Z|X) dP_X to match P_Z, as depicted with the green ball in picture (b). As a result latent codes of different examples get a chance to stay far away from each other, promoting a better reconstruction.
Figure 2: Both VAE and WAE minimize two terms: the reconstruction cost and the regularizer penalizing discrepancy between and distribution induced by the encoder . VAE forces to match for all the different input examples drawn from . This is illustrated on picture (a), where every single red ball is forced to match depicted as the white shape. Red balls start intersecting, which leads to problems with reconstruction. In contrast, WAE forces the continuous mixture to match , as depicted with the green ball in picture (b). As a result latent codes of different examples get a chance to stay far away from each other, promoting a better reconstruction.

WAE

The paper is structured as follows. In Section 2 we review a novel auto-encoder formulation for OT between and the latent variable model derived in [11]. Relaxing the resulting constrained optimization problem we arrive at an objective of Wasserstein auto-encoders. We propose two different regularizers, leading to WAE-GAN and WAE-MMD algorithms. Section 3 discusses the related work. We present the experimental results in Section 4 and conclude by pointing out some promising directions for future work.

2Proposed method

Our new method minimizes the optimal transport cost based on the novel auto-encoder formulation derived in [11] (see Theorem ? below). In the resulting optimization problem the decoder tries to accurately reconstruct the encoded training examples as measured by the cost function . The encoder tries to simultaneously achieve two conflicting goals: it tries to match the encoded distribution of training examples to the prior as measured by any specified divergence , while making sure that the latent codes provided to the decoder are informative enough to reconstruct the encoded training examples. This is schematically depicted on Figure 2.

2.1Preliminaries and notations

We use calligraphic letters (i.e.) for sets, capital letters (i.e.) for random variables, and lower case letters (i.e.) for their values. We denote probability distributions with capital letters (i.e.) and corresponding densities with lower case letters (i.e.). In this work we will consider several measures of discrepancy between probability distributions and . The class of -divergences [9] is defined by , where is any convex function satisfying . Classical examples include the Kullback-Leibler and Jensen-Shannon divergences.

2.2Optimal transport and its dual formulations

A rich class of divergences between probability distributions is induced by the optimal transport (OT) problem [5]. Kantorovich’s formulation of the problem is given by

where is any measurable cost function and is a set of all joint distributions of with marginals and respectively. A particularly interesting case is when is a metric space and for . In this case , the -th root of , is called the -Wasserstein distance.

When the following Kantorovich-Rubinstein duality holds1:

where is the class of all bounded 1-Lipschitz functions on .

2.3Application to generative models: Wasserstein auto-encoders

One way to look at modern generative models like VAEs and GANs is to postulate that they are trying to minimize certain discrepancy measures between the data distribution and the model . Unfortunately, most of the standard divergences known in the literature, including those listed above, are hard or even impossible to compute, especially when is unknown and is parametrized by deep neural networks. Previous research provides several tricks to address this issue.

In case of minimizing the KL-divergence , or equivalently maximizing the marginal log-likelihood , the famous variational lower bound provides a theoretically grounded framework successfully employed by VAEs [1]. More generally, if the goal is to minimize the -divergence (with one example being ), one can resort to its dual formulation and make use of -GANs and the adversarial training [6]. Finally, OT cost is yet another option, which can be, thanks to the celebrated Kantorovich-Rubinstein duality , expressed as an adversarial objective as implemented by the Wasserstein-GAN [4].

In this work we will focus on latent variable models defined by a two-step procedure, where first a code is sampled from a fixed distribution on a latent space and then is mapped to the image with a (possibly random) transformation. This results in a density of the form

assuming all involved densities are properly defined. For simplicity we will focus on non-random decoders, i.e.generative models deterministically mapping to for a given map . Similar results for random decoders can be found in [11].

It turns out that under this model, the OT cost takes a simpler form as the transportation plan factors through the map : instead of finding a coupling in between two random variables living in the space, one distributed according to and the other one according to , it is sufficient to find a conditional distribution such that its marginal is identical to the prior distribution . This is the content of the theorem below proved in [11]:

This result allows us to optimize over random encoders instead of optimizing over all couplings between and . Of course, both problems are still constrained. In order to implement a numerical solution we relax the constraints on by adding a penalty to the objective. This finally leads us to the WAE objective:

where is any nonparametric set of probabilistic encoders, is an arbitrary divergence between and , and is a hyperparameter. Similarly to VAE, we propose to use deep neural networks to parametrize both encoders and decoders . Note that as opposed to VAEs, the WAE formulation allows for non-random encoders deterministically mapping inputs to their latent codes.

We propose two different penalties :

GAN-based . The first option is to choose and use the adversarial training to estimate it. Specifically, we introduce an adversary (discriminator) in the latent space trying to separate2 “true” points sampled from and “fake” ones sampled from [3]. This results in the WAE-GAN described in Algorithm ?. Even though WAE-GAN falls back to the min-max problem, we move the adversary from the input (pixel) space to the latent space . On top of that, may have a nice shape with a single mode (for a Gaussian prior), in which case the task should be easier than matching an unknown, complex, and possibly multi-modal distributions as usually done in GANs. This is also a reason for our second penalty:

MMD-based . For a positive-definite reproducing kernel the following expression is called the maximum mean discrepancy (MMD):

where is the RKHS of real-valued functions mapping to . If is characteristic then defines a metric and can be used as a divergence measure. We propose to use . Fortunately, MMD has an unbiased U-statistic estimator, which can be used in conjunction with stochastic gradient descent (SGD) methods. This results in the WAE-MMD described in Algorithm ?. It is well known that the maximum mean discrepancy performs well when matching high-dimensional standard normal distributions [8] so we expect this penalty to work especially well working with the Gaussian prior .

3Related work

Literature on auto-encoders Classical unregularized auto-encoders minimize only the reconstruction cost. This results in different training points being encoded into non-overlapping zones chaotically scattered all across the space with “holes” in between where the decoder mapping has never been trained. Overall, the encoder trained in this way does not provide a useful representation and sampling from the latent space becomes hard [12].

Variational auto-encoders [1] minimize a variational bound on the KL-divergence which is composed of the reconstruction cost plus which captures how distinct the image by the encoder of each training example is from the prior , which is not guaranteeing that the overall encoded distribution matches like WAE does. Also, VAEs require non-degenerate Gaussian encoders and random decoders for which can be computed and differentiated with respect to the parameters. Later [10] proposed a way to use VAE with non-Gaussian encoders. WAE minimizes OT and allows both probabilistic and deterministic encoder-decoder pairs of any kind.

When used with WAE-GAN is equivalent to adversarial auto-encoders (AAE) proposed by [2]. Our theory thus suggests that AAEs minimize the 2-Wasserstein distance between and . This provides the first theoretical justification for AAEs known to the authors. WAE generalizes AAE in two ways: first, it can use any cost function in the input space ; second, it can use any discrepancy measure in the latent space (for instance MMD), not necessarily the adversarial one of WAE-GAN.

Literature on OT [13] address computing the OT cost in large scale using SGD and sampling. They approach this task either through the dual formulation, or via a regularized version of the primal. They do not discuss any implications for generative modeling. Our approach is based on the primal form of OT, we arrive at regularizers which are very different, and our main focus is on generative modeling.

The WGAN [4] minimizes the 1-Wasserstein distance for generative modeling. The authors approach this task from the dual form. Their algorithm comes without an encoder and can not be readily applied to any other cost , because the neat form of the Kantorovich-Rubinstein duality holds only for . WAE approaches the same problem from the primal form, can be applied for any cost function , and comes naturally with an encoder.

In order to compute the values or of OT we need to handle non-trivial constraints, either on the coupling distribution or on the function being considered. Various approaches have been proposed in the literature to circumvent this difficulty. For [4] tried to implement the constraint in the dual formulation by clipping the weights of the neural network . Later [7] proposed to relax the same constraint by penalizing the objective of with a term which should not be greater than 1 if . In a more general OT setting of [14] proposed to penalize the objective of with the KL-divergence between the coupling distribution and the product of marginals. [13] showed that this entropic regularization drops the constraints on functions in the dual formulation as opposed to . Finally, in the context of unbalanced optimal transport it has been proposed to relax the constraint in by regularizing the objective with [15], where and are marginals of . In this paper we propose to relax OT in a way similar to the unbalanced optimal transport, i.e. by adding additional divergences to the objective. However, we show that in the particular context of generative modeling, only one extra divergence is necessary.

VAE

Figure 3: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
Figure 3: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In “test reconstructions” odd rows correspond to the real test points.
Figure 4: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
Figure 4: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In “test reconstructions” odd rows correspond to the real test points.
Figure 5: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
Figure 5: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In “test reconstructions” odd rows correspond to the real test points.

WAE-MMD

Figure 6: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
Figure 6: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In “test reconstructions” odd rows correspond to the real test points.
Figure 7: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
Figure 7: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In “test reconstructions” odd rows correspond to the real test points.
Figure 8: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
Figure 8: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In “test reconstructions” odd rows correspond to the real test points.

WAE-GAN

Figure 9: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
Figure 9: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In “test reconstructions” odd rows correspond to the real test points.
Figure 10: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
Figure 10: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In “test reconstructions” odd rows correspond to the real test points.
Figure 11: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points.
Figure 11: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In “test reconstructions” odd rows correspond to the real test points.

Literature on GANs Many of the GAN variations (including -GAN and WGAN) come without an encoder. Often it may be desirable to reconstruct the latent codes and use the learned manifold, in which cases these models are not applicable.

There have been many other approaches trying to blend the adversarial training of GANs with auto-encoder architectures [17]. The approach proposed by [19] is perhaps the most relevant to our work. The authors use the discrepancy between and the distribution of auto-encoded noise vectors as the objective for the max-min game between the encoder and decoder respectively. While the authors showed that the saddle points correspond to , they admit that encoders and decoders trained in this way have no incentive to be reciprocal. As a workaround they propose to include an additional reconstruction term to the objective. WAE does not necessarily lead to a min-max game, uses a different penalty, and has a clear theoretical foundation.

Several works used reproducing kernels in context of GANs. [21] use MMD with a fixed kernel to match and directly in the input space . These methods have been criticised to require larger mini-batches during training: estimating requires number of samples roughly proportional to the dimensionality of the input space [23] which is typically larger than . [24] take a similar approach but further train adversarially so as to arrive at a meaningful loss function. WAE-MMD uses MMD to match to the prior in the latent space . Typically has no more than dimensions and is Gaussian, which allows us to use regular mini-batch sizes to accurately estimate MMD.

VAE

Figure 12: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
Figure 12: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In “test reconstructions” odd rows correspond to the real test points.
Figure 13: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
Figure 13: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In “test reconstructions” odd rows correspond to the real test points.
Figure 14: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
Figure 14: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In “test reconstructions” odd rows correspond to the real test points.

WAE-MMD

Figure 15: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
Figure 15: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In “test reconstructions” odd rows correspond to the real test points.
Figure 16: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
Figure 16: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In “test reconstructions” odd rows correspond to the real test points.
Figure 17: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
Figure 17: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In “test reconstructions” odd rows correspond to the real test points.

WAE-GAN

Figure 18: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
Figure 18: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In “test reconstructions” odd rows correspond to the real test points.
Figure 19: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
Figure 19: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In “test reconstructions” odd rows correspond to the real test points.
Figure 20: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In test reconstructions odd rows correspond to the real test points.
Figure 20: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on CelebA dataset. In “test reconstructions” odd rows correspond to the real test points.

4Experiments

In this section we empirically evaluate the proposed WAE model. We would like to test if WAE can simultaneously achieve (i) accurate reconstructions of data points, (ii) reasonable geometry of the latent manifold, and (iii) random samples of good (visual) quality. Importantly, the model should generalize well: requirements (i) and (ii) should be met on both training and test data. We trained WAE-GAN and WAE-MMD (Algorithms ? and ?) on two real-world datasets: MNIST [25] consisting of 70k images and CelebA [26] containing roughly 203k images.

Experimental setup In all reported experiments we used Euclidian latent spaces for various depending on the complexity of the dataset, isotropic Gaussian prior distributions over , and a squared cost function for data points . We used deterministic encoder-decoder pairs, Adam [27] with , and convolutional deep neural network architectures for encoder mapping and decoder mapping similar to the DCGAN ones reported by [28] with batch normalization [29]. We tried various values of and noticed that seems to work good across all datasets we considered. All reported experiments use this value.

Since we are using deterministic encoders, choosing larger than intrinsic dimensionality of the dataset would force the encoded distribution to live on a manifold in . This would make matching to impossible if is Gaussian and may lead to numerical instabilities. We use for MNIST and for CelebA which seems to work reasonably well.

We also report results of VAEs. VAEs used the same latent spaces as discussed above and standard Gaussian priors . We used Gaussian encoders with mean and diagonal covariance . For MNIST we used Bernoulli decoders parametrized by and for CelebA the Gaussian decoders with mean . Functions , , and were parametrized by deep nets of the same architectures as used in WAE.

WAE-GAN and WAE-MMD specifics In WAE-GAN we used discriminator composed of several fully connected layers with ReLu. We tried WAE-MMD with the RBF kernel but observed that it fails to penalize the outliers of because of the quick tail decay. If the codes for some of the training points end up far away from the support of (which may happen in the early stages of training) the corresponding terms in the U-statistic will quickly approach zero and provide no gradient for those outliers. This could be avoided by choosing the kernel bandwidth in a data-dependent manner, however in this case per-minibatch U-statistic would not provide an unbiased estimate for the gradient. Instead, we used the inverse multiquadratics kernel which is also characteristic and has much heavier tails. In all experiments we used , which is the expected squared distance between two multivariate Gaussian vectors drawn from . This significantly improved the performance compared to the RBF kernel (even the one with ). Trained models are presented in Figures Figure 11 and Figure 20. Further details are presented in Supplementary Section 6.

Random samples are generated by sampling and decoding the resulting noise vectors into . As expected, in our experiments we observed that for both WAE-GAN and WAE-MMD the quality of samples strongly depends on how accurately matches . To see this, notice that while training the decoder function is presented only with encoded versions of the data points . Indeed, the decoder is trained on samples from and thus there is no reason to expect good results when feeding it with samples from . In our experiments we noticed that even slight differences between and may affect the quality of samples.

In some cases WAE-GAN seems to lead to a better matching and generates better samples than WAE-MMD. However, due to adversarial training WAE-GAN is highly unstable, while WAE-MMD has a very stable training much like VAE.

In order to quantitatively assess the quality of the generated images, we use the Fréchet Inception Distance introduced by [30] and report the results on CelebA in Table ?. These results confirm that the sampled images from WAE are of better quality than from VAE, and WAE-GAN gets a slightly better score than WAE-MMD, which correlates with visual inspection of the images.

Test reconstructions and interpolations. We take random points from the held out test set and report their auto-encoded versions . Next, pairs of different data points are sampled randomly from the held out test set and encoded: , . We linearly interpolate between and with equally-sized steps in the latent space and show decoded images.

5Conclusion

Using the optimal transport cost, we have derived Wasserstein auto-encoders—a new family of algorithms for building generative models. We discussed their relations to other probabilistic modeling techniques. We conducted experiments using two particular implementations of the proposed method, showing that in comparison to VAEs, the images sampled from the trained WAE models are of better quality, without compromising the stability of training and the quality of reconstruction. Future work will include further exploration of the criteria for matching the encoded distribution to the prior distribution , assaying the possibility of adversarially training the cost function in the input space , and a theoretical analysis of the dual formulations for WAE-GAN and WAE-MMD.

Acknowledgments

The authors are thankful to Mateo Rojas-Carulla, Arthur Gretton, and Fei Sha for stimulating discussions.

6Further details on experiments

MNIST: We use mini-batches of size 100, , and 4x4 convolutional filters. The reported models were trained for 100 epochs. We used for Adam in the beginning, decreased it to after 30 epochs, and to after first 50 epochs.

We pre-processed CelebA images by first taking a 140x140 center crops and then resizing to the 64x64 resolution. We used mini-batches of size 100 and trained the models for various number of epochs (up to 250). All reported WAE models were trained for 55 epochs and VAE for 35 epochs (we ran VAE for another 60 epochs and confirmed that it has converged). Initial learning rate of Adam was set to as often recommended in the literature, decreased it to after 30 epochs, and to after first 50 epochs. FID scores of Table ? were computed based on samples of 10k images.

Footnotes

  1. Note that the same symbol is used for and , but only is a number and thus the above refers to the 1-Wasserstein distance.
  2. We noticed that the famous “log trick” (also called “non saturating loss”) proposed by [3] leads to better results.

References

  1. Auto-encoding variational Bayes.
    D. P. Kingma and M. Welling. In ICLR, 2014.
  2. Adversarial autoencoders.
    A. Makhzani, J. Shlens, N. Jaitly, and I. Goodfellow. In ICLR, 2016.
  3. Generative adversarial nets.
    Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. In NIPS, pages 2672–2680, 2014.
  4. Wasserstein GAN, 2017.
    M. Arjovsky, S. Chintala, and L. Bottou.
  5. Topics in Optimal Transportation.
    C. Villani. AMS Graduate Studies in Mathematics, 2003.
  6. f-GAN: Training generative neural samplers using variational divergence minimization.
    Sebastian Nowozin, Botond Cseke, and Ryota Tomioka. In NIPS, 2016.
  7. Improved training of wasserstein GANs, 2017.
    I. Gulrajani, F. Ahmed, M. Arjovsky, V. Domoulin, and A. Courville.
  8. A kernel two-sample test.
    A. Gretton, K. M. Borgwardt, M. J. Rasch, B. Schölkopf, and A. J. Smola. Journal of Machine Learning Research, 13:723–773, 2012.
  9. Statistical Decision Theory.
    F. Liese and K.-J. Miescke. Springer, 2008.
  10. Adversarial variational bayes: Unifying variational autoencoders and generative adversarial networks, 2017.
    L. Mescheder, S. Nowozin, and A. Geiger.
  11. From optimal transport to generative modeling: the VEGAN cookbook, 2017.
    O. Bousquet, S. Gelly, I. Tolstikhin, C. J. Simon-Gabriel, and B. Schölkopf.
  12. Representation learning: A review and new perspectives.
    Y. Bengio, A. Courville, and P. Vincent. Pattern Analysis and Machine Intelligence, 35, 2013.
  13. Stochastic optimization for large-scale optimal transport.
    A. Genevay, M. Cuturi, G. Peyré, and F. R. Bach. In Advances in Neural Information Processing Systems, pages 3432–3440, 2016.
  14. Sinkhorn distances: Lightspeed computation of optimal transport.
    M. Cuturi. In Advances in Neural Information Processing Systems, pages 2292–2300, 2013.
  15. Unbalanced optimal transport: geometry and kantorovich formulation.
    Lenaic Chizat, Gabriel Peyré, Bernhard Schmitzer, and François-Xavier Vialard. arXiv preprint arXiv:1508.05216, 2015.
  16. Optimal entropy-transport problems and a new hellinger-kantorovich distance between positive measures.
    Matthias Liero, Alexander Mielke, and Giuseppe Savaré. arXiv preprint arXiv:1508.07941, 2015.
  17. Energy-based generative adversarial network.
    J. Zhao, M. Mathieu, and Y. LeCun. In ICLR, 2017.
  18. Adversarially learned inference.
    V. Dumoulin, I. Belghazi, B. Poole, A. Lamb, M. Arjovsky, O. Mastropietro, and A. Courville. In ICLR, 2017.
  19. It takes (only) two: Adversarial generator-encoder networks, 2017.
    D. Ulyanov, A. Vedaldi, and V. Lempitsky.
  20. Began: Boundary equilibrium generative adversarial networks, 2017.
    D. Berthelot, T. Schumm, and L. Metz.
  21. Generative moment matching networks.
    Y. Li, K. Swersky, and R. Zemel. In ICML, 2015.
  22. Training generative neural networks via maximum mean discrepancy optimization.
    G. K. Dziugaite, D. M. Roy, and Z. Ghahramani. In UAI, 2015.
  23. On the high-dimensional power of a linear-time two sample test under mean-shift alternatives.
    R. Reddi, A. Ramdas, A. Singh, B. Poczos, and L. Wasserman. In AISTATS, 2015.
  24. Mmd gan: Towards deeper understanding of moment matching network, 2017.
    C. L. Li, W. C. Chang, Y. Cheng, Y. Yang, and B. Poczos.
  25. Gradient-based learning applied to document recognition.
    Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. In Proceedings of the IEEE, volume 86(11), pages 2278–2324, 1998.
  26. Deep learning face attributes in the wild.
    Ziwei Liu, Ping Luo, Xiaogang Wang, and Xiaoou Tang. In Proceedings of International Conference on Computer Vision (ICCV), 2015.
  27. Adam: A method for stochastic optimization, 2014.
    D. P. Kingma and J. Lei.
  28. Unsupervised representation learning with deep convolutional generative adversarial networks.
    A. Radford, L. Metz, and S. Chintala. In ICLR, 2016.
  29. Batch normalization: Accelerating deep network training by reducing internal covariate shift, 2015.
    S. Ioffe and C. Szegedy.
  30. GANs trained by a two time-scale update rule converge to a nash equilibrium.
    Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, Günter Klambauer, and Sepp Hochreiter. arXiv preprint arXiv:1706.08500, 2017.
Comments 8
Request Comment
""
The feedback must be of minumum 40 characters
Add comment
Cancel
Loading ...
402
This is a comment super asjknd jkasnjk adsnkj
Upvote
Downvote
""
The feedback must be of minumum 40 characters
The feedback must be of minumum 40 characters
Submit
Cancel
8

You are asking your first question!
How to quickly get a good answer:
  • Keep your question short and to the point
  • Check for grammar or spelling errors.
  • Phrase it like a question
Test
Test description