A Stable Variational Autoencoder for Text Modelling
Abstract
Variational Autoencoder (VAE) is a powerful method for learning representations of highdimensional data. However, VAEs can suffer from an issue known as latent variable collapse (or KL loss vanishing), where the posterior collapses to the prior and the model will ignore the latent codes in generative tasks. Such an issue is particularly prevalent when employing VAERNN architectures for text modelling bowman2016generating. In this paper, we present a simple architecture called holistic regularisation VAE (HRVAE), which can effectively avoid latent variable collapse. Compared to existing VAERNN architectures, we show that our model can achieve much more stable training process and can generate text with significantly better quality.
1 Introduction
Variational Autoencoder (VAE) kingma2013auto is a powerful method for learning representations of highdimensional data. However, recent attempts of applying VAEs to text modelling are still far less successful compared to its application to image and speech bachman2016architecture; fraccaro2016sequential; semeniuta2017hybrid. When applying VAEs for text modelling, recurrent neural networks (RNNs)^{1}^{1}1NB: here we refer RNN to any type of recurrent neural architectures including LSTM and GRU. are commonly used as the architecture for both encoder and decoder bowman2016generating; xu2018spherical; dieng2019avoiding. While such a VAERNN based architecture allows encoding and generating sentences (in the decoding phase) with variablelength effectively, it is also vulnerable to an issue known as latent variable collapse (or KL loss vanishing), where the posterior collapses to the prior and the model will ignore the latent codes in generative tasks.
Various efforts have been made to alleviate the latent variable collapse issue. bowman2016generating uses KL annealing, where a variable weight is added to the KL term in the cost function at training time. yang2017improved discovered that there is a tradeoff between the contextual capacity of the decoder and effective use of encoding information, and developed a dilated CNN as decoder which can vary the amount of conditioning context. They also introduced a loss clipping strategy in order to make the model more robust. xu2018spherical addressed the problem by replacing the standard normal distribution for the prior with the von MisesFisher (vMF) distribution. With vMF, the KL loss only depends on the concentration parameter which is fixed during training and testing, and hence results in a constant KL loss. In a more recent work, dieng2019avoiding avoided latent variable collapse by including skip connections in the generative model, where the skip connections enforce strong links between the latent variables and the likelihood function.
Although the aforementioned works show effectiveness in addressing the latent variable collapse issue to some extent, they either require carefully engineering to balance the weight between the reconstruction loss and KL loss bowman2016generating; sonderby2016train, or resort to designing more sophisticated model structures yang2017improved; xu2018spherical; dieng2019avoiding.
In this paper, we present a simple architecture called holistic regularisation VAE (HRVAE), which can effectively avoid latent variable collapse. In contrast to existing VAERNN models for text modelling which merely impose a standard normal distribution prior on the last hidden state of the RNN encoder, our HRVAE model imposes regularisation for all hidden states of the RNN encoder. Another advantage of our model is that it is generic and can be applied to any existing VAERNNbased architectures.
We evaluate our model against several strong baselines which apply VAE for text modelling bowman2016generating; yang2017improved; xu2018spherical. We conducted experiments based on two public benchmark datasets, namely, the Penn Treebank dataset marcus1993building and the endtoend (E2E) text generation dataset novikova2017e2e. Experimental results show that our HRVAE model not only can effectively mitigate the latent variable collapse issue with a stable training process, but also can give better predictive performance than the baselines, as evidenced by both quantitative (e.g., negative log likelihood and perplexity) and qualitative evaluation. The code for our model is available online^{2}^{2}2https://github.com/ruizheliUOA/HRVAE.
2 Methodology
2.1 Background of VAE
A variational autoencoder (VAE) is a deep generative model, which combines variational inference with deep learning. The VAE modifies the conventional autoencoder architecture by replacing the deterministic latent representation of an input with a posterior distribution , and imposing a prior distribution on the posterior, such that the model allows sampling from any point of the latent space and yet able to generate novel and plausible output. The prior is typically chosen to be standard normal distributions, i.e., , such that the KL divergence between posterior and prior can be computed in closed form kingma2013auto.
To train a VAE, we need to optimise the marginal likelihood , where the log likelihood can take following form:
(1) 
(2) 
Here is the variational approximation for the true posterior . Specifically, can be regarded as an encoder (a.k.a. the recognition model) and the decoder (a.k.a. the generative model). Both encoder and decoder are implemented via neural networks. As proved in kingma2013auto, optimising the marginal log likelihood is essentially equivalent to maximising , i.e., the evidence lower bound (ELBO), which consists of two terms. The first term is the expected reconstruction error indicating how well the model can reconstruct data given a latent variable. The the second term is the KL divergence of the approximate posterior from prior, i.e., a regularisation pushing the learned posterior to be as close to the prior as possible.
2.2 Variational Autoendoder with Holistic Regularisation
In this section, we discuss the technical details of the proposed holistic regularisation VAE (HRVAE) model, a general architecture which can effectively mitigate the KL vanishing phenomenon.
Our model design is motivated by one noticeable defect shared by the VAERNN based models in previous works bowman2016generating; yang2017improved; xu2018spherical; dieng2019avoiding. That is, all these models, as shown in Figure 0(a), only impose a standard normal distribution prior on the last hidden state of the RNN encoder, which potentially leads to learning a suboptimal representation of the latent variable and results in model vulnerable to KL loss vanishing. Our hypothesis is that to learn a good representation of data and a good generative model, it is crucial to impose the standard normal prior on all the hidden states of the RNNbased encoder (see Figure 0(b)), which allows a better regularisation of the model learning process.
We implement the HRVAE model using a twolayer LSTM for both the encoder and decoder. However, one should note that our architecture can be readily applied to other types of RNN such as GRU. For each time stamp (see Figure 0(b)), we concatenate the hidden state and the cell state of the encoder. The concatenation (i.e., ) is then fed into two linear transformation layers for estimating and , which are parameters of a normal distribution corresponding to the concatenation of and . Let , we wish to be close to a prior , which is a standard Gaussian. Finally, the KL divergence between these two multivariate Gaussian distributions (i.e., and ) will contribute to the overall KL loss of the ELBO. By taking the average of the KL loss at each time stamp , the resulting ELBO takes the following form
(3) 
As can be seen in Eq. 2.2, our solution to the KL collapse issue does not require any engineering for balancing the weight between the reconstruction term and KL loss as commonly the case in existing works bowman2016generating; sonderby2016train. The weight between these two terms of our model is simply .


Dataset  Training  Development  Testing  Avg. sent. length  Vocab. 
PTB  42,068  3,370  3,761  21.1  10K 
E2E  42,061  4,672  4,693  22.67  2.8K 

3 Experimental Setup
3.1 Datasets
We evaluate our model on two public datasets, namely, Penn Treebank (PTB) marcus1993building and the endtoend (E2E) text generation corpus novikova2017e2e, which have been used in a number of previous works for text generation bowman2016generating; xu2018spherical; wiseman2018learning; su2018natural. PTB consists of more than 40,000 sentences from Wall Street Journal articles whereas the E2E dataset contains over 50,000 sentences of restaurant reviews. The statistics of these two datasets are summarised in Table 1.
Model  PTB  E2E  
Standard  Inputless  Standard  Inputless  
NLL  PPL  NLL  PPL  NLL  PPL  NLL  PPL  
VAELSTMbase  101 (2)  119  125 (15)  380  50 (1.88)  5.77  101 (5.48)  34.70 
VAECNN  99 (3.1)  113  121 (16.2)  323  41 (3.02)  4.23  82 (5.95)  17.81 
vMFVAE  96 (5.7)  98  117 (18.6)  262  34 (7.63)  3.29  61 (19.58)  8.52 
HRVAE (Ours)  79 (10.4)  43  85 (17.32)  54  20 (5.37)  2.02  38 (7.78)  3.74 


Input  1. blue spice is a coffee shop in city centre . 
2. giraffe is a coffee shop found near the bakers .  
3. a pub in the city centre area called blue spice  
4. pub located near café sicilia called cocum with a high customer rating  
5. the cricketers is a one star coffee shop near the ranch that is not family friendly .  
vMFVAE  1. blue spice is a coffee in city centre . it is not , and 
2. cotto is a coffee shop located near the bakers . . is 5 out of  
3. a coffee in the city city area is blue spice spice . the is is  
4. located located near café rouge , cotto has a high customer rating and a customer  
5. the cricketers is a low rated coffee shop near the bakers that is a star , is is  
Ours  1. blue spice is a coffee shop in city centre . 
2. giraffe is a coffee shop located near the bakers .  
3. a restaurant in the city centre called blue spice italian  
4. located place near café sicilia called punter has a high customer rating  
5. the cricketers is a one star coffee shop near ranch ranch that is not family friendly .  

3.2 Implementation Details
For the PTB dataset, we used the traintest split following bowman2016generating; xu2018spherical. For the E2E dataset, we used the traintest split from the original dataset novikova2017e2e and indexed the words with a frequency higher than 3. We represent input data with 512dimensional word2vec embeddings mikolov2013distributed. We set the dimension of the hidden layers of both encoder and decoder to 256. The Adam optimiser kingma2014adam was used for training with an initial learning rate of 0.0001. Each utterance in a minibatch was padded to the maximum length for that batch, and the maximum batchsize allowed is 128.
3.3 Baselines
We compare our HRVAE model with three strong baselines using VAE for text modelling:
VAELSTMbase^{3}^{3}3https://github.com/timbmg/SentenceVAE: A variational autoencoder model which uses LSTM for both encoder and decoder. KL annealing is used to tackled the latent variable collapse issue bowman2016generating;
VAECNN^{4}^{4}4https://github.com/kefirski/contiguoussuccotash: A variational autoencoder model with a LSTM encoder and a dilated CNN decoder yang2017improved;
vMFVAE^{5}^{5}5https://github.com/jiachengxu/vmf_vae_nlp: A variational autoencoder model using LSTM for both encoder and decoder where the prior distribution is the von MisesFisher (vMF) distribution rather than a Gaussian distribution xu2018spherical.
4 Experimental Results
We evaluate our HRVAE model in two experimental settings, following the setup of bowman2016generating; xu2018spherical. In the standard setting, the input to the decoder at each time stamp is the concatenation of latent variable and the ground truth word of the previous time stamp. Under this setting, the decoder will be more powerful because it uses the ground truth word as input, resulting in little information of the training data captured by latent variable . The inputless setting, in contrast, does not use the previous ground truth word as input for the decoder. In other words, the decoder needs to predict the entire sequence with only the help of the given latent variable . In this way, a highquality representation abstracting the information of the input sentence is much needed for the decoder, and hence enforcing to learn the required information.
Overall performance. Table 2 shows the language modelling results of our approach and the baselines. We report negative log likelihood (NLL), KL loss, and perplexity (PPL) on the test set. As expected, all the models have a higher KL loss in the inputless setting than the standard setting, as is required to encode more information about the input data for reconstruction. In terms of overall performance, our model outperforms all the baselines in both datasets (i.e., PTB and E2E). For instance, when comparing with the strongest baseline vMFVAE in the standard setting, our model reduces NLL from 96 to 79 and PPL from 98 to 43 in PTB, respectively. In the inputless setting, our performance gain is even higher, i.e., NLL reduced from 117 to 85 and PPL from 262 to 54. A similar pattern can be observed for the E2E dataset. These observations suggest that our approach can learn a better generative model for data.
Loss analysis. To conduct a more thorough evaluation, we further investigate model behaviours in terms of both reconstruction loss and KL loss, as shown in Figure 2. These plots were obtained based on the E2E training set using the inputless setting.
We can see that the KL loss of VAELSTMbase, which uses Sigmoid annealing bowman2016generating, collapses to zero, leading to a poor generative performance as indicated by the high reconstruction loss. The KL loss for both VAECNN and vMFVAE are nonzero, where the former mitigates the KL collapse issue with a KL loss clipping strategy and the latter by replacing the standard normal distribution for the prior with the vMF distribution (i.e., with the vMF distribution, the KL loss only depends on a fixed concentration parameter, and hence results in a constant KL loss). Although both VAECNN and vMFVAE outperform VAELSTMbase by a large margin in terms of reconstruction loss as shown in Figure 2, one should also notice that these two models actually overfit the training data, as their performance on the test set is much worse (cf. Table 2). In contrast to the baselines which mitigate the KL collapse issue by carefully engineering the weight between the reconstruction loss and KL loss or choosing a different choice of prior, we provide a simple and elegant solution through holistic KL regularisation, which can effectively mitigate the KL collapse issue and achieve a better reconstruction error in both training and testing.
Sentence reconstruction. Lastly, we show some sentence examples reconstructed by vMFVAE (i.e., the best baseline) and our model in the inputless setting using sentences from the E2E test set as input. As shown in Table 3, the sentences generated by vMFVAE contain repeated words in quite a few cases, such as ‘city city area’ and ‘blue spice spice’. In addition, vMFVAE also tends to generate unnecessary or unrelated words at the end of sentences, making the generated sentences ungrammatical. The sentences reconstructed by our model, in contrast, are more grammatical and more similar to the corresponding ground truth sentences than vMFVAE.
5 Conclusion
In this paper, we present a simple and generic architecture called holistic regularisation VAE (HRVAE), which can effectively avoid latent variable collapse. In contrast to existing VAERNN models which merely impose a standard normal distribution prior on the last hidden state of the RNN encoder, our HRVAE model imposes regularisation on all the hidden states, allowing a better regularisation of the model learning process. Empirical results show that our model can effectively mitigate the latent variable collapse issue while giving a better predictive performance than the baselines.
Acknowledgment
This work is supported by the award made by the UK Engineering and Physical Sciences Research Council (Grant number: EP/P011829/1).