VAE Learning via Stein Variational Gradient Descent

VAE Learning via Stein Variational Gradient Descent


A new method for learning variational autoencoders (VAEs) is developed, based on Stein variational gradient descent. A key advantage of this approach is that one need not make parametric assumptions about the form of the encoder distribution. Performance is further enhanced by integrating the proposed encoder with importance sampling. Excellent performance is demonstrated across multiple unsupervised and semi-supervised problems, including semi-supervised analysis of the ImageNet data, demonstrating the scalability of the model to large datasets.


There has been significant recent interest in the variational autoencoder (VAE) [11], a generalization of the original autoencoder [34]. VAEs are typically trained by maximizing a variational lower bound of the data log-likelihood [2]. To compute the variational expression, one must be able to explicitly evaluate the associated distribution of latent features, i.e., the stochastic encoder must have an explicit analytic form. This requirement has motivated design of encoders in which a neural network maps input data to the parameters of a simple distribution, e.g., Gaussian distributions have been widely utilized [1].

The Gaussian assumption may be too restrictive in some cases [29]. Consequently, recent work has considered normalizing flows [29], in which random variables from (for example) a Gaussian distribution are fed through a series of nonlinear functions to increase the complexity and representational power of the encoder. However, because of the need to explicitly evaluate the distribution within the variational expression used when learning, these nonlinear functions must be relatively simple, e.g., planar flows. Further, one may require many layers to achieve the desired representational power.

We present a new approach for training a VAE. We recognize that the need for an explicit form for the encoder distribution is only a consequence of the fact that learning is performed based on the variational lower bound. For inference (e.g., at test time), we do not need an explicit form for the distribution of latent features, we only require fast sampling from the encoder. Consequently, rather than directly employing the traditional variational lower bound, we seek to minimize the Kullback-Leibler (KL) distance between the true posterior of model and latent parameters. Learning then becomes a novel application of Stein variational gradient descent (SVGD) [15], constituting its first application to training VAEs. We extend SVGD with importance sampling [1], and also demonstrate its novel use in semi-supervised VAE learning.

The concepts developed here are demonstrated on a wide range of unsupervised and semi-supervised learning problems, including a large-scale semi-supervised analysis of the ImageNet dataset. These experimental results illustrate the advantage of SVGD-based VAE training, relative to traditional approaches. Moreover, the results demonstrate further improvements realized by integrating SVGD with importance sampling.

Independent work by [3] proposed the similar models, in which the aurthers incorporated SVGD with VAEs [3] and importance sampling [6] for unsupervised learning tasks.

2Stein Learning of Variational Autoencoder (Stein VAE)

2.1Review of VAE and Motivation for Use of SVGD

Consider data , where are modeled via decoder . A prior is placed on the latent codes. To learn parameters , one typically is interested in maximizing the empirical expected log-likelihood, . A variational lower bound is often employed:

with , and where is approximated by averaging over a finite number of samples drawn from encoder . Parameters and are typically iteratively optimized via stochastic gradient descent [11], seeking to maximize .

To evaluate the variational expression in (Equation 1), we require the ability to sample efficiently from , to approximate the expectation. We also require a closed form for this encoder, to evaluate . In the proposed VAE learning framework, rather than maximizing the variational lower bound explicitly, we focus on the term , which we seek to minimize. This can be achieved by leveraging Stein variational gradient descent (SVGD) [15]. Importantly, for SVGD we need only be able to sample from , and we need not possess its explicit functional form.

In the above discussion, is treated as a parameter; below we treat it as a random variable, as was considered in the Appendix of [11]. Treatment of as a random variable allows for model averaging, and a point estimate of is revealed as a special case of the proposed method.

The set of codes associated with all is represented . The prior on is here represented as . We desire the posterior . Consider the revised variational expression

where is the evidence for the underlying model . Learning such that is maximized is equivalent to seeking that minimizes . By leveraging and generalizing SVGD, we will perform the latter.

2.2Stein Variational Gradient Descent (SVGD)

Rather than explicitly specifying a form for , we sequentially refine samples of and , such that they are better matched to . We alternate between updating the samples of and samples of , analogous to how and are updated alternatively in traditional VAE optimization of (Equation 1). We first consider updating samples of , with the samples of held fixed. Specifically, assume we have samples drawn from distribution , and samples drawn from distribution . We wish to transform by feeding them through a function, and the corresponding (implicit) transformed distribution from which they are drawn is denoted as . It is desired that, in a KL sense, is closer to than was . The following theorem is useful for defining how to best update .

The proof is provided in Appendix A. Following [15], we assume lives in a reproducing kernel Hilbert space (RKHS) with kernel . Under this assumption, the solution for that maximizes the decrease in the KL distance ( ?) is

Theorem ? concerns updating samples from assuming fixed . Similarly, to update with fixed, we employ a complementary form of Theorem ? (omitted for brevity). In that case, we consider transformation , with , and function is also assumed to be in a RKHS.

The expectations in ( ?) and (Equation 3) are approximated by samples , with

with . A similar update of samples is manifested for the latent variables :

where The kernels used to update samples of and are in general different, denoted respectively and , and is a small step size. For notational simplicity, is the same in (Equation 4) and (Equation 5), but in practice a different number of samples may be used for and .

If for parameter , indices and are removed in (Equation 4). Learning then reduces to gradient descent and a point estimate for , identical to the optimization procedure used for the traditional VAE expression in (Equation 1), but with the (multiple) samples associated with sequentially transformed via SVGD (and, importantly, without the need to assume a form for ). Therefore, if only a point estimate of is desired, (Equation 1) can be optimized wrt , while for updating SVGD is applied.

2.3Efficient Stochastic Encoder

At iteration of the above learning procedure, we realize a set of latent-variable (code) samples for each under analysis. For large , training may be computationally expensive. Further, the need to evolve (learn) samples for each new test sample, , is undesirable. We therefore develop a recognition model that efficiently computes samples of latent codes for a data sample of interest. The recognition model draws samples via with . Distribution is selected such that it may be easily sampled, e.g., isotropic Gaussian.

After each iteration of updating the samples of , we refine recognition model to mimic the Stein sample dynamics. Assume recognition-model parameters have been learned thus far. Using , latent codes for iteration are constituted as , with . These codes are computed for all data , where is the minibatch of data at iteration . The change in the codes is , as defined in (Equation 5). We then update to match the refined codes, as

The analytic solution of (Equation 6) is intractable. We update with steps of gradient descent as , where , is a small step size, , , and is the transpose of the Jacobian of wrt . Note that the use of minibatches mitigates challenges of training with large training sets, .

The function plays a role analogous to in (Equation 1), in that it yields a means of efficiently drawing samples of latent codes , given observed ; however, we do not impose an explicit functional form for the distribution of these samples.

3Stein Variational Importance Weighted Autoencoder (Stein VIWAE)

3.1Multi-sample importance-weighted KL divergence

Recall the variational expression in (Equation 1) employed in conventional VAE learning. Recently, [1] showed that the multi-sample ( samples) importance-weighted estimator

provides a tighter lower bound and a better proxy for the log-likelihood, where are random variables sampled independently from . Recall from ( ?) that the KL divergence played a key role in the Stein-based learning of Section 2. Equation (Equation 7) motivates replacement of the KL objective function with the multi-sample importance-weighted KL divergence

where and are independent samples from . Note that the special case of recovers the standard KL divergence. Inspired by [1], the following theorem (proved in Appendix A) shows that increasing the number of samples is guaranteed to reduce the KL divergence and provide a better approximation of target distribution.

We minimize (Equation 8) with a sample transformation based on a generalization of SVGD and the recognition model (encoder) is trained in the same way as in Section 2.3. Specifically, we first draw samples and from a simple distribution , and convert these to approximate draws from by minimizing the multi-sample importance weighted KL divergence via nonlinear functional transformation.

3.2Importance-weighted SVGD for VAEs

The following theorem generalizes Theorem ? to multi-sample weighted KL divergence.

The proof and detailed definition is provided in Appendix A. The following corollaries generalize Theorem ? and (Equation 3) via use of importance sampling, respectively.

Corollary ? and Corollary ? provide a means of updating multiple samples from via . The expectation wrt is approximated via samples drawn from . Similarly, we can employ a complementary form of Corollary ? and Corollary ? to update multiple samples from . This suggests an importance-weighted learning procedure that alternates between update of particles and , which is similar to the one in Section 2.2. Detailed update equations are provided in Appendix B.

4Semi-Supervised Learning with Stein VAE

Consider labeled data as pairs , where the label and the decoder is modeled as , where represents the parameters of the decoder for labels. The set of codes associated with all labeled data are represented as . We desire to approximate the posterior distribution on the entire dataset via samples, where represents the unlabeled data, and is the set of codes associated with . In the following, we will only discuss how to update the samples of , and . Updating samples is the same as discussed in Sections Section 2 and Section 3.2 for Stein VAE and Stein VIWAE, respectively.

Assume drawn from distribution , drawn from distribution , and samples drawn from (distinct) distribution . The following corollary generalizes Theorem ? and (Equation 3), which is useful for defining how to best update .

Further details are provided in Appendix C.


For all experiments, we use a radial basis-function (RBF) kernel as in [15], i.e., , where the bandwidth, , is the median of pairwise distances between current samples. and are set to isotropic Gaussian distributions. We share the samples of across data points, i.e., , for (this is not necessary, but it saves computation). The samples of and , and parameters of the recognition model, , are optimized via Adam [9] with learning rate 0.0002. We do not perform any dataset-specific tuning or regularization other than dropout [33] and early stopping on validation sets. We set and , and use minibatches of size 64 for all experiments, unless otherwise specified.

5.1Expressive power of Stein recognition model

Approximation of posterior distribution: Stein VAE vs. VAE. The figures represent different samples of Stein VAE. (left) 10 samples, (center) 50 samples, and (right) 100 samples. Approximation of posterior distribution: Stein VAE vs. VAE. The figures represent different samples of Stein VAE. (left) 10 samples, (center) 50 samples, and (right) 100 samples. Approximation of posterior distribution: Stein VAE vs. VAE. The figures represent different samples of Stein VAE. (left) 10 samples, (center) 50 samples, and (right) 100 samples.

Gaussian Mixture Model We synthesize data by drawing , where , ; drawing , where and . The recognition model is specified as a multi-layer perceptron (MLP) with 100 hidden units, by first concatenating and into a long vector. The dimension of is set to 2. The recognition model for standard VAE is also an MLP with 100 hidden units, and with the assumption of a Gaussian distribution for the latent codes [11].

We generate data points for training and 10 data points for testing. The analytic form of true posterior distribution is provided in Appendix D. Figure ? shows the performance of Stein VAE approximations for the true posterior; other similar examples are provided in Appendix F. The Stein recognition model is able to capture the multi-modal posterior and produce accurate density approximation.

Figure 1: Univariate marginals and pairwise posteriors. Purple, red and green represent the distribution inferred from MCMC, standard VAE and Stein VAE, respectively.
Figure 1: Univariate marginals and pairwise posteriors. Purple, red and green represent the distribution inferred from MCMC, standard VAE and Stein VAE, respectively.

Poisson Factor Analysis Given a discrete vector , Poisson factor analysis [36] assumes is a weighted combination of latent factors , where is the factor loadings matrix and is the vector of factor scores. We consider topic modeling with Dirichlet priors on (-th column of ) and gamma priors on each component of .

We evaluate our model on the 20 Newsgroups dataset containing documents with a vocabulary of . The data are partitioned into 10,314 training, 1,000 validation and 7,531 test documents. The number of factors (topics) is set to . is first learned by Markov chain Monte Carlo (MCMC) [4]. We then fix at its MAP value, and only learn the recognition model using standard VAE and Stein VAE; this is done, as in the previous example, to examine the accuracy of the recognition model to estimate the posterior of the latent factors, isolated from estimation of . The recognition model is an MLP with 100 hidden units.

An analytic form of the true posterior distribution is intractable for this problem. Consequently, we employ samples collected from MCMC as ground truth. With fixed, we sample via Gibbs sampling, using 2,000 burn-in iterations followed by 2,500 collection draws, retaining every 10th collection sample. We show the marginal and pairwise posterior of one test data point in Figure 1. Additional results are provided in Appendix F. Stein VAE leads to a more accurate approximation than standard VAE, compared to the MCMC samples. Considering Figure 1, note that VAE significantly underestimates the variance of the posterior (examining the marginals), a well-known problem of variational Bayesian analysis [7]. In sharp contrast, Stein VAE yields highly accurate approximations to the true posterior.

5.2Density estimation

Data We consider five benchmark datasets: MNIST and four text corpora: 20 Newsgroups (20News), New York Times (NYT), Science and RCV1-v2 (RCV2). For MNIST, we used the standard split of 50K training, 10K validation and 10K test examples. The latter three text corpora consist of 133K, 166K and 794K documents. These three datasets are split into 1K validation, 10K testing and the rest for training.

Evaluation Given new data (testing data), the marginal log-likelihood/perplexity values are estimated by the variational evidence lower bound (ELBO) while integrating the decoder parameters out, , where and is the entropy. The expectation is approximated with samples and with , . Directly evaluating is intractable, thus it is estimated via density transformation .

We further estimate the marginal log-likelihood/perplexity values via the stochastic variational lower bound, as the mean of 5K-sample importance weighting estimate [1]. Therefore, for each dataset, we report four results: (


) Stein VAE + ELBO, (


) Stein VAE + S-ELBO, (


) Stein VIWAE + ELBO and (


) Stein VIWAE + S-ELBO; the first term denotes the training procedure is employed as Stein VAE in Section 2 or Stein VIWAE in Section 3; the second term denotes the testing log-likelihood/perplexity is estimated by the ELBO or the stochastic variational lower bound, S-ELBO [1].

Model For MNIST, we train the model with one stochastic layer, , with 50 hidden units and two deterministic layers, each with 200 units. The nonlinearity is set as . The visible layer, , follows a Bernoulli distribution. For the text corpora, we build a three-layer deep Poisson network [25]. The sizes of hidden units are 200, 200 and 50 for the first, second and third layer, respectively (see [25] for detailed architectures).

NLL vs. Training/Testing time on MNIST with various numbers of samples for {\boldsymbol \theta}.
NLL vs. Training/Testing time on MNIST with various numbers of samples for .

Results The log-likelihood/perplexity results are summarized in Tables ? and ?. On MNIST, our Stein VAE achieves a variational lower bound of -85.21 nats, which outperforms standard VAE with the same model architecture. Our Stein VIWAE achieves a log-likelihood of -82.88 nats, exceeding normalizing flow (-85.1 nats) and importance weighted autoencoder (-84.78 nats), which is the best prior result obtained by feedforward neural network (FNN). DRAW [5] and PixelRNN [20], which exploit spatial structure, achieved log-likelihoods of around -80 nats. Our model can also be applied on these models, but this is left as interesting future work. To further illustrate the benefit of model averaging, we vary the number of samples for (while retaining 100 samples for ) and show the results associated with training/testing time in Figure ?. When for , our model reduces to a point estimate for that parameter. Increasing the number of samples of (model averaging) improves the negative log-likelihood (NLL). The testing time of using 100 samples of is around 0.12 ms per image.

5.3Semi-supervised Classification

We consider semi-supervised classification on MNIST and ImageNet [30] data. For each dataset, we report the results obtained by () VAE, () Stein VAE, and () Stein VIWAE.

Mnist We randomly split the training set into a labeled and unlabeled set, and the number of labeled samples in each category varies from 10 to 300. We perform testing on the standard test set with 20 different training-set splits. The decoder for labels is implemented as . We consider two types of decoders for images and encoder : () FNN: Following [12], we use a 50-dimensional latent variables and two hidden layers, each with 600 hidden units, for both encoder and decoder; softplus is employed as the nonlinear activation function. () All convolutional nets (CNN): Inspired by [32], we replace the two hidden layers with 32 and 64 kernels of size and a stride of 2. A fully connected layer is stacked on the CNN to produce a 50-dimensional latent variables . We use the leaky rectified activation [16]. The input of the encoder is formed by spatially aligning and “stacking” and , while the output of decoder is the image itself.

Table ? shows the classification results. Our Stein VAE and Stein VIWAE consistently achieve better performance than the VAE. We further observe that the variance of Stein VIWAE results is much smaller than that of Stein VAE results on small labeled data, indicating the former produces more robust parameter estimates. State-of-the-art results [27] are achieved by the Ladder network, which can be employed with our Stein-based approach, however, we will consider this extension as future work.

ImageNet 2012 We consider scalability of our model to large datasets. We split the 1.3 million training images into an unlabeled and labeled set, and vary the proportion of labeled images from 1% to 40%. The classes are balanced to ensure that no particular class is over-represented, i.e., the ratio of labeled and unlabeled images is the same for each class. We repeat the training process 10 times for the training setting with labeled images ranging from 1% to 10% , and 5 times for the the training setting with labeled images ranging from 20% to 40%. Each time we utilize different sets of images as the unlabeled ones.

We employ an all convolutional net [32] for both the encoder and decoder, which replaces deterministic pooling (e.g., max-pooling) with stridden convolutions. Residual connections [8] are incorporated to encourage gradient flow. The model architecture is detailed in Appendix E. Following [13], images are resized to . A crop is randomly sampled from the images or its horizontal flip with the mean subtracted [13]. We set and .

Table ? shows classification results indicating that Stein VAE and Stein IVWAE outperform VAE in all the experiments, demonstrating the effectiveness of our approach for semi-supervised classification. When the proportion of labeled examples is too small (), DGDN [21] outperforms all the VAE-based models, which is not surprising provided that our models are deeper, thus have considerably more parameters than DGDN [21].


We have employed SVGD to develop a new method for learning a variational autoencoder, in which we need not specify an a priori form for the encoder distribution. Fast inference is manifested by learning a recognition model that mimics the manner in which the inferred code samples are manifested. The method is further generalized and improved by performing importance sampling. An extensive set of results, for unsupervised and semi-supervised learning, demonstrate excellent performance and scaling to large datasets.


This research was supported in part by ARO, DARPA, DOE, NGA, ONR and NSF.


Proof of Theorem 1 Recall the definition of KL divergence:

where . Since , we have

Following [15], we have

Proof of Theorem 2 Following [1], we have , where with , is a uniformly distributed subset of . Using Jensen’s inequality, we have

if is bounded, we have


Proof of Theorem 3 is defined as following:

Assume denote the density of . We have

Note that

and when , we have


Therefore, (Equation 9) can be rewritten as

where and .

BSamples Updating for Stein VIWAE

let and denote the samples acquired at iteration of the learning procedure. To update samples of , we apply the transformation , for , by approximating the expectation by samples , and we have


Similarly, when updating samples of the latent variables, we have