Balancing Reconstruction Quality and Regularisation in Evidence Lower Bound for Variational Autoencoders

Balancing Reconstruction Quality and Regularisation in Evidence Lower Bound for Variational Autoencoders

Shuyu Lin1, Stephen Roberts1, Niki Trigoni1, Ronald Clark2
Abstract

A trade-off exists between reconstruction quality and the prior regularisation in the Evidence Lower Bound (ELBO) loss that Variational Autoencoder (VAE) models use for learning. There are few satisfactory approaches to deal with a balance between the prior and reconstruction objective, with most methods dealing with this problem through heuristics. In this paper, we show that the noise variance (often set as a fixed value) in the Gaussian likelihood for real-valued data can naturally act to provide such a balance. By learning this noise variance so as to maximise the ELBO loss, we automatically obtain an optimal trade-off between the reconstruction error and the prior constraint on the posteriors. This variance can be interpreted intuitively as the necessary noise level for the current model to be the best explanation of the observed dataset. Further, by allowing the variance inference to be more flexible it can conveniently be used as an uncertainty estimator for reconstructed or generated samples. We demonstrate that optimising the noise variance is a crucial component of VAE learning, and showcase the performance on MNIST, Fashion MNIST and CelebA datasets. We find our approach can significantly improve the quality of generated samples whilst maintaining a smooth latent-space manifold to represent the data. The method also offers an indication of uncertainty in the final generative model.

1 Introduction

Variational Auto-encoders (VAEs, [Kingma and Welling2013, Rezende, Mohamed, and Wierstra2014]) are a stable and efficient approach to unsupervised learning. VAEs naturally combine two learning outcomes: enabling generation of data similar to the observations and secondly offering a probabilistic embedding scheme, in which the statistics at both sample and group levels satisfy required constraints. VAEs can, for example, be used for dimensionality reduction, feature extraction and efficient anomaly detection in the latent space.

Learning with a VAE naturally requires a balance between reconstruction loss minimisation and prior constraint enforcement. Several papers discuss the importance of this balance and the impact of different trade-off settings on the learned VAE models [Alemi et al.2018, Mathieu et al.2019]. In [Higgins et al.2017a], the authors propose scaling up the prior regularisation term in an attempt to promote disentanglement in the learnt representation. On the other hand, in [Alemi et al.2017] more emphasis is placed on the reconstruction loss term, to promote maximal information retention in the learned representation. The trade is most commonly adapted on a problem by problem basis [Vedantam et al.2018, Suzuki, Nakayama, and Matsuo2017, Higgins et al.2017b].

In this paper, we propose a mechanism to optimise the trade-off between prior and reconstruction loss. We show that, by parameterising this balance via a variance parameter, we can achieve significant performance gains, both in terms of achieving a closer bound to the true data log likelihood as well as improving the generated sample quality. Unlike the hyperparameter introduced in [Higgins et al.2017a], this variance hyperparameter intuitively represents the noise level in the dataset conditioned on the model explaining the observations. Furthermore, we develop a robust algorithm to learn to predict the variance hyperparameter conditioned on a given input. We show that the predicted variance can be used as an effective uncertainty estimator for reconstructed or generated samples.

By altering the strength of the prior regularisation term, the gap between the marginal latent distribution and the prior, measured for example using the Kullback Leibler (KL) divergence between the aggegate posterior and prior of the learnt model, becomes more prominent. This can significantly damage the quality of the generated samples as measured by common perceptual metrics such as the Inception Score (IS) [Salimans et al.2016] or the Frechet Inception Distance (FID) [Heusel et al.2017a]. To avoid this problem, we propose a simple approximation to the learned latent marginal distribution and use this approximate distribution, rather than the prior, to generate samples.

In summary, our contributions are:

  • We show how optimisation of the variance hyperparameter under the VAE ELBO loss allows us to automate the trade-off between the reconstruction loss and the prior regularisation and we propose a stable learning procedure to accomplish this optimisation.

  • We study the impact of sampling from the learned latent marginal distribution as opposed to the prior and propose to use an approximate marginal distribution instead of the prior to generate samples.

2 Background

2.1 VAE ELBO Loss

Given a dataset of observations and a latent variable model , the learning objective in VAEs maximises the marginal likelihood for all the data points in under the model parameter , i.e.

(1)

where describes the true data distribution and is the marginal distribution of under the latent variable model by integrating out according to a prior distribution . Directly evaluating Equation (1) is often not feasible, because is often parameterised by a neural network and integration over such a function cannot be easily evaluated. Instead, we can use variational evidence lower bound (ELBO) on as an alternative objective by introducing an amortized inference model , as shown below:

(2)
(3)
(4)

Substituting in (4) into the VAE objective (1) and evaluating the expectation wrt with an empirical approximation, i.e. , we have the final VAE learning objective:

(5)

Facilitated with the reparameterisation trick [Williams1992], the above ELBO loss can be optimised with stochastic gradient descent algorithms, leading to an efficient learning method \shortciteVAE,diff-VAE.

2.2 Competing Learning Objectives in ELBO

The two terms in the ELBO loss (Eq. 5) control different behaviours of the model. The

1
reconstruction likelihood term attempts to reconstruct the input as faithfully as possible. The

2
prior constraint term enforces the posterior distribution of under the encoding function to comply with an assumed prior and if the constraint is effectively imposed, the aggregate posterior of the entire dataset [Tomczak and Welling2018] is likely to be very close to the prior, indicating the data distribution has been successfully projected to the target prior in the latent space. Optimizing the two losses together, we would hope to simultaneously achieve the best performance in both terms.

Unfortunately, the two losses are often conflicting. To see this, consider if

1
is maximized, then the embeddings of different input samples and can be easily distinguished, i.e. should be large. On the other hand, if

2
is satisfied, then all sample embeddings should approach , i.e. should be zero. Therefore,

1
and

2
almost always influence each other. As a result, finding the optimal balance between the two terms is a crucial part in the VAE learning process. Unfortunately, there has not been a satisfying solution and most methods handle this problem by introducing weight hyperparameters on the two losses and choose their values based on heuristics \shortcitebeta-VAE,beta-VAE2,deep-info-bottleneck.

3 Our Proposal

In this work, we propose a method that automatically finds the best balance between the two contradicting objectives in the VAE ELBO loss.

3.1 Gaussian Likelihood in the Generative Model

For real-valued data , we can use a Gaussian distribution to model the conditional distribution in the generative model, i.e.

(6)

where is a common global variance parameter, reflecting the global noise properties of the data; represents a nonlinear mapping that transforms an encoding in the latent space to the data space; and represents the -th dimension of the data variable . If we assume that different dimensions in are independent, then the conditional density of the complete data variable is

(7)

where is the dimension of the data variable . Therefore, in the ELBO loss given by (5) can be computed as

(8)

Notice the terms inside the summation are element-wise square errors between the generated sample and the original sample and naturally appears as a weighting factor on the sum of square error term. When is maximised, the term has an important regularizing effect on the value: it prevents from taking very large values, which would allow the generator function to produce arbitrarily bad reconstruction.

3.2 ELBO Loss with A Global Variance Parameter

Substituting (8) into (5), we have the overall ELBO loss with the Gaussian likelihood assumption as

(9)

From Equation (9), we can see that term

3
represents a relative weighting between the two competing learning objectives;

1
the reconstruction error; and

2
prior constraint on the posteriors . If we optimise the ELBO in (9) w.r.t. , and , then we will reach a maximised lower bound of the data likelihood where the model automatically balances between the objectives of minimizing the information loss through the auto-encoding process while still having the latent marginal distribution remain close to the prior distribution. The optimal , which is obtained at the maximal ELBO and denoted as , can be interpreted as the amount of noise that has to be assumed in the dataset for the current reconstruction to be considered as the best explanation of the observed samples.

3.3 Closed-form Solution for the Variance

For fixed and , we can derive a closed form solution for the optimal global variance . To compute , we take derivative of in Equation (9) wrt and set it to zero. Then

(10)

The significance of this result is that it improves the stability of the learning process with the additional parameter, as we can always find a local minimum of the ELBO loss by updating and the model parameters and iteratively.

3.4 More Flexible Variance Estimation

So far we have taken to be the same value for the entire dataset and across all data dimensions, but nothing stops us here. We can extend the estimation to be conditioned on particular encodings, be that an encoding of a real data sample or a sampled latent code from the prior distribution. We can also allow to differ across data dimensions and such can conveniently highlights the regions with high uncertainty in a reconstructed or generated sample. To achieve these goals, we need to replace the global variance in Equation (6) by a variance estimation function which is parameterised by and conditioned on data dimension and an input encoding . The corresponding now becomes

(11)

Substituting (11) in (5), we get the ELBO loss associated with the input dependent variance estimation and learning can be done by optimising this ELBO wrt to all model parameters and .

However, we notice that learning to predict together with learning the parameters for the auto-encoding task from scratch turns out to be extremely challenging and the optimisation often gets stuck at predicting very small variance values. This is because the variance prediction is highly dependent on the reconstruction error given by the current model. At early stage when both the encoder and decoder are inaccurate, the reconstruction error is large and has high variance. This causes random updates in the variance prediction module. If the variance happens to arrive at a small value where the gradient of the term dominates, then it is very hard for the gradient update to escape the strong negative gradient from the function. To prevent this happening and obtain stable learning, we propose a staged learning process where we start with learning the global variance parameter until certain condition of convergence is reached and then we switch to the input dependent variance prediction module with the model parameters ( and ) continuing updates from the optimised values given by the previous stage. We also use the optimised global variance as an effective lower bound to prevent the predicted variance from getting stuck at small values. The learning procedure is summarised in Algorithm 1.

3.5 Importance of Aggregate Posterior

Even when optimising the variance hyperparameter, the learnt posteriors might be quite different from the assumed prior distribution. If this is the case, then using the prior distribution to generate samples may no longer be a good idea. A better alternative is to generate samples from the aggregate posterior , which can be considered as the data distribution in the latent space, defined \shortciteVampPrior as

(12)

The approximation in (12) is made by evaluating the expectation w.r.t empirically for and with the isotropic Gaussian assumption for posteriors , the aggregate posterior is effectively a mixture of Gaussian distribution with mixture components. However, there are two problem with generating data samples directly from the aggregate posterior: 1) large memory is needed to store the statistics of all components in the aggregate posterior distribution and 2) as training data is finite, the aggregate posterior is likely to be overfitted to the training samples, which means that only samples very similar to the observed data will be generated.

We propose to address both problems by using a simpler approximate distribution to the true aggregate posterior. There are many choices for the approximate distribution. In our case, we simply use a Gaussian mixture distribution with components ( and we often take =30 for simple datasets, such as MNIST and Fashion MNIST), as shown below:

(13)

where is the weight for each Gaussian mixture ( and ) and and are the mean and covariance matrix for the -th Gaussian mixture. The Gaussian mixture distribution is a good choice for the following reasons: 1) it is sufficiently expressive to represent the major modes and the low density regions that might possess; 2) with limited number of components, it is much simpler than the original aggregate posterior, significantly reducing the risk of overfitting; and 3) there are efficient algorithms, such as EM [Bishop2006] and Variational Bayesian inference methods [Blei, Kucukelbir, and McAuliffe2016], to derive the Gaussian mixture approximation.

  , Initialize parameters
   Initialize as 1
  , , n_epoch_1, n_epoch_2 Set stopping condition
  Set epoch counter
  while e n_epoch_1 or  do
     Update , and using AdamOptimizer on the ELBO loss defined in Equation (9)
     
     e = e + 1
  end while
  while e n_epoch_2 or  do
     Freeze value as
     Add the input dependent variance module with , 111where indicates a function that takes as input and predicts the -th dimension variance
     Update , using AdamOptimizer on the ELBO loss with the term given by Equation (11).
     
     e = e + 1
  end while
  return
Algorithm 1 VAE learning with variance prediction

4 Related Work

Optimising the variance parameter in a likelihood model has been adopted in Mixture Density Networks [Bishop1994] and Gaussian process regression models [Rasmussen and Williams2005]. Both works aim to learn a model for a regression task and use an isotropic Gaussian likelihood to model the conditional distribution between an input and an output . Although the nature of their learning tasks is vastly different from ours (supervised vs unsupervised), they emphasise the same message as our proposal that learning the variance parameter in the likelihood model for noisy observations is an integral part of the learning process and should not be omitted. Two major differences between our proposal and their treatment to the variance parameter are: 1) we extend the variance prediction to more flexible settings, such as being conditioned on an encoding and variable across data dimensions, whereas they only use a single-valued variance parameter; 2) the learning objective in their models only contains the Gaussian likelihood term and therefore they do not have the tradeoff between reconstruction and regularisation that we are faced with when using the ELBO loss. In the original VAE paper [Kingma and Welling2013], estimation of a similar variance in the decoder is briefly mentioned in the appendix, where their decoder estimates a single-valued variance for a given input sample. However, such prediction is often unstable and thus our proposal uses a two-stage training procedure to stabilise the training.

The undesirable gap between the aggregate posterior and prior has been noted in \shortciteVAE-suboptimality. Many works have been proposed to mitigate such a gap by using either a more expressive model for sample posteriors [Rezende and Mohamed2015, Kingma, Salimans, and Welling2016, Ranganath, Tran, and Blei2016, Tomczak and Welling2016] or a more flexible model for the prior [Tomczak and Welling2018, Dilokthanakul et al.2016]. Our treatment of the gap differs from all the aforementioned works, as we realise that such gap is hard to be eliminated. Hence, the best remedy is to replace the prior with a distribution that is better matched with the amortized posterior and the generative likelihood so that the gap between the amortized posterior and the true posterior from the generative model can be effectively reduced, giving a closer bound to .

5 Experiment Results

1 0.5 0.035 (Optimal) 0.01 Prior
Aggregate posterior
Approximate aggregate posterior
\addstackgap[.5]0 0.094 0.003 0.132 0.003 0.422 0.011 1.552 0.034
Table 1: Comparison of the aggregate posterior and its Gaussian mixture approximation (32 mixtures) for models learnt under different values for MNIST dataset. denotes the gap between the approximate aggregate posterior and the prior distribution and is evaluated by Monte Carlo estimation of 10k samples (10 runs).

We carry out extensive experiments on MNIST [LeCun1998], Fashion MNIST [Xiao, Rasul, and Vollgraf2017] and CelebA [Liu et al.2015] datasets. For all datasets, we take the images as real-valued data so that our Gaussian likelihood assumption is appropriate. We take the original VAE and -VAE as baseline methods for comparison. We also compare to WAE [Tolstikhin et al.2017] and DIP-VAE [Kumar, Sattigeri, and Balakrishnan2018] to demonstrate the significant performance gain given by optimising the variance parameter and sampling from the approximate aggregate posterior. More results and details of data pre-processing and model architectures are given in Supplemental Materials.

5.1 Intuition through visualisation

First we try to gain an intuitive understanding of the impact of learning the parameter in our proposal. To this end, we trained VAE models of two dimensional latent space on MNIST dataset under different values and in Table 1 we show log density plots of various distributions in the latent space.

When = 0.5, the weighting between the reconstruction loss and the prior constraint is equal and this corresponds to the original VAE learning objective. When , the weight on the reconstruction loss is halved, resulting in an increased penalty on the prior regularisation, and this leads to the -VAE’s objective. = 0.035 is learnt under our proposal by optimising the ELBO in (9), which learns the optimal balance between the two losses. Finally, = 0.01 indicates the scenario where extreme penalty on the reconstruction loss is imposed.

From the visualisation, we can see that learning under different values leads to very different inference models. Larger values result in an aggregate posterior closer to the prior distribution and often corresponding to a smooth density. As the gets smaller, sample posteriors become more distinctive and the aggregate posterior clearly becomes more complex and contains more sophisticated low density regions. Either end of the value spectrum is sub-optimal: too large causes severe information loss of the input data and too small leads to a marginal distribution that is overly complex. We argue that the optimised offers the best balance. More experimental evidence for this claim is given in Section 5.2.

(a) Optimised ELBO (higher is better).
(b) FID score (lower is better).
(c) (lower is better).
Figure 1: Impact of learning on 3 metrics for VAE performance on MNIST dataset. Optimising (green star) reaches the highest ELBO, very low in FID score and reasonably small .
\addstackgap[.5] Ours (optimising ) VAE [Kingma and Welling2013] -VAE [Higgins et al.2017a] DIP-VAE [Kumar, Sattigeri, and Balakrishnan2018] WAE [Tolstikhin et al.2017] Real images
MNIST 8.9 74.9 438.1 114.9 142.8 1.9
Fashion MNIST 10.6 123.8 281.5 103.6 149.5 0.7
CelebA 74.7 78.4 106.4 87.5 85.4 2.9
Table 2: FID scores compared across 5 different methods for MNIST, fashion MNIST and CelebA datasets. For all datasets, optimising achieves the best visual quality (sample are generated from approx. aggregate posterior where possible).
Optimising (Ours) VAE -VAE DIP-VAE
dataset pixel
MNIST Prior 1307 6037 -462 -730 -461
Approx. Agg. Posterior 1323 6052 -461 -730 -460
Fashion MNIST Prior 1145 5708 -463 -731 -463
Approx. Agg. Posterior 1157 5724 -462 -731 -462
Table 3: IWAE test log likelihood (LL, higher is better) for MNIST and Fashion MNIST compared across 4 different methods. In our proposal, we evaluate the LL for two different variance estimation settings: 1) a global variance of the entire dataset and 2) input dependent pixel level variances.

The visualisation also demonstrates that except for extremely small values, the Gaussian mixture model can effectively approximate the aggregate posterior. Thanks to the closed-form solution of the global variance parameter derived in (10) and the stabilised learning procedure introduced in Section 3.4, we will never enter these small values regime in our learning scheme. Furthermore, the gap between the aggregate posterior and the prior persistently exists no matter what value takes. There are good texts explaining why such a gap appears [Cremer, Li, and Duvenaud2018]. We would like to stress that if such a gap is hard to avoid, then samples should not be naively generated from the prior distribution but from an alternative distribution that better represents the aggregate posterior. We delegate Section 5.3 for more detailed discussion on this gap.

5.2 Performance Gain from Optimizing

Here we compare the learned VAE models under 3 metrics: 1) the optimised ELBO value (higher indicates a closer bound to the data log likelihood), 2) FID score (lower indicates a better visual quality of the generated samples [Heusel et al.2017b]) and 3) the KL divergence between the approximate aggregate posterior and the prior (smaller often indicates a simpler and smoother marginal latent distribution). In Figure 1, we show these metrics with fixed at 6 different values ranging between 0 and 1 and also trained a VAE model with learnt under the ELBO in (9) for the MNIST dataset. When is optimised (indicated by a green star), the model achieves the highest ELBO, one of the lowest FID scores and still remains reasonably close to the prior distribution. In Figure 0(b), we plot the FID scores evaluated with samples drawn from the approximate posterior (blue) and the prior (red). For all cases the sample quality from aggregate posterior is better than that from the prior distribution and the gap grows exponentially as gets smaller.

We also compared the model performance under our proposal (learning ) with four other learning methods, including the original VAE (=0.5), -VAE ((=1), DIP-VAE ([Kumar, Sattigeri, and Balakrishnan2018]) and WAE [Tolstikhin et al.2017] in terms of the generated sample quality measured by FID scores for all three datasets. All methods adopt the same model architecture. The result is given in Table 2. For all datasets, optimising achieves the best FID score. Examples of generated samples are given in Supplemental Materials.

(a) A significant gap between the aggregate posterior (blue) and
prior (orange) distributions consistently exists across all methods.
(b) Samples generated from approximate aggregate posterior (top)
or prior distribution (bottom) for CelebA dataset.
Figure 2: Gap between aggregate posterior and prior distributions causes bad quality samples to be generated from the prior.

Finally, for VAE models with relatively low latent dimensions, we are able to evaluate the IWAE test log likelihood (LL) [Burda, Grosse, and Salakhutdinov2015], which is proposed as a tighter lower bound to the log likelihood than ELBO and hence will give a better evaluation on how well the model is optimised in terms of approaching the data log likelihood objective. In Table 3, we list the IWAE test LL evaluated over 500 test data samples with 20 importance points per sample for the MNIST and fashion MNIST datasets where a VAE model of 16-dimensional latent space is learnt. Optimising significantly improves the test LL in comparison to other learning methods. More interestingly, there is a huge performance boost when is estimated across data dimensions and by conditioning on the input data compared to optimising a single-valued . Therefore, we conclude that using the full learning procedure to predict a more flexible estimator gives better results.

5.3 Aggregate Posterior vs Prior

In Figure 1(a), we compared FID scores of generated samples from both approximate aggregate posterior and prior distributions across 4 different methods and it indicates the sample quality is always better if the samples are drawn from the approximate aggregate posterior. In Figure 1(b), samples from both distributions are generated under our proposal for the CelebA dataset. The samples from the aggregate posterior are consistently more realistic, whereas the samples from the prior distribution are either with essential facial feature missing or damaged.

5.4 Uncertainty Estimation

A very powerful application of optimising the parameter, especially with the more flexible setting where the is conditioned on a latent encoding and allowed to be variable across data dimensions, is to estimate uncertainty in the generated samples. Figure 3 illustrates the estimated uncertainty in the reconstructed samples for MNIST and CelebA datasets. Note the predicted highlights the region where the reconstruction is different from the original data samples. During training, is learnt to suppress the difference between the reconstructed and the original data samples and such difference is highly correlated with the statistical irregularity in the dataset. For example, the major difference between the reconstructed digit zero and its corresponding data sample is at the top where the stroke starts and ends. It is the most unpredictable part of writing a zero and that’s why the model fails to reconstruct this part accurately. Figure 4 illustrates uncertainty estimation in generated samples for both MNIST and CelebA datasets. Notice in Figure 3(b), the high uncertainty regions correspond to eyes, mouths, edges of hair and backgrounds, which are highly variable across the dataset. Such uncertainty estimation indicates that the model is aware that there are many other possibilities to generate these regions and it only renders one of them.

(a) MNIST.
(b) CelebA.
Figure 3: Uncertainty estimation for reconstructed data. a) ground truth images, b) reconstructed images and c) estimated indicating uncertainty.
(a) MNIST.
(b) CelebA.
Figure 4: Uncertainty estimation for generated data. a) generated samples and b) estimated uncertainty.

5.5 Ablation Study

A very important hyperparameter in our learning procedure is the number of components in the Gaussian mixture model that we use to estimate the aggregate posterior . We carried out a study about its impact on the quality of generated samples measured by FID scores for the CelebA dataset and the result is shown in Table 4. Increasing the number of components does not lead to any obvious improvement in the generation quality (FID score is only improves by 1.5 () when the number of mixtures increases from 500 to 2000). Only when the number of mixture is extremely large, such as 10k, we can see a significant gain, but this is likely to be caused by the approximation becoming overfitted to the 50k samples that we use to fit the mixture model. Therefore, in all of our experiments, we stick with a relatively small number of Gaussian mixtures, specifically 32 for MNIST and fashion MNIST and 500 for CelebA.

N mixtures 500 1000 2000 10000
FID 79.9 79.2 78.4 73.1
Table 4: Study the impact of the number of Gaussian mixtures used in the approximation of the aggregate posterior on the generated sample quality (CelebA dataset).

6 Conclusion

In this paper, we propose a learning algorithm that automatically achieves (for real-valued data) optimal balance between the two competing objectives (reconstruction loss and prior regularisation) for the VAE ELBO loss. A convenient by-product of our learning scheme is an effective uncertainty estimator for the generation or reconstruction prediction, allowing a wider range of potential applications, including safety critical environments. We further study the gap between the aggregate posterior and the prior distribution, which is associated with poor samples being generated, and offer a simple solution to mitigate such problems.

References

See pages 1-9 of appendix1.pdf See pages 1-6 of appendix2.pdf

Comments 0
Request Comment
You are adding the first comment!
How to quickly get a good reply:
  • Give credit where it’s due by listing out the positive aspects of a paper before getting into which changes should be made.
  • Be specific in your critique, and provide supporting evidence with appropriate references to substantiate general statements.
  • Your comment should inspire ideas to flow and help the author improves the paper.

The better we are at sharing our knowledge with each other, the faster we move forward.
""
The feedback must be of minimum 40 characters and the title a minimum of 5 characters
   
Add comment
Cancel
Loading ...
389216
This is a comment super asjknd jkasnjk adsnkj
Upvote
Downvote
""
The feedback must be of minumum 40 characters
The feedback must be of minumum 40 characters
Submit
Cancel

You are asking your first question!
How to quickly get a good answer:
  • Keep your question short and to the point
  • Check for grammar or spelling errors.
  • Phrase it like a question
Test
Test description