Straight-Through Estimator as Projected Wasserstein Gradient Flow

Straight-Through Estimator as
Projected Wasserstein Gradient Flow

Pengyu Cheng, Chang Liu, Chunyuan Li,
Dinghan Shen, Ricardo Henao and Lawrence Carin
Duke University, Tsinghua University, Microsoft Research

The Straight-Through (ST) estimator is a widely used technique for back-propagating gradients through discrete random variables. However, this effective method lacks theoretical justification. In this paper, we show that ST can be interpreted as the simulation of the projected Wasserstein gradient flow (pWGF). Based upon this understanding, a theoretical foundation is established to justify the convergence properties of ST. Further, another pWGF estimator variant is proposed, which exhibits superior performance on distributions with infinite support, e.g., Poisson distributions. Empirically, we show that ST and our proposed estimator, while applied to different types of discrete structures (including both Bernoulli and Poisson latent variables), exhibit comparable or even better performances relative to other state-of-the-art methods. Our results uncover the origin of the widespread adoption of ST estimator, and represent a helpful step towards exploring alternative gradient estimators for discrete variables.

1 Introduction

Learning distributions in discrete domains is a fundamental problem in machine learning. This problem can be formulated in general as minimizing the following expected cost


where is the cost function, is a discrete (latent) random variable whose distribution is parameterized by . Typically, is obtained as the output of a Neural Network (NN), whose weights are learned by backpropagating the gradients through discrete random variables .

In practice, direct gradient computations through the discrete random variables, suffers from the curse of dimensionality, since it requires traversing through all possible joint configurations of the latent variable, whose number is exponentially large w.r.t. the latent dimension. Due to this limitation, existing approaches resort to estimating the gradient by approximating its expectation, where Monte Carlo sampling methods are typically employed.

The Straight-Through (ST) estimator [14, 3] is a widely applied method due to its simplicity and effectiveness. The idea of ST is directly using the gradients of discrete samples as the gradients of the distribution parameters. Since discrete samples can be generated as the output of hard threshold functions with distribution parameters as input, Bengio et al [3] explain the ST estimator by set the gradients of hard threshold functions to . However, this explanation lacks theoretical justification for the gradients of hard threshold functions.

In this paper, we show that ST can be interpreted as simulating the projected Wasserstein gradient flow (pWGF) of a functional , where is a distribution in the target discrete distribution family with density parameterized by . Further, a more general optimizing scheme for (1) is introduced. Instead of directly updating in the discrete distribution family, is first updated to on a larger Wasserstein distribution space where gradients are easier to compute. Then, we project back to the discrete distribution family as the updated distribution. Moreover, the projection follows the descending direction of in , which justifies the effectiveness of ST. This pWGF based updating scheme also motivates another variant that achieves faster convergence when the desired family of distributions has infinite support, e.g., Poisson.

2 Proposed Algorithm

Denote as the -dimensional discrete distributions family parameterized by . With , the task (1) can be rewritten as


where is assumed to be differentiable. To solve (2), directly calculating the gradient is challenging, because the discrete distribution family is very restrictive on the gradients. Alternatively, if we relax the discrete constraint and perform updates in an appropriate larger space , the calculation of the gradient can be much easier. Therefore, as showed in Fig. 1, in -th updating iteration, we consider first updating the current distribution to with stepsize in a larger 2-Wasserstein space [24], then projecting back to as updated discrete distribution . Theorem C.2 in supplement guarantees that our updating scheme converges with a small enough step size .

Figure 1: Updating scheme Figure 2: Algorithm outline

With Wasserstein gradient flow (WGF) [24], we show (in Appendix) that, the gradient in larger space as , which means, if is represented by a group of its samples , then can be treated as a group of sample from . Therefore, we can update to along the WGF simply by updating its samples. To project back to as , we need to solve , which is equivalent to solve , where is the square of the 2-Wasserstein distance [24]. Consequently, our pWGF algorithm proceeds in 3 steps shown in Fig. 2: (A) draw samples from current distribution ; (B) update to as samples from ; (C) project back to by minimizing Wasserstein distance.

Since distributions in are multidimensional, the exact Wasserstein distance is difficult to derive. We make a standard assumption [8] that and are factorized distributions. With the assumption, we prove in Theorem 2.1 that minimizing Wasserstein distance between factorized distributions is equivalent to minimizing the marginal distance on every dimension. Therefore, for simplicity, we describe our projection step using one-dimensional distributions. As the updated distribution is implicit, we can not obtain the closed form of Wasserstein distance ,. Therefore, we consider two approximations of .

Theorem 2.1.

If -dimensional distributions and are factorized, then , where and are the marginal distributions of and respectively.

2.1 ST estimator: Absolute Difference of Expectation

We find that the Straight-Through (ST) estimator [3] is a special case of pWGF, when the Wasserstein distance is approximated via its lower bound, absolute difference of expectations.

Theorem 2.2.

For two one-dimensional distributions , the absolute difference between and is a lower bound of , i.e.


If and are Bernoulli, then , which means minimizing the expectation difference is equivalent to minimizing the 2-Wasserstein distance under Bernoulli cases.

For one-dimensional Bernoulli distribution, , noting that and , we approximate the parameter gradient by: To reduce the variance caused by the sample mean, we use the control variate method [4] and write Thus, we have derived the pWGF estimator with expectation difference approximation, which has the same form as a multi-sample version ST estimator [3]. Parameter gradients for Poisson and Categorical distributions can be derived in a similar way.

2.2 Proposed estimator: Maximum Mean Discrepancy

A more principled way to approximate the Wasserstein distance is to use Maximum Mean Discrepancy (MMD) [12]: , where is a selected kernel. In practice, instead of minimizing , we can minimize the empirical expectation . Details on parameter gradients are shown in the supplement.

3 Experiments

We demonstrate the advantage of pWGF on updating Poisson distributions, and show the benchmark performance with a binary latent model in the supplement. Since the only difference between our pWGF version ST and the original ST is the learning rate scalar, if not specifically mentioned, we call pWGF-ST or the original ST together as ST, and call our MMD version method as pWGF.

3.1 Poisson Parameter Estimation

We apply pWGF to infer the parameter of a one-dimensional Poisson distribution. We use the true distribution to generate data samples , and use a Generative Adversarial learing framework to learn model parameters. A generator is constructed as . A discriminator is a network used to distinguish true/fake samples, which outputs the probability that the data comes from the true distribution. During the adversarial training, the generator aims to increase , while the discriminator tries to decrease and increase . We can rewrite the training process as a min-max game with objective function: Similar to the observation in [10], the training process should finally converges to . Therefore, for the generator, learning becomes optimizing . We compare our pWGF against ST, Reinforce and Muprop [13] and show the learning curves on estimation in Figure 3. pWGF converges faster than others and exhibits much smaller oscillation. In Table 3.1, We report the mean and the standard derivation of the inferred parameter after training epochs, where our pWGF exhibits higher inference accuracy and lower variance. Figure 3: Learning curves of Poisson parameter. Table 1: Mean and Standard Derivation of Inference Mean Std pWGF 5.0076 0.013 ST 5.1049 0.161 Muprop 5.0196 0.159 Reinforce 4.9452 0.173

4 Conclusion

We presented a theoretical foundation to justify the superior empirical performance of Straight-Through (ST) estimator for backpropagating gradients through discrete latent variables. Specifically, we show that ST can be interpreted as the simulation of the projected gradient flow on Wasserstein space. Based upon this theoretical framework, we further propose another gradient estimator for learning discrete variables, which exhibits even better performance while applied to distributions with infinite support, e.g., Poisson.


  • [1] L. Ambrosio, N. Gigli, and G. Savaré (2008) Gradient flows: in metric spaces and in the space of probability measures. Springer Science & Business Media. Cited by: §A.2, Appendix B, Appendix B, Appendix B.
  • [2] J. Benamou and Y. Brenier (2000) A computational fluid mechanics solution to the monge-kantorovich mass transfer problem. Numerische Mathematik 84 (3), pp. 375–393. Cited by: Appendix B, Appendix B.
  • [3] Y. Bengio, N. Léonard, and A. Courville (2013) Estimating or propagating gradients through stochastic neurons for conditional computation. arXiv preprint arXiv:1308.3432. Cited by: §A.1, §1, §2.1, §2.1.
  • [4] P. P. Boyle (1977) Options: a monte carlo approach. Journal of financial economics 4 (3), pp. 323–338. Cited by: §2.1.
  • [5] C. Chen, C. Li, L. Chen, W. Wang, Y. Pu, and L. Carin (2017) Continuous-time flows for deep generative models. arXiv preprint arXiv:1709.01179. Cited by: §A.2.
  • [6] C. Chen, R. Zhang, W. Wang, B. Li, and L. Chen (2018) A unified particle-optimization framework for scalable bayesian sampling. UAI submission. Cited by: §A.2.
  • [7] C. Chen and R. Zhang (2017) Particle optimization in stochastic gradient mcmc. arXiv preprint arXiv:1711.10927. Cited by: §A.2.
  • [8] X. Chen, Y. Duan, R. Houthooft, J. Schulman, I. Sutskever, and P. Abbeel (2016) Infogan: interpretable representation learning by information maximizing generative adversarial nets. In Advances in neural information processing systems, pp. 2172–2180. Cited by: §2.
  • [9] C. Doersch (2016) Tutorial on variational autoencoders. arXiv preprint arXiv:1606.05908. Cited by: Appendix A.
  • [10] I. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville, and Y. Bengio (2014) Generative adversarial nets. In Advances in neural information processing systems, pp. 2672–2680. Cited by: §3.1.
  • [11] W. Grathwohl, D. Choi, Y. Wu, G. Roeder, and D. Duvenaud (2017) Backpropagation through the void: optimizing control variates for black-box gradient estimation. arXiv preprint arXiv:1711.00123. Cited by: Appendix E.
  • [12] A. Gretton, K. M. Borgwardt, M. Rasch, B. Schölkopf, and A. J. Smola (2007) A kernel method for the two-sample-problem. In Advances in neural information processing systems, pp. 513–520. Cited by: §2.2.
  • [13] S. Gu, S. Levine, I. Sutskever, and A. Mnih (2015) MuProp: unbiased backpropagation for stochastic neural networks. arXiv preprint arXiv:1511.05176. Cited by: §3.1.
  • [14] G. Hinton (2012) Neural networks for machine learning, video lectures. Coursera. Cited by: §A.1, §1.
  • [15] E. Jang, S. Gu, and B. Poole (2017) Categorical reparametrization with gumble-softmax. In International Conference on Learning Representations 2017, Cited by: §A.1, Appendix E.
  • [16] D. P. Kingma and M. Welling (2014) Auto-encoding variational bayes. stat 1050, pp. 1. Cited by: §A.1, Appendix E.
  • [17] C. Liu, J. Zhuo, P. Cheng, R. Zhang, J. Zhu, and L. Carin (2018) Accelerated first-order methods on the wasserstein space for bayesian inference. arXiv preprint arXiv:1807.01750. Cited by: §A.2.
  • [18] Q. Liu and D. Wang (2016) Stein variational gradient descent: a general purpose bayesian inference algorithm. In Advances In Neural Information Processing Systems, pp. 2378–2386. Cited by: §A.2.
  • [19] Q. Liu (2017) Stein variational gradient descent as gradient flow. In Advances in neural information processing systems, pp. 3118–3126. Cited by: §A.2.
  • [20] S. Mukherjee, Q. Wu, D. Zhou, et al. (2010) Learning gradients on manifolds. Bernoulli 16 (1), pp. 181–207. Cited by: Theorem C.2, Appendix C.
  • [21] F. Otto (2001) The geometry of dissipative evolution equations: the porous medium equation. Cited by: Appendix B, Appendix B.
  • [22] D. J. Rezende, S. Mohamed, and D. Wierstra (2014) Stochastic backpropagation and approximate inference in deep generative models. In International Conference on Machine Learning, pp. 1278–1286. Cited by: §A.1.
  • [23] G. Tucker, A. Mnih, C. J. Maddison, J. Lawson, and J. Sohl-Dickstein (2017) Rebar: low-variance, unbiased gradient estimates for discrete latent variable models. In Advances in Neural Information Processing Systems, pp. 2627–2636. Cited by: Appendix E.
  • [24] C. Villani (2008) Optimal transport: old and new. Vol. 338, Springer Science & Business Media. Cited by: §A.2, Appendix B, Appendix B, §2, §2.
  • [25] M. Yin and M. Zhou (2018) ARM: augment-reinforce-merge gradient for discrete latent variable models. arXiv preprint arXiv:1807.11143. Cited by: Appendix E.
  • [26] R. Zhang, C. Chen, C. Li, and L. Carin (2018) Policy optimization as wasserstein gradient flows. arXiv preprint arXiv:1808.03030. Cited by: §A.2.

Appendix A Background

To minimize the expected cost in (1), we assume that , if the cost function depends of . For instance, in the Variational Autoencoder (VAE) [9], we seek to maximize the Evidence Lower Bound (ELBO) as , where depends on parameter through the variational posterior approximation . Since , we have .

As described above, there are two types of updating methods for under (1), namely, estimation of the parameter gradient , and continuous relaxation of the discrete variable .

a.1 Continuous relaxation

Another approach used to obtain updates for in (1) is to approximate samples of from a deterministic function, , of and an independent random variable with simple distribution , e.g., uniform or normal, so . Then we can use the chain rule to derive the gradient of (1) as

We can take expectation of the gradients, which is very convenient because can be computed by chain rule, noting that does not directly depend of . This reparameterization trick works quiet well when originates from a continuous distribution. For example, given a normal distribution, , we can rewrite and directly obtain and . This reparameterization has been widely used in the training of variational autoencoder with latent Gaussian priors [16, 22].

In the discrete case, it becomes very difficult to find a differentiable deterministic function to generate samples from . For the categorical distribution, [15] introduced the Gumbel-Softmax distribution to relax the one-hot vector encoding commonly used for categorical variables. For the multidimensional (factorized) Bernoulli distribution with parameter , the Straight Through (ST) estimator [14, 3], which considers the gradient of samples of directly, as the gradient of parameter , can be also explained by setting the derivative of the discrete function (coordinate-wise) directly to the identity matrix [3].

a.2 Wasserstein gradient flow

Wasserstein gradient flows (WGF) [6, 1, 24] have become popular in machine learning, due to its generality over parametric distribution families, and tractable computational efficiency. The Wasserstein space is a metric space of distributions. The WGF defines a family of steepest descending functions. It has been Bayesian inference, where the KL divergence of an approximating distribution to a target one is minimized by simulating its gradient flow. [6] developed a unfnied framework to simulate the WGF, including Stein Variational Gradient Descent (SVGD) [18, 19] and Stochastic Gradient MCMC as its special cases. [7] and [17] proposed an acceleration framework for these methods. WFGs have also been applied to deep generative models [5] and policy optimization in reinforcement learning [26]. However, all previous methods focus on simulating WGFs to approximate distributions in continuous domains. There has been little if any research reported for WGFs for discrete domains.

Appendix B Updating via Wasserstein gradient flow

Gradient computation and Wasserstein Gradient Flow (WGF) simulation are made possible by the Riemannian structure of , which consists of a proper inner product in the tangent space that is consistent with the Wasserstein distance [2, 21]. The tangent space of at can be represented by a subspace of vector fields on ([24], Thm 13.8; [1], Thm 8.3.1, Prop 8.4.5):

where is the set of compactly supported smooth functions on , is a Hilbert space with inner product , and the overline represents taking the closure in .

With the inner product inherited from , being a Riemannian manifold is consistent with the Wasserstein distance due to the Benamou-Brenier formula [2]. We can then express the gradient of a function on in the Riemannian sense. The explicit expression is intuitively proposed as Otto’s calculus ([21]; [24], Chapter 15) and rigorously verified by subsequent work, e.g., [24], Thm 23.18; [1], Lem 10.4.1. Specifically, they showed that given a functional with , its gradient is , a vector field on . This means that we can, in principle, compute the desired gradient using .

Another convenient property of based on the physical interpretation of tangent vectors on makes the gradient flow simulation possible. Consider a smooth curve of absolutely continuous measures, , with corresponding tangent vector , where , and for which the gradient flow is simulated (iteratively) at discrete values , to estimate (the target distribution). For any and , Proposition 8.4.6 of [1] guarantees that , where is a transformation on ( is the identity map and is a vector field on ), and is the pushed-forward measure of that moves along the tangent vector by distance , see Figure 1. When is a gradient flow (steepest descending curve) of defined in the form above, , as described before, then for having a set of samples and the definition of pushed-forward measure [1], is a set of samples of , which conform a first-order approximation of . Since is a good approximation of (the optimal measure along the WGF) as discussed above, thus we can use to approximate . This is done by projecting onto . Then, per Theorem C.2, with small enough positive , we can always get a set of samples whose distribution improves , the functional of the cost in (1).

Appendix C Proofs

Theorem C.1.

Let be a differentiable function on a manifold and a submanifold of , , then at any ,

where is the projection of onto .

Proof of Theorem c.1.

By the definition of [20], for any vector ,


By the definition of , for any vector ,


Since is the subspace of , by definition of we have


By (3), (4), (5), for any

Therefore, . ∎

Theorem C.2.

Let and be the 2-Wasserstein distance in . Update in along direction to (exponential map [20]), then project back to as . If is Lipschitz continuous, then there exists , such that for any , .

Proof of Theorem 2.1.

(1) First, we show that .

Arbitrarily selecting , , we define . Since , we have


which means the marginal distribution of on is . Similarly, the marginal distribution of on is . Therofore, . Then


On the other hand,


By (7) and (8), we have


Take the infimum over both sides of the equation (9),


(2) Then we show

Note that


where is the marginal distribution of over .

By Fubini’s Theorem,


Similarly, . Therefore, . Then


By (11) and (13),


Take infimum over both sides,

Therefore, . ∎

Proof of Remark 2.2.

For and ,


where , , and , .

Problem in (15) is a linear programming. It can be shown easily that the minimum value of (15) is . ∎

Lemma C.3.

Let be an arbitrary distribution and be a Bernoulli distribution. Then


where .

Appendix D Gradient For MMD Projection

We take the radial basis function kernel for instance.

For Bernoulli distribution, ,

Appendix E Binary Latent Models

As most of previous proposed algorithms are specifically designed for the discrete variables with finite support, we consider using a binary latent model as the benchmark. We use variational autoencoder (VAE) [16] with the Bernoulli latent variable (Bernoulli VAE). We compare pWGF with the baseline methods ST and Gumbel-Softmax [15], as well as three state-of-the-art algorithms: Rebar [23], Relax [11] and ARM [25]. Following the settings in [25], we build the model with different network architectures. We apply all methods and architectures to the MNIST dataset, and show the results in Table 2. From the results, pWGF is comparable with ST, and both pWGF/ST outperform other competing methods except ARM in all tested network architecture.

Linear 119.8 119.1 110.3 122.1 123.2 129.2
Two Layers 108.3 107.6 98.2 114 113.7 NA
Nonlinear 104.6 104.2 101.3 110.9 111.6 112.5
Table 2: Testing ELBO for Bernoulli VAE on MNIST
Comments 0
Request Comment
You are adding the first comment!
How to quickly get a good reply:
  • Give credit where it’s due by listing out the positive aspects of a paper before getting into which changes should be made.
  • Be specific in your critique, and provide supporting evidence with appropriate references to substantiate general statements.
  • Your comment should inspire ideas to flow and help the author improves the paper.

The better we are at sharing our knowledge with each other, the faster we move forward.
The feedback must be of minimum 40 characters and the title a minimum of 5 characters
Add comment
Loading ...
This is a comment super asjknd jkasnjk adsnkj
The feedback must be of minumum 40 characters
The feedback must be of minumum 40 characters

You are asking your first question!
How to quickly get a good answer:
  • Keep your question short and to the point
  • Check for grammar or spelling errors.
  • Phrase it like a question
Test description