Closed-form Expressions for Maximum Mean Discrepancy with Applications to Wasserstein Auto-Encoders

Closed-form Expressions for Maximum Mean Discrepancy with Applications to Wasserstein Auto-Encoders

Raif M. Rustamov, AT&T Labs Research, Bedminster, NJ
Abstract

The Maximum Mean Discrepancy (MMD) has found numerous applications in statistics and machine learning, most recently as a penalty in the Wasserstein Auto-Encoder (WAE). In this paper we compute closed-form expressions for estimating the Gaussian kernel based MMD between a given distribution and the standard multivariate normal distribution. We introduce the standardized version of MMD as a penalty for the WAE training objective, allowing for a better interpretability of MMD values and more compatibility across different hyperparameter settings. Next, we propose using a version of batch normalization at the code layer; this has the benefits of making the kernel width selection easier, reducing the training effort, and preventing outliers in the aggregate code distribution. Finally, we discuss the appropriate null distributions and provide thresholds for multivariate normality testing with the standardized MMD, leading to a number of easy rules of thumb for monitoring the progress of WAE training. Curiously, our MMD formula reveals a connection to the Baringhaus-Henze-Epps-Pulley (BHEP) statistic of the Henze-Zirkler test and provides further insights about the MMD. Our experiments on synthetic and real data show that the analytic formulation improves over the commonly used stochastic approximation of the MMD, and demonstrate that code normalization provides significant benefits when training WAEs.

1 Introduction

The Maximum Mean Discrepancy (MMD) is a measure of divergence between distributions [mmd] which has found numerous applications in statistics and machine learning; see the recent review [MMD_review] and citations therein. MMD has a well-established theory, based on which a number of approaches are available for computing the thresholds for hypothesis testing, allowing to make sense of the raw MMD values; however, the whole process can be somewhat intricate. Given the increasing adoption, it is desirable to have closed-form expressions for the MMD so as to make it more accessible to a general practitioner and to streamline its use. Additionally, since the raw MMD values are hard to interpret, it would be important to convert MMD to a more intuitive scale and provide some easy to remember thresholds for testing and evaluating model convergence.

To focus the paper, we will concentrate on an application of the MMD in the context of Wasserstein Auto-Encoders, which we now review. MMD quickly entered the neural network arena as a penalty/regularization term in generative modeling—initially within the moment-matching generative networks [GMMN1, GMMN2] and later on as a replacement for the adversarial penalty in Adversarial Auto-Encoders [AdversAE] leading to the MMD version of Wasserstein Auto-Encoders [InfoVAE, vegan_cookbook, WAE]. These WAE-MMDs, to which we will refer simply as WAEs, use an objective that in addition to the reconstruction error includes an MMD term that pushes the latent representation of data towards some reference distribution. Similarly to Variational Auto-Encoders [VAE], WAEs can be used to generate new data samples by feeding random samples from the reference distribution to the decoder. By making a fundamental connection to optimal transport distances in the data space, [vegan_cookbook, WAE] establish theory proving the correctness of this generative procedure.

Already in the context of WAEs there has been an effort to replace the MMD with closed-form alternatives. For example, Tabor et al. [Cramer_Wold_AE] introduce the Cramer-Wold Auto-Encoders inspired by the slicing idea of [kolouri2018sliced]. While their Cramer-Wold distance has a closed-form expression, it depends on special functions unless one uses an approximation. In addition, similarly to the situation with the MMD, the raw values of the Cramer-Wold distance are not directly interpretable.

In this paper, we carry out the analytical computation of the MMD in a special case where the reference distribution is the standard multivariate normal and the MMD kernel is a Gaussian RBF. We are also able to compute the variance of the MMD in closed-form, which allows us to introduce the standardized version of the MMD. This version has the advantage of being more amenable to direct interpretation, which is demonstrated by a number of easy to remember rules of thumb suitable for model evaluation and hypothesis testing. As a curious development, our MMD formula reveals a relationship to the Baringhaus-Henze-Epps-Pulley (BHEP) statistic [EP, BH] and the Henze-Zirkler test [HZ], which also allows making a connection to the Cramer-Wold distance.

Focusing on WAEs as an application, we discuss the use of the closed-form standardized MMD as a penalty in the WAE training objective. Estimating the MMD the usual way requires sampling both from the latent code and the target reference distributions. The latter sampling incurs additional stochasticity which has an immediate effect on the gradients for training; using the analytic formula for the MMD essentially integrates out this extra stochasiticty. We also argue that standardization of the MMD induces better compatibility across different hyperparameter settings, which can be advantageous for model selection. As another contribution, we propose using code normalization— a version of the batch normalization [batchnorm] applied at the code layer—when training WAEs. This has the benefits of making the selection of width for the MMD kernel easier, reducing the training effort, and preventing outliers in the aggregate latent distribution.

The paper is organized as follows. Section 2 provides closed-form expressions for the MMD and its variance. In Section 3, we discuss the standardized MMD and code normalization in the context of WAE training. Section 4 discusses thresholds for hypothesis testing, and their application to monitoring the WAE training progress. Section 5 provides an empirical evaluation on synthetic and real data. The derivations of the formulas and relevant code are provided in the appendix.

2 Closed-Form Expressions for MMD

The maximum mean discrepancy is a divergence measure between two distributions and . In the context of WAEs, applying the encoder net to the distribution of the input data (e.g. images) yields the aggregate distribution of the latent variables. One of the goals of WAE training is to make (which depends on the neural net parameters) as close as possible to some fixed target distribution . This is achieved by incorporating MMD between and as a regularizer into the WAE objective.

The computation of the MMD requires specifying a positive-definite kernel; in this paper we always assume it to be the Gaussian RBF kernel of width , namely, . Here, , where is the dimension of the code/latent space, and we use to denote the norm. The population MMD can be most straight-forwardly computed via the formula [mmd]:

(2.1)

In this paper, the target reference distribution is always assumed to be the standard multivariate normal distribution with the density , .

In practical situations, we only have access to through a sample. For example, during each step of the WAE training, the encoder neural net will compute the codes corresponding to the input data in the batch (we use “batch” to mean “mini-batch”) and the current values of neural network parameters. Given this sample from , our goal is to derive a closed-form estimate of .

In this section, we first consider deterministic encoders and derive an analytic formula for an unbiased estimator of the MMD and its variance. Next, we discuss the biased estimator and its connection to BHEP statistics. Finally, we derive a formula for the estimator of the MMD in the case of random encoders.

2.1 Deterministic Encoders

Unbiased Estimator

We start with the expression Eq. (2.1) and using the sample , we replace the last two terms by the sample average and the U-statistic respectively to obtain the unbiased estimator:

(2.2)

This quantity is denoted by in [mmd]; our slightly different notation allows including the sample size as a subscript. In Appendix A.1 we show that the expectations in this expression can be computed analytically to yield the formula

This formula for reveals two forces at play when optimizing to have small divergence from the standard multi-variate normal distribution. One force is pulling the sample points towards the origin, and the other is pushing them apart from each other. Namely, we can see that the second term encourages the sample points to be as close as possible to the origin so as to make the exponentials as large as possible. If not for the third term, all of the points would have collapsed onto the origin making the exponentials equal to . However, the third term introduces repealing forces between the sample points, and pushes them away from each other. It is interesting to note that the second and third terms have different widths for the Gaussian kernels. Another observation is that one can compute the optimal translation transform for a given sample, and surprisingly it is not the one that places the center of mass at the origin. In fact, during this shift optimization the third term stays constant, and the second term can be interpreted (up-to a constant factor) as a kernel density estimate with the kernel width of . The optimal shift is the one that places the mode of this density estimate at the origin.

Since its computation involves taking a random sample from , we see that is a random variable. Thus, even when , the estimator will not be identically zero. It is important to understand the behavior of this random variable; using the hypothesis testing terminology, we refer to this as the distribution of under the null—the null hypothesis being . By unbiasedness, we have that the null mean is zero:

where the expectation is over various realizations of the sample from . This immediately means that in contrast to , the estimator can take negative values.

Next, we would like to obtain the variance under the null. First, we rewrite by defining,

and noting that

Now according to [mmd, Appendix B.3 ] we have

This expression can be computed in a closed form using manipulations similar to those used for computing . Since the mean of under the null is , the null variance is equal to the second moment, and we obtain the formula,

(2.3)

Biased Estimator and BHEP Statistic

The biased estimator from [mmd] can be computed in closed form in a similar manner. The only difference is the use of the V-statistic for the third term in Eq. (2.1); the final expression is as follows:

Interestingly, this expression is equivalent to a statistic proposed for testing multivariate normality, its history going back to as early as 1983. The Baringhaus-Henze-Epps-Pulley (BHEP) statistic is named after the authors of [EP, BH] as coined by [Csorgo]. This statistic is used in the Henze-Zirkler test of multivariate normality [HZ]. We give a quick review of this connection since it provides some useful insights about the MMD.

The BHEP statistic is a measure of divergence between two distributions and that captures how different their characteristic functions are. It is defined as the weighted -distance:

where and are the characteristic functions of the distributions and , and is a weight function. When is the multivariate normal distribution , we have . Selecting the weight function to be and noting that is available through a sample , the BHEP statistic takes the following form:

Here is the empirical characteristic function of ,

A closed-form formula for can be obtained (see e.g. [HZ, HenzeWagner1997]) and it coincides with the expression for when one sets the Gaussian RBF kernel width .

This connection has a number of useful consequences. Henze and Zirkler [HZ] show that BHEP statistic can be equivalently obtained as the -distance between kernel density estimates; in our context, this is a concrete example of the connection described in [mmd, Section 3.3.1]. Based on this equivalence, [HZ] suggests using a specific value of from optimal density estimation theory. The corresponding is

(2.4)

we will refer to this setting as HZ in our experimental section. We also note that the one-dimensional distance used in the definition of the Cramer-Wold distance [Cramer_Wold_AE] is based on exactly the same -distance between kernel density estimates. As a result, we see that the Cramer-Wold distance is the integral of over all one-dimensional projections of . Of course, by similarly integrating instead, one could introduce a new version of the Cramer-Wold distance that is zero centered under the null.

On a conceptual level, this connection allows transferring some insights about the BHEP statistic to the MMD. For example, inspecting the relationship between the MMD and the characteristic function formulation of the BHEP statistic, we see that this formulation more transparently expresses the fact that MMD is performing moment matching. As another example, [HZ] makes the following qualitative observation: “Choosing a small value of entails that the weight function puts most of its mass near the origin of . Since the tail behavior of a probability distribution is reflected by the behavior of its characteristic function at the origin, the test should be sensitive against alternative distributions with heavy tails.” This intuition is made concrete by studying the limiting behavior of BHEP statistic in [Henze_limiting] and making connections to Mardia’s kurtosis and skewness statistics[Mardia], with a summary provided in [HenzeWagner1997]. When translated to our setting, this means that MMD with a large value of is useful for distinguishing distributions that have heavy tails.

Remark: We leave out the computation of the null mean and variance of the ; this can be carried out similarly to . Note that the mean and variance of are computed in closed-form in [HZ]. However, these expressions are based on a different null hypothesis (composite null in Section 4) and have corrections for nuisance parameter estimation. Thus, they should not be used for standardization in Section 3.1.

2.2 Random Encoders

In this section we consider Gaussian random encoders, where instead of one code per input data point, we obtain a distribution of codes given as . Here is the batch size, is a diagonal covariance matrix, . Both mean vectors and variance vectors are computed by applying neural nets to the input data. Our goals is once again to obtain an estimator for .

Note that the implied distribution of for the current batch is an equally weighted mixture of Gaussians with the distribution given by:

where is the -th component of the vector . We will replace sampling from in the formula Eq. (2.1), by sampling from , and compute the second and third terms in a closed form. Note that the first term depends only on and will be the same as before; the computation of the remaining terms is demonstrated in Appendix A.2, and yields the following unbiased estimator:

When the noise is isotropic, namely with (note that was a vector in the general case above, but here it is a single number), we can rewrite this formula in a simpler form:

Note that setting the variances gives rise to the deterministic encoders where , and the resulting estimator is the same as and not . The difference is that the last term in the unbiased deterministic estimator includes an average over distinct pairs , whereas for the unbiased random estimator the average runs over all pairs . The latter is appropriate here because when , in Eq. (2.1) one can sample independently from the same component of the Gaussian mixture. Doing so in the deterministic case would have resulted in a biased estimate: essentially instead of the U-statistic we would have gotten the upwards biased V-statistic.

3 WAE Training

3.1 Standardized MMD Penalty

In the original formulation of the WAE, the MMD penalty enters the objective as the term , and an estimate is computed by sampling from both and . Obviously, the closed-form formulas for the MMD presented in the previous section can be used instead. In addition, we suggest standardizing the MMD values to make them more directly interpretable.

We start by defining the standardized MMD by applying centering and scaling to the unbiased estimator. Since the mean of under the null is zero, and variance under the null is as given by Eq. (2.3), we define for a batch of codes sampled from :

A Python implementation of this formula is provided in Appendix A.4. Note that this definition can be used both for deterministic and random encoders, with the only difference being in the formula for the MMD estimator in the numerator. Now, the WAE objective looks as follows,

Having a closed-form formula for the SMMDU both in deterministic and random case is advantageous for optimization. Computing this penalty the usual way [mmd, WAE, InfoVAE] relies on taking a sample from both and . As a result, this incurs additional stochasticity due to the sampling from . Our formula essentially integrates out this stochasticity, and results in an estimator with a smaller variance. This allows better discrimination between distributions (see Section 5), and, as a result, potentially provides higher quality gradients for training. In some settings, there can also be a computational benefit to using our formula as it requires the computation of distance matrices in contrast to the ones required by the sampling approach.

While at the theoretical level the suggested scaling can be equivalently seen as a re-definition of the regularization coefficient, yet it has a number of benefits in practice. The main advantage of using the SMMDU is that it is amenable to quick inspection when one wants to have a sense of how far the current code distribution is from the target normal multivariate distribution. Indeed, in contrast to the raw MMD numbers, there are easy to remember rules of thumb for SMMDU values that can be used when monitoring the WAE training progress; these will be provided in Section 4.

Figure 1: Variance as a function of kernel width and latent dimensionality . Batch size is fixed to .

Next, before discussing other benefits of the standardization, we use Figure 1 to give an insight on the behavior of the scaling term in the denominator of the SMMDU expression. This figure depicts the graphs of for different values of the kernel width and the latent dimension , for a fixed batch size . We can clearly see that the scaling varies widely not only across dimensions, but also across kernel width choices at a fixed dimension.

The use of the SMMDU as a penalty in the WAE objective is potentially beneficial for model selection.The choice of the best hyperparameters is usually carried out via cross-validation which among others things includes trying out different values of the penalty coefficient , kernel width , and latent dimension . Without the proposed scaling of the MMD term, the values of are not universal across the choices of and , which makes cross-validation more difficult. For example, if a small list of ’s is used when cross-validating, then disparate regions of the optimization space would be considered across the choices of and , perhaps resulting in a suboptimal model being chosen.

Our scaled formulation can also be beneficial for the commonly used trick of combining kernels of different widths in order to boost the performance of the MMD metric and to avoid search over the kernel width. This is equivalent to adding more MMD terms as a penalty, for example, using a penalty of the form . However, when such a combination is performed without the proposed standardization, then MMDs coming from different kernel width choices can be of different orders of magnitude. As a result, one may end up with a single kernel width dominating. In fact, the common choice of including kernels having together with the ones that have or would lead to this issue as can be seen from Figure 1. One can see that this observation is also relevant in cases where is set adaptively per batch, this time leading to various amounts of penalty being applied to each batch.

3.2 Code Normalization

In this subsection we propose to apply a variant of batch normalization [batchnorm] on top of the code layer: for each batch, we center and scale the codes so that their distribution has zero mean and unit variance in each dimension; we will refer to this as “code normalization”. Importantly, no scaling or shifting is applied after normalizing (i.e. and in the notation of [batchnorm]) as the decoder network expects a normally distributed input. For deterministic encoders this version of batch normalization is readily available in the existing neural net packages. For example, in Keras [keras] this would be the layer BatchNormalization(center=False, scale=False). We note, however, that the existing implementations often use the biased estimate of the sample variance which leads to some distortion in the long run; this can be easily remedied by multiplying the variance by .

Random encoders require a separate treatment of the mean and variance network outputs. Namely, using the notation of Section 2.2, code normalization is given in coordinate-wise manner by

where subscript is used to refer to the -th coordinate, , and . Since is a mixture of Gaussians, closed form expressions for mean and variance are available:

During inference, one uses the population statistics to normalize the codes. This can be achieved by processing multiple mini-batches from the training data at once to compute the required means and standard deviations of the code layer. Another option is to keep a running exponential average of the mean and variance and use them for normalization—this is what the Keras implementation does.

Below under separate headings we discuss the benefits of code normalization for the WAE training; we will use the term “MMD penalty” to refer to any kind of penalty based on MMD, including SMMDU.

Easier Kernel Width Selection

One advantage of code normalization is that a single setting of the width for the Gaussian RBF kernel, , can be used when computing the MMD penalty. Without code normalization, a fixed choice of leads to issues. For example, when is small, and the codes are far away from the origin and from each other, the MMD penalty term has small gradients, which makes learning difficult or even impossible. Indeed, the exponentials become vanishingly small, and since they enter the gradient multiplicatively this makes the gradients small as well. The same issue arises when choosing a large value of when the codes are not far away from the origin. Thus, one has to use an adaptive choice of in order to deal with this problem, see e.g. [WAE] and also our Appendix A.4 for one such particular choice. On the other hand, in the long run, code normalization makes sure that the codes have commensurate distances with throughout the training process, alleviating the need for an adaptive . This makes possible to decouple the choice of from the neural network training and to provide practical recommendations as we do in Section 5.

Reduced Training Effort

Figure 2: Code normalization shifts the distribution of to lower values. Here, , , and kernel scale .

Code normalization shifts and scales codes to be in the “right” part of the space, namely where the target standard multivariate normal distribution lives, and we speculate that this reduces the training effort. The intuition comes from inspecting the relationship between the MMD and the characteristic function formulation of the BHEP statistic. This formulation expresses the fact that at some level MMD is performing moment matching, and so by rendering the first two moments (marginal) equal to those of the standard multivariate normal distribution, code normalization focuses the training effort on matching the higher moments.

To illustrate this point, Figure 2 shows the distribution of values for samples of size taken from . The value of is computed for the original sample and then for the sample to which code normalization was applied. Note that on average the normalized codes have smaller MMD values compared to the original ones. See Section 4 for more about this distribution shift together with an even stronger shift when samples are whitened; in some sense, normalized samples are more “ideal” from the point of the view of the MMD. This means that even if the neural network has converged to the target normal distribution, the gradient for a batch will not be zero but will have components in the direction of shifting and scaling the codes to reduce the MMD for a given batch. Code normalization directly takes care of this reduction, and allows the training process to spend its effort on improving the reconstruction error. Technically, this is achieved by projecting out the components of the gradient corresponding to shifting and scaling which is automatically achieved by normalization, see [BatchRenorm, Section 3, penultimate paragraph].

This observation reveals an interesting aspect of training with the MMD as compared to training in an adversarial manner [AdversAE]. When training in an adversarial manner, the goal is to make the codes in each batch resemble a sample from the standard multivariate normal distribution. At an intuitive level, we expect this would happen with the MMD penalty as well. However, this is not the case—we see that, on average, the MMD penalty considers normalized/whitened samples more “ideal” than the actual samples from the target distribution. Luckily, the neural network cannot learn batch-wise operations (e.g. it cannot learn to do batch-wise normalization or whitening by itself) assuming that at inference time the inputs are processed independently of each other. As a result, this phenomenon will not prevent convergence to the target distribution. A rigorous argument follows from unbiasedness, where the expectation is taken over i.i.d. samples and equality holds only at convergence to the target distribution; this makes any overall shift to the left at the inference time impossible.

Avoiding Outliers

Another benefit of code normalization is that it provides a solution to outlier insensitivity problem of the MMD penalty, described below. Indeed, scaling by the standard deviation (rather than by a robust surrogate) controls the tail behavior of the code distribution. Due to this control, the code distribution ends up having a light tail and no code falls too far away from the origin.

The outlier insensitivity problem is not specific to our closed-form formula or the choice of the kernel (see Section 5 for an empirical verification); this problem is relevant to any kernel such that as .

Given a sample from the standard multi-variate normal distribution , consider a modified sample and is far from the origin. Expressing the sum of vanishingly small exponentials via the -notation, we can compute the difference in incurred by this change:

(3.1)

Note that the second term is the sample average approximation of . This expectation can be computed analytically (in fact it is equivalent to the summand in the second term of Eq. (2.2)) and it precisely cancels the first term here in Eq. (3.1), giving . Thus, MMD changes very little despite the presence of the large outlier.

Given the mixed objective and stochasticity inherent in the training process, this issue has an effect on WAE training even before reaching the limits of computer precision. Indeed, in addition to the MMD penalty, the WAE objective contains the reconstruction term. Given the incentive to reconstruct well, the optimizer will realize that it is beneficial to push some of the codes far away from the origin, since the origin is where most of the codes concentrate. If this happens only for a few codes in a batch, the MMD penalty will not be big enough so as to pull these codes back towards the origin. As a result, the training process will result in a distribution that has outliers. Our experiments show that the proposed code normalization provides a solution to this issue without a need for using adaptive kernel widths or extra penalties.

4 Hypothesis Tests for Multivariate Normality

In this section we discuss hypothesis testing using and provide thresholds that can be useful when monitoring progress of WAE code distribution convergence on a single and multiple batch levels. Our initial discussion is set in a broader manner so as to encompass general testing for multivariate normality.

We quickly review the hypothesis testing setting following [HZ] with some notational changes. Let be i.i.d. random vectors from some underlying distribution. The problem is to test the hypothesis that the underlying distribution is a non-degenerate -variate normal distribution: , for some mean vector and non-degenerate covariance matrix . Note that the population mean vector and covariance matrix are not known.

The test of multivariate normality proceeds as follows. Let be the sample mean, and be the sample covariance matrix. Assuming non-degeneracy, define the centered and whitened vectors

Now, the task of testing multivariate normality of reduces to the simpler problem of testing whether the underlying distribution of is .

While Henze-Zirkler test [HZ] carries out this last step by using the BHEP statistic, it can also be achieved by using the . One computes the statistic for the sample and checks whether it is above the test threshold, and if so, the null hypothesis gets rejected. The most straightforward way to compute the threshold is to run a Monte Carlo simulation: sample from the null distribution , and compute the corresponding value; repeat this many times to obtain the empirical sampling distribution of the statistic and use the -th percentile as the threshold for the -level test.

However, this approach is problematic due to the treatment of the nuisance parameters and : the same sample is used both for estimating mean and covariance, and then for testing (this is somewhat like training and testing on the same data). The most apparent consequence is that one introduces dependencies within , namely, and , rendering it no longer an i.i.d sample. Thus, when using the Monte Carlo approach with sampled directly from we would end up with a wrong null distribution and, so, with the wrong thresholds. Henze-Zirkler test [HZ] uses appropriate corrections to account for the nuisance parameters when computing the moments under the null. These moments are then used to obtain a log-normal approximation to the null distribution. A similar path can be potentially taken with the statistic, but for simplicity we will explain how to correct the issue with Monte Carlo sampling.

To fix the problem, the computation of the null distribution should proceed from samples that satisfy the dependency relationships mentioned above. Fortunately, constructing such samples is easy: we sample from , then apply centering by the mean and whitening by the sample covariance matrix. The resulting sample satisfies the relationships and . This centered and whitened sample is used to compute the values and to obtain the thresholds. To prove the correctness of this procedure one has to show that there is a measure preserving and test statistic preserving one-to-one mapping between these samples originating from and samples if they were to originate from with the true and . Using the non-degeneracy of , with some linear algebra one can show that indeed there is such a mapping given by an orthogonal linear transformation, see Appendix A.3. The matrix of this transformation depends on only, making it measure preserving. Since is rotation-invariant, the resulting sampling distributions coincide.

Before proceeding, we would like to mention a modification of the above test where the goal is to test whether the sample comes from a normal distribution with a diagonal covariance. This is a test that both checks each dimension for normality and establishes the independence between the dimensions. When conducting the test only the diagonal of the sample covariance matrix is computed and used for transforming to . The corresponding Monte Carlo procedure takes from , and applies centering by the mean and scaling each dimension by its standard deviation (just like code normalization).

Table 1 displays the thresholds corresponding to the level test, for sample size of for varying dimensions and kernel scales (we have included dimensions and to give an idea about the overall trend; one expects the test to lose power with an increasing dimensionality [Ramdas_decreasing_power]). The column “Sample Type” indicates what processing was applied to the original sample from , if any. The “Original” thresholds can be used for testing the following simple hypothesis: given a sample we would like to test whether the underlying distribution is . The “Centered+Scaled” and “Centered+Whitened” rows give the correct thresholds for composite nulls, i.e. testing whether , for unknown and . “Centered+Scaled” corresponds to the case where is assumed to be diagonal, and “Centered+Whitened” correspond to the case of a general non-degenerate . As expected, dependencies within the sample shift the null distribution of to the left considerably; also see Figure 2 for side by side histograms of “Centered+Scaled” versus “Original” null distributions. Therefore, using the original thresholds for composite hypotheses would have resulted in tests that are rather liberal.

Dimension Sample Type 1/2 1/4 1/8 1/16 HZ
Original 1.97 1.97 1.95 1.92 1.91 1.98
Centered+Scaled/Whitened -0.13 0.32 0.75 1.05 1.24 0.22
Original 1.93 1.94 1.90 1.85 1.83 1.90
Centered+Scaled -0.57 -0.12 0.39 0.86 1.16 0.23
Centered+Whitened -0.79 -0.33 0.22 0.72 1.06 0.03
Original 1.90 1.87 1.83 1.79 1.76 1.83
Centered+Scaled -1.09 -0.66 -0.02 0.63 1.12 0.30
Centered+Whitened -1.60 -1.25 -0.56 0.24 0.85 -0.16
Original 1.85 1.83 1.80 1.77 1.74 1.75
Centered+Scaled -1.76 -1.36 -0.59 0.34 1.11 0.55
Centered+Whitened -2.63 -2.57 -1.88 -0.60 0.58 -0.30
Original 1.81 1.80 1.77 1.74 1.78 1.78
Centered+Scaled -2.65 -2.30 -1.47 -0.15 1.08 1.05
Centered+Whitened -4.00 -4.47 -4.17 -2.31 0.00 -0.03
Original 1.77 1.77 1.74 1.71 1.76 1.30
Centered+Scaled -3.87 -3.59 -2.78 -1.13 0.75 1.02
Centered+Whitened -5.86 -7.12 -7.91 -5.72 -1.24 0.06
Table 1: Empirical thresholds for hypothesis tests with size . See text for the details of when each kind of threshold should be used. Here, , kernel scale . HZ is the suggested by Henze and Zirkel [HZ] as given by the formula Eq. (2.4).

Monitoring WAE Training Progress

We will consider two ways of monitoring progress: at a single batch and multi-batch levels. When inspecting the value of for a single batch, one can use the above thresholds for hypothesis testing as a guideline. Assuming that this batch is from validation or test set, we can use the above thresholds listed in the “Original” rows of Table 1. By looking at these values, we suggest using as an easy to remember liberal threshold. This applies to code normalized batches as long as the normalization is done using population statistics. However, when the batch is normalized using its own statistics, then the appropriate thresholds are given by the “Centered+Scaled” rows. We should stress again that even upon convergence to the target distribution, one should still expect oscillations of the values: it is not the case that samples from the target distribution all have equal to zero, instead they follow the appropriate null distribution.

Remark: With neural nets it is customary to track the training and validation losses during learning. When code normalization is used, the distribution shift exemplified in Figure 2 and seen in Table 1 will result in an even smaller training loss than the validation loss. This is because code normalization uses the batch statistics during training and population statistics at validation/test time. The difference between these losses can be on the order of several ’s; here is the penalty coefficient.

In the multi-batch case, such as when computing the SMMDU for codes corresponding to the validation or test set, one can use the same batch size as used for training and compute the average SMMDU value. This average has a very simple asymptotic distribution under the null as explained below. Assume that the validation/test set contains batches of size , and the corresponding batches are . The average SMMDU value is computed as

Note that under the null, each summand has zero mean and unit variance due to the standardization. Assuming that is big enough, we can apply the Central Limit Theorem [CLT_Ref], giving that the null distribution of is asymptotically normal with mean and variance . Thus, as a rule of thumb, values of that do not fall into the three-sigma interval should be considered as an indication that the aggregate code distribution has not converged to the target standard multivariate normal distribution. The raw MMD version of this test together with theoretical results can be found in [B_test], where it is called the B-test. Thus, we will refer to as the B-Statistic.

Another popular way of keeping track of progress metrics is exponential moving averaging. The Lyapunov/Lindeberg version of the Central Limit Theorem [CLT_Ref, Chapter 27] can be applied to obtain the corresponding interval. Suppose that the exponential moving average with the momentum of is used to keep track of a per-batch quantity . Thus, is used for . Note that, can be written as

here is some initial value, usually , which we will use. Assuming that are standardized to have zero mean and unit variance, the application of the CLT to random variables gives that is normally distributed:

By dropping , we can use as an upper bound for the variance. This gives the three-sigma interval for the E-Statistic liberally as . When we get the interval as . Once again this interval can be used when monitoring the exponential moving average of ; falling outside this interval should be considered as an indication that the aggregate code distribution has not converged to the target standard multivariate normal distribution.

Of course, the single-batch approach above that treats the whole validation/test set as one batch would result in a more powerful test. However, B-statistic or E-statistic tests are simple to state and are computationally inexpensive as they avoid constructing the large pair-wise distance matrices for the overall test. Moreover, neural network packages such as Keras provide these types of averages automatically if one adds the corresponding quantity as a validation metric. At a theoretical level, one should keep in mind that given enough power we will always reject the null: with real-life data one rarely expects the neural net to exactly reproduce the normal distribution. Rejecting the null at high power does not mean that the distributions are easily distinguishable: the practical difference can be so small that a classifier trained to distinguish the two distributions (think of an adversary from an adversarial WAE) would perform at a nearly chance level. Based on these considerations, using the B-Statistic with should be a reasonable choice, see the discussion in [FRISTON20121300, Appendix 1] albeit in a different context; for power calculations for the MMD based tests one can refer to [Sutherland_model_criticismMMD].

5 Experiments

First we discuss our parameterization for the kernel width used in computation of various MMD measures. A rule of thumb choice of the kernel width is , where is the dimension of the code space (see e.g. [InfoVAE, WAE]). This choice is based on considering the average pair-wise distance between two points drawn from the standard multi-variate normal distribution, and halving it to offset the multiplication by in the expression for the kernel. We will see that this choice gives rather suboptimal results, yet it provides a good point of reference for defining scale of the kernel as . We will experiment with various choices of , where gives wider and gives narrower kernels.

Validation

We first experimentally verify that our closed-form formula for results in zero mean and unit variance when . To this end, we sample points from the standard -variate normal distribution and compute the value of . This process is repeated 10,000 times to obtain the empirical distribution of the values. Figure 1 shows the violin plots of these empirical distributions computed for several values of the kernel scale and dimensionality . The red segments in this plot are centered at the mean, and they extend between mean standard deviation. We observe from the graph that the means are close to zero and the standard deviations are close to 1 as expected.

Figure 1: Violin plots verify that has zero mean and unit variance under the null. Here, batch size is and the kernel width is expressed via the scale as

Discriminative Performance

The goal of the next experiment is to compare our closed-formula estimator of MMD (referred to as “Analytic RBF”) to the commonly used sampling based estimator using the same Gaussian RBF kernel (“Empirical RBF”). We also compare to the sampling based estimator but with the inverse multi-quadratics (IMQ) kernel defined by ; we call this “Empirical IMQ”. The IMQ kernel is often claimed to be superior to the RBF kernel due to its slower tail decay.

In our first experiment we would like to determine which one of these three methods is most effective at distinguishing the standard -variate normal distribution from the uniform distribution. Since our goal is to train neural networks rather than perform hypothesis testing, we will not use the test power as a metric of interest; instead we will rely on the effect size defined below. In addition, we are not studying the dependence on the latent dimension, so we do not have to worry about the fair choice of alternatives [Ramdas_decreasing_power].

Table 2: Discrimination power between -variate standard normal distribution and uniform distribution as measured by .
Figure 2: Graphical representation of this discrimination experiment for .
Table 3: Discrimination power between the -variate standard normal distribution and a latent -dimensional embedding of MNIST with an unregularized autoencoder.
Figure 3: Graphical representation of the discrimination experiment with MNIST latent embedding for .

The uniform distribution under consideration is . Note that this particular uniform distribution has mean 0 and variance 1 in each dimension just like the normal distribution. As a result, distinguishing the two distributions requires going beyond the first two moments. For each of the three methods, for a fixed dimension and kernel scale , we sample points from the the standard -variate normal distribution and compute the corresponding MMD estimate. Next we sample points from the uniform distribution and compute the corresponding MMD estimate. We repeat this 200 times, and compute the corresponding means and , and the standard deviations and corresponding to each of the two sets of 200 MMD values111Of course, we expect since all of the three methods are unbiased. For the Analytic RBF, we also know the theoretical value of from the closed-form formula for the variance. However, for fairness we will use empirical estimates for all of the three methods.. Now we can measure the discriminativeness of a given method by computing

Note that this is the effect size of a two sample t-test as measured by Cohen’s d [cohen1988spa]. Larger values of mean better discrimination, which potentially translates to better gradients for neural network training.

The results are presented in Table 2. Note that the experiment was done for different values of the kernel scale; due to the heavier tail, we included more scale choices for the IMQ kernel than for the RBF kernel. For each method and dimensionality choice , the best performing choice of the kernel scale corresponds to the maximum value of ; these values are shown in boldface (we also highlight the values that are within of the maximum).

Figure 2 provides a graphical display for this experiment when . In this graph, for each method, the best choice of the kernel scale was used to compute the distributions of MMD values. For each method, the boxes are centered at the corresponding and the half-height of the box is . The whiskers span the range of all of the values; the blue dots correspond to the MMD values. For a given method, when the boxes corresponding to normal and uniform distributions overlap, it means that the method has difficulty discriminating the two distributions. In terms of training neural networks, this means that the corresponding MMD penalty may not be able to provide a strong gradient direction for training because the difference is lost within the stochastic noise.

We repeat the same experiment but instead of the uniform distribution we use a distribution obtained from a neural networks. We use the MNIST dataset and train auto-encoders (both encoder and decoder have two hidden layers with 128 neurons each, ReLU activations) with different latent dimensions with no regularization. The codes corresponding to the test data are extracted and shifted to have zero mean. We observed that with growing the various latent dimensions were highly correlated (e.g. Pearson correlations as high as ); thus, to make the task more difficult, we applied PCA-whitening to the latent codes. The resulting discrimination performance is presented in Table 3 and Figure 3.

By examining both of the tables above, we can see that Analytic RBF method outperforms both the Empirical RBF and IMQ methods in terms of discrimination power. Another observation is that the commonly recommended choice of (which corresponds to the kernel scale ) is never a good choice; a similar finding for the median heuristic was spelled out in [Sutherland_model_criticismMMD]. The kernel width recommended for Henze-Zirkler test gives mixed results, which is somewhat expected—optimality for density estimation does not guarantee optimal discriminative performance. Examining the Analytic RBF results, it seems that kernel scales or provide a good rule of thumb choices. Finally, in these particular examples we see that despite its having a larger repertoire of kernel scale choices, Empirical IMQ does not perform as well as Empirical RBF.

While these results are limited to two datasets, yet they bring into question the commonly recommended choices of the kernel and its width. Of course, our analysis assumes that the alternative distribution has zero mean and unit variance in each dimension. We believe that this is the most relevant setting to WAE learning because during the late stages of WAE training the code distribution starts converging to the normal distribution.

Outliers

Figure 4: Outlier discrimination experiment carried out for . For each method, the most discriminative (i.e. maximum ) kernel scale is chosen.

Here we experimentally verify the outlier insensitivity of the MMD and demonstrate that the issue is not peculiar to our approach. To this end, we run the discrimination experiment above but this time trying to distinguish a sample from the standard -variate normal distribution from the same but with one of the sample points replaced with a point far away from the origin (namely ). Figure 4 shows that all of the three methods fail to distinguish these two distributions in practice.

Code Norm (60 epochs)

AdaptiveBN (80 epochs)

AdaptivePlain (80 epochs)

a) Test reconstruction b) Random samples c) Slice through code space
Figure 5: Qualitative results for WAE trained on MNIST. In (a) odd rows are the real images.

WAE results

Here we present the results of training WAE’s on MNIST dataset. The architecture for the neural net is borrowed from RStudio’s “Keras Variational Auto-encoder with Deconvolutions” example222https://keras.rstudio.com/articles/examples/variational_autoencoder_deconv.html. This network has about 3.5M trainable parameters, almost an order of magnitude less than the the 22M parameter network used by Tolstikhin et al. [WAE]. We consider three versions:

  • CodeNorm—code normalization is used, the kernel width is kept fixed.

  • Adaptive—no code normalization is used, kernel width is chosen adaptively. This has two versions:

    • AdaptiveBN—since code normalization can have other benefits (e.g. improved optimization [BatchNormHow]), we add batch normalization as the initial layer of the decoder;

    • AdaptivePlain—no batch normalization layer added.

The CodeNorm version was trained for 60 epochs, but to allow the Adaptive versions to reach a favorable configuration in the code space we trained them for an extra 20 epochs at the initial learning rate. The latent dimension is set to and all versions use the closed-form SMMDU penalty; further details are provided in Appendix A.5.

Figure 5 (a)-(b) shows the reconstruction of test images and random samples generated from Gaussian noise fed to the decoder. Next we take a planar slice through the origin in the code space and feed the codes at the regular grid along this plane into the decoder. Figure 5 (c) depicts the resulting digit images, giving a taste of the manifold structure captured by the models. Qualitatively, both of the Adaptive results are lower quality than CodeNorm despite the former being trained for more epochs.

Quantitative results are presented in Table 4. CodeNorm achieves the best test reconstruction error. We speculate that the reason for this is that the gradient components of the MMD penalty pointing in the direction of “ideal” samples (see Section 3.2) add oscillations that hinder reduction in the reconstruction loss of the Adaptive models.

Next, we follow the suggestion of [Cramer_Wold_AE] to compute Mardia’s multivariate skewness and normalized kurtosis statistics of the latent code distribution of test data; we used the formulas provided in [Cramer_Wold_AE] and obtained the values as shown in the table. We see that for both measures, the CodeNorm version is better. Skewness is a measure of symmetry, so its small magnitude indicates that the code distribution is symmetrically distributed around the origin. Since kurtosis is a measure of outlier presence [kurtosis_RIP], its small value indicates that there are no outliers present in the code distribution. We verified experimentally (not presented here) that code normalization is responsible for keeping kurtosis under control. Indeed, removing the code normalization layer from a trained network, modifying the latent layer incoming weights so that the codes have zero mean and unit variance, and continuing to train afterwards leads to increased kurtosis as predicted in Section 3.2.

Finally, we analyze the results using the B-statistic discussed in Section 4 which gives a more in-depth summary of the data than Mardia’s statistics. We computed the B-statistic using batches of size from the test partition of MNIST. The corresponding three sigma interval is . Both CodeNorm and AdaptiveBN look good in terms of this statistic, CodeNorm falling inside the interval; on the other hand AdaptivePlain is somewhat farther away, indicating that its code distribution more noticeably deviates from the target distribution.

WAE Version Test MSE Normalized Kurtosis Skewness B-Statistic
CodeNorm 0.0156 -0.90 0.56 0.355
AdaptiveBN 0.0244 6.85 2.80 0.449
AdaptivePlain 0.0242 3.81 2.35 0.519
Table 4: Quantitative comparison of different WAE versions.

6 Conclusion

This paper introduces closed-form formulas for MMD and its variance in the case of the standard multivariate normal target distribution. This allows us to propose a properly normalized and more interpretable standardized version of MMD as a penalty in the WAE training objective. We point out a relationship with the BHEP statistic that provides further insights about the MMD and allows making a connection to the Cramer-Wold distance. In addition, we propose using code normalization when training WAEs; this has the benefits of making the kernel width selection easier, reducing the training effort, and preventing outliers in the aggregate code distribution. Finally, we discuss the appropriate null distributions and provide thresholds for multivariate normality testing with SMMDU. A number of rules of thumb are provided for monitoring the progress of WAE training. Our experiments on synthetic and real data confirm that the analytic formulation improves over the commonly used stochastic approximation of the MMD, and demonstrate that code normalization provides significant benefits when training WAEs.

References

Appendix A Appendix

a.1 Deterministic Encoders

We start with the expression

and show that the first two expectations can be computed in closed form. Let us start with the second term, and rewrite each summand as an integral:

(A.1)

Since , the integral in this expression can be recognized as the probability density function of the sum where and . Being a sum of two normal distributions, adding means and variances we get, , and the above expression computes to

(A.2)

Next, we compute the first term in the formula Eq. (2.2) by rewriting it as an integral:

(A.3)

In this expression, let us replace by , and remember that we would get the sought value by setting . Rewriting this as

(A.4)

With this replacement, we can recognize the inner integral as the density function of the sum of two multivariate normal variables. Interpreting the outer integral similarly, we can see that the entire double integral captures the probability density function of the sum , where , and . Being a sum of three normal distributions, adding means and variances we get , immediately giving the expression for this integral as

Including the multiplier in front of the integral, and setting , we obtain: