BooVAE: A scalable framework for continual VAE learning under boosting approach
Variational Auto Encoders (VAE) are capable of generating realistic images, sounds and video sequences. From practitioners point of view, we are usually interested in solving problems where tasks are learned sequentially, in a way that avoids revisiting all previous data at each stage. We address this problem by introducing a conceptually simple and scalable end-to-end approach of incorporating past knowledge by learning prior directly from the data. We consider scalable boosting-like approximation for intractable theoretical optimal prior. We provide empirical studies on two commonly used benchmarks, namely MNIST and Fashion MNIST on disjoint sequential image generation tasks. For each dataset proposed method delivers the best results or comparable to SOTA, avoiding catastrophic forgetting in a fully automatic way.
Since most of the real-world datasets are unlabeled, unsupervised learning is an important part of the machine learning field. Generative models allow us to obtain samples from observed empirical distributions of the complicated high-dimensional objects such as images, sound or texts. Moreover, they can be used to capture intrinsic structure of such objects (Donahue et al., 2016). Generative Adversarial Networks (GANs) (Goodfellow et al., 2014) and Variational Autoencoders (VAEs) (Kingma and Welling, 2013) are the most successful approaches applied to unsupervised learning of complicated empirical distributions. This work is mostly devoted to VAEs with the focus on incremental learning setting.
It was observed that VAEs ignore dimentions of latent variables and produce blurred reconstructions (Burda et al., 2015; Sønderby et al., 2016). There are several approaches to address this issues, including amortization gap reduction (Kim et al., 2018), KL-term annealing (Cremer et al., 2018) and alternative optimization objectives introduction (Rezende and Viola, 2018). In all cases, it was observed that the choice of prior distribution is highly relevant and use of default Gaussian prior overregularizes the encoder.
In this work we address the problem of constructing the optimal prior for VAE. The form of optimal prior (Tomczak and Welling, 2018) was obtained by maximizing a lower bound of the marginal likelihood (ELBO) as the aggregated posterior over the whole training dataset. To construct feasible approximation we consider a method of greedy KL-divergence projection. Applying maximum entropy approach allows to formulate feasible optimization problem and avoid overfitting.
The final algorithm of the VAE training is formulated as the iterative MM algorithm (Polson et al., 2015). Each iteration consists of two steps: on the first step, the optimal prior is approximated for current encoder-decoder pair, on the second step parameters of the encoder and decoder are updated by optimizing ELBO with the new prior.
It was shown that deep neural networks are prone to catastrophic forgetting (McCloskey and Cohen, 1989; Goodfellow et al., 2013; Nguyen et al., 2017). Several possible solutions were proposed in the literature to overcome this drawback such as weight regularization and dynamic architecture. However, most of the approaches are successfully applicable to discriminative models only.
We formulate a straightforward algorithm for incremental learning. From practitioners point of view it is essential to be able to store one model, capable of solving several tasks arriving sequentially. Hence, we propose the algorithm with one pair of encoder-decoder and update only the prior. We validate our method on the disjoint sequential image generation tasks. We consider MNIST and Fashion-MNIST datasets.
Our main contributions can be summarized as follows.
We propose the MM algorithm for training VAE model.
We propose simple and efficient algorithm for incremental learning which shares prior knowledge between tasks, keeping the single encoder-decoder pair. The algorithm of the prior optimization can be viewed in this setting as automatic coreset selection.
We empirically validate the proposed algorithm on several commonly used benchmark datasets (MNIST and Fashion-MNIST) for both offline and incremental setting.
2 Related work
2.1 Prior selection for Variational Autoencoders
Regardless of the fact that standard Gaussian distribution is a default choice of the prior distribution for Variational Autoencoders (VAEs) leading to mathematically convenient model, it has significantly restricted the expressiveness of the model (Hoffman and Johnson, 2016; Goyal et al., 2017). The choice of the prior distribution is critical for VAE performance and at the same time, it is not a trivial choice to make. One way to mitigate this problem, proposed in the literature, is to consider prior distribution as a constraint on the capacity of the latent representation (Higgins et al., 2017; Rezende and Viola, 2018).
Several works proposed to choose prior as a reflection of the assumed structure of the latent space. Davidson et al. (2018) suggest using von Mises-Fisher (vMF) distribution as prior to properly model data with spherical latent representation. Provided experiments demonstrate, that properly chosen prior results in better hyperspherical latent structure recovery. Mathieu et al. (2019) suggest training VAE with hyperbolic latent space in order to account for hierarchies in the data.
We argue, that learning the prior from data is better than specifying it ad-hoc. Goyal et al. (2017) propose to learn a non-parametric prior by nested Chinese restaurant process. Tomczak and Welling (2018) propose to train simultaneously pair of the encoder-decoder and the prior as a mixture of the variational posterior with a set of pseudo-inputs. We propose the approach of knowledge distillation from the VampPrior to avoid overfitting and obtain the corset in an automatic way. Motivated by the fact that optimal prior is learned in the form of the mixture of distributions, we propose to use it for incremental learning.
2.2 Incremental Learning
The significant difference between offline and incremental learning lies in the availability of training dataset. The offline setting for both discriminative and generative models assumes that data comes as i.i.d. from the fixed empirical training dataset. In contrast, incremental learning assumes that distribution of the training dataset changes over time, for example data from new class arrive. Notably, the real-world applications of the deep models should be considered as the incremental learning, as distribution of the data could change drastically over time and we could not keep the whole dataset for computational or privacy issues.
Despite the importance of unsupervised learning, the incremental learning setting is mainly considered for discriminative models. Several directions to overcome catastrophic forgetting were suggested such as weight regularization and dynamic architecture. Weight regularization approach such as the Elastic Weight Consolidation (EWC) (Kirkpatrick et al., 2017) add to the objective quadratic regularization (Liu et al., 2018; Huszár, 2017) which form the trust-region for the weights around the optimum of the previous task, while the model architecture remains the same. In contrast, other approach is based on the incrementally updating architectures of the nets, increasing its capacity and sharing feature extractor across tasks at the same time (Rusu et al., 2016; Li and Hoiem, 2018).
The notional dissimilarity makes it impossible to directly apply methods for discriminative models to generative ones. Seff et al. (2017) consider EWC-augmented loss to the generator for continual learning with generative adversarial nets (GANs). However, the approach was successfully applied only to the class-conditional GANs. Nguyen et al. (2017) proposed a dynamic architecture for VAEs, where the specific head is used for each task. Another class of works consider using generative models to generate historical data and overcome catastrophic forgetting of classifier learning problem (Wu et al., 2018).
We propose a novel incremental learning approach for the VAE generative model with the static architecture. Suggested model incorporates knowledge from the previous task by prior updating, which connects our work with loss-augmentation approach.
2.3 Greedy KL minimization
The approximation of the unnormalized distribution with the sequential mixture models has been considered previously in several studies. A number of works (Miller et al., 2017; Gershman et al., 2012) perform direct optimization of ELBO with respect to the parameters of the new component. Unfortunately, it leads to a complex optimization problem. Therefore, other directions of works consider functional Frank-Wolfe framework, where subproblems are linearized (Wang et al., 2015). At each step, KL-divergence functional is linearized at the current approximation point by its convex perturbation, obtaining tractable optimization subproblems over distribution space for each component. Different regularization to avoid degenerate solutions were proposed. (Guo et al., 2016) suggest using concave log-det regularization for gaussian base learners. (Egorov et al., 2019) formulated MaxEntropy Pursuit Variational Inference (MEPVI) algorithm from the Maximum Entropy principle, where each sub-problem is represented as max entropy problem with linear constraint.
In present work, we consider knowledge distillation from the prior based on the MEPVI approach.
3.1 Variational Autoencoders
VAEs consider two-step generative process by a prior over latent space and a conditional generative distribution , which is parametrized by a deep neural network (DNN). Our goal is to maximize marginal log-likehood which is intractable in general case. Therefore, variational inference (VI) framework is considered.
where is a variational posterior distribution. Given data distribution we aim at maximizing the average marginal log-likelihood. Following the variational auto-encoder architecture amortized inference is proposed by choice of the variational distribution which is also parametrized by DNN.
Tomczak and Welling (2018) propose to combine variational and empirical bayes approaches and optimize the objective (equation (1)) over the prior distribution. If the empirical data density for the observed dataset is then optimal prior has the following form:
Clearly, such prior leads to the the overfitting. Hence with keeping same functional form, truncated version with presudoinputs was proposed (Tomczak and Welling, 2018) as VampPrior:
Hence, algorithm of the VAE learning is the simultaneously optimization:
In present paper we address two important drawbacks of the VampPrior (2). Firstly, for large values of variational inference with such prior distribution will be very computationally expensive. Even for the MNIST dataset Tomczak and Welling (2018) used the mixture of components. Secondly, it is not clear how to choice the for the particular dataset, as we have trade-off between prior capacity and overfitting. For this purpose we adapt maxentropy variational inference framework. We add components to the prior during training in a greedy manner and show that.
3.2 MaxEntropy Variational Inference
Assume, that we want to approximate complex distribution by a mixture of simple components
We are going to learn each component of the mixture greedily. At the first stage we initialize mixture by a standard normal distribution
Afterwards, we add new component from some family of distributions , e.g. Gaussians, with the weight one by one . We do it in two stages:
Choose , corresponding to the from the previous step
In order to choose optimal , we suggest applying Maximum Entropy approach.
It can be shown that such problem can be reformulated as a -minimization problem (3). See Appendix A for the detailed derivation.
Assume, that we’ve found from (3), now we should determine it’s weight in the mixture. Since we want this mixture to be as close as possible to the target density, we will optimize KL-divergence between the two:
Figure 1 illustrates how boosting can approximate complex mixture of gaussians by only two components.
In this paper we suggest combining boosting algorithm for density distributions with the idea of optimal (in terms of empirical Bayes approach) prior for Variational Autoencoder to deal with problem of catastrophic forgetting in the incremental learning setting. In this section, firstly, we describe how BooVAE can be applied when the whole dataset is available for training. Then we switch to the incremental learning setting.
Proposed algorithm for training Variational Autoencoder consists of two steps. On the first one we optimize evidence lower bound ((1)) with respect to the parameters of the encoder and decoder. On the second stage, we learn new component for the prior distribution and it’s weight , keeping parameters of the encoder and decoder fixed. Target prior distribution is chosen to be a mixture of variational posteriors, following Tomczak and Welling (2018):
We learn each component to be posterior given the learnable artificial input : . Parameters of the first component are obtained by ELBO maximization simultaneously with the network parameters, as shown in the Algorithm 1. Starting from the second component, we apply MaxEntropy Variational Inference for component optimization, alternated with ELBO maximization steps until the desired number of components in prior is reached. From that point further on, only model parameters are updated until convergence.
4.2 Incremental learning
In the incremental learning setting we do not have access to the whole dataset. Instead, subsets arrive sequentially and may come from different domains. With the first task we follow Algorithm 1 to obtain prior and optimal values of the network parameters. Let us now consider training procedure of task . Each epoch consists of two steps: maximizing ELBO w.r.t. encoder and decoder parameters and adding a new component. Now, we are going to consider both of them in more detail.
4.2.1 ELBO maximization step
When we start training a new task, we already have prior distribution of the following form:
That is being said, we store prior components from the previous task in the form of Gaussian mean and variance vectors, computed using optimal encoder parameters for the given task (after the convergence).
To make sure that model keeps encoding and decoding prior component as it did during training of the corresponding task, we add regularization term to ELBO. For that purpose at the end of each task we compute reconstructions of the components’ mean value .
All thing considered, objective for the maximization step is the following:
4.2.2 Prior update step
When training a new component and its’ weight, we want optimal prior that we approximate to be close to the mixture of variational posteriors at all the training points, seen by the model. Since we don’t have access to the data from the previous tasks, we suggest using trained prior as a proxy for the corresponding part of the mixture. Therefore, optimal prior for tasks can be expressed, using optimal prior from the previous step and current training dataset:
During training, we approximate optimal prior by the mixture , using random subset from the given dataset , containing observations and, therefore, the optimal prior can be approximated in the following way:
This allows an algorithm not to forget information from the previous task, which is stored in the prior distribution. We can add new tasks in the same manner as long as we have them, increasing the number of components in the prior, keeping one encoder-decoder pair for all the tasks.
That being said, prior update step is the following:
Train new component:
Train component weight
Update prior, using above component and weight
We perform experiments on MNIST dataset (LeCun et al., 2010), containing ten hand-written digits and on fashion MNIST (Xiao et al., 2017), which also has ten classes of different clothing items. Therefore, we have ten tasks in each dataset for sequential learning.
To evaluate the performance of the VAE approach, we estimate a negative log likelihood (NLL) on the test set. NLL is calculated by importance sampling method, using 5000 samples for each test observation.
We also want our model to generate diverse images in terms of classes. To ensure this property, we calculate KL-divergence between uniform distribution and distribution of classes generated by the model.
Since we want to assign classes to generated images automatically, we train classification network, which can classify images with high probability (more than 90%), use it to label generated objects and calculated the empirical distribution over 10000 generated samples.
5.2 Offline setting
In an offline setting we aim at comparing our method to VampPrior and Mixture of Gaussians prior. In both cases the goal was to achieve similar performance, using less components. On Figure 2 comparison of the NLL values for different number of components in the prior is presented. The graph clearly illustrates that BooVAE not only shows performance comparable to other data-driven priors, but also requires less components.
5.3 Incremental setting
In tables 1, 2 we provide NLL results for incremental setting. First columns states how many tasks did the VAE see in total (in an incremental manner). We compare our approach with VAE with standard normal prior, which is a default setting for this model. Elastic weight consolidation (EWC), model proposed by Kirkpatrick et al. (2017) to deal with catastrophic forgetting, was also estimated for the comparison. The idea is to add reguralization term to the loss function, which enforces the weights to stay closer to the optimal solution of the previous tasks. Note that this method can be easily combined with different types of prior. Finally, we train incremental MoG, where we add new components to the mixture for every task, just as we do in boosting, and apply the same regularization for the fair comparison.
|# tasks||Standard||Standard + EWC||MoG||Boo (ours)||Boo (ours) + EWC|
|2||343.54 (26.38)||256.55 (8.38)||96.50 (1.95)||100.11 (1.39)||97.49 (0.40)|
|3||122.05 (2.31)||121.91 (1.31)||107.78 (3.59)||104.33 (1.34)||102.90 (0.95)|
|4||146.06 (0.32)||142.00 (2.28)||123.95 (5.05)||118.78 (1.45)||117.07 (0.46)|
|5||197.02 (5.68)||192.84 (0.12)||143.44 (8.05)||132.08 (0.64)||130.80 (0.87)|
|6||164.29 (3.78)||159.80 (3.14)||143.33 (2.49)||135.42 (1.64)||131.83 (1.24)|
|7||205.21 (5.58)||187.43 (5.20)||163.14 (9.02)||142.21 (1.85)||137.38 (1.57)|
|8||213.25 (9.22)||189.06 (4.72)||172.00 (12.93)||140.80 (2.42)||138.47 (2.50)|
|9||171.04 (3.64)||160.47 (2.53)||164.18 (9.49)||141.70 (0.97)||140.13 (2.67)|
|10||186.79 (2.32)||170.26 (2.20)||181.53 (29.02)||142.92 (1.99)||140.68 (1.86)|
|# tasks||Standard||Standard + EWC||MoG||Boo (ours)||Boo (ours) + EWC|
|2||262.22 (2.92)||271.14 (6.05)||239.43 (2.76)||227.83 (3.34)||229.81 (2.31)|
|3||289.45 (2.72)||287.45 (3.87)||266.18 (1.87)||255.85 (1.61)||256.47 (2.16)|
|4||274.08 (2.42)||272.82 (1.02)||264.35 (3.16)||248.96 (0.85)||249.08 (1.40)|
|5||272.87 (1.80)||270.44 (0.98)||264.51 (1.93)||253.12 (1.43)||253.26 (1.20)|
|6||487.05 (43.78)||417.81 (7.44)||282.00 (4.06)||250.87 (2.12)||250.64 (1.18)|
|7||274.72 (2.93)||272.09 (6.25)||292.93 (9.16)||250.87 (0.69)||253.50 (2.59)|
|8||1827.62 (489.47)||565.81 (22.94)||448.55 (103.92)||260.05 (5.25)||250.30 (0.48)|
|9||321.49 (17.36)||289.17 (2.43)||321.72 (14.11)||256.42 (1.00)||256.33 (0.78)|
|10||964.90 (237.27)||427.83 (21.19)||440.96 (49.75)||284.86 (21.21)||256.58 (1.27)|
Compared to other methods, boosting performs quite good. Instability of other methods can be explained by the fact that some classes in the dataset are quite similar and even though model forgets old class, knowledge about a new class let her reconstruct it relatively good. We do not see such oscillations for boosting, which indicates that it was able to properly remember previous tasks.
There are also works, where multihead architectures are used to improve performance in the incremental setting. It means, that task-specific layers are being added to the network. In out experiments, the architecture was fixed, that is being said we only increase number of components in the prior and it is not really fair to compare the two. Nevertheless, we provide results for multihead approach with standard normal prior and EWC regularization, which is presented in Tables 3, 4.
|# tasks||Standard||Multihead||Multihead + EWC||Boo (ours) + EWC|
|2||343.54 (26.38)||399.23 (6.81)||119.64 (18.54)||97.49 (0.40)|
|3||122.05 (2.31)||203.87 (1.55)||114.22 (6.92)||102.90 (0.95)|
|4||146.06 (0.32)||217.63 (3.50)||115.31 (6.97)||117.07 (0.46)|
|5||197.02 (5.68)||281.46 (2.56)||115.86 (8.04)||130.80 (0.87)|
|6||164.29 (3.78)||215.02 (2.31)||113.51 (5.78)||131.83 (1.24)|
|7||205.21 (5.58)||247.58 (4.22)||113.69 (6.09)||137.38 (1.57)|
|8||213.25 (9.22)||300.97 (5.82)||112.23 (6.63)||138.47 (2.50)|
|9||171.04 (3.64)||210.58 (5.03)||112.31 (5.91)||140.13 (2.67)|
|10||186.79 (2.32)||256.55 (6.26)||111.30 (5.60)||140.68 (1.86)|
|# tasks||Standard||Multihead||Multihead + EWC||Boo (ours) + EWC|
|2||262.22 (2.92)||393.79 (10.75)||249.66 (7.76)||229.81 (2.31)|
|3||289.45 (2.72)||319.98 (0.38)||256.63 (1.12)||256.47 (2.16)|
|4||274.08 (2.42)||328.72 (2.83)||245.16 (1.80)||249.08 (1.40)|
|5||272.87 (1.80)||333.74 (4.15)||251.48 (0.93)||253.26 (1.20)|
|6||487.05 (43.78)||681.65 (15.13)||250.57 (1.39)||250.64 (1.18)|
|7||274.72 (2.93)||375.84 (3.24)||251.46 (1.01)||253.50 (2.59)|
|8||1827.62 (489.47)||597.63 (43.93)||244.55 (1.59)||250.30 (0.48)|
|9||321.49 (17.36)||490.90 (26.85)||255.41 (1.73)||256.33 (0.78)|
|10||964.90 (237.27)||566.25 (13.94)||254.43 (2.10)||256.58 (1.27)|
As it was mentioned earlier, we want our model to generate diverge images. Figure 3 depicts KL-divergence between distribution for generated images and uniform distribution, evaluated on 10000 samples. We want this value to stay as close as possible to 0 as the number of tasks grows since it will mean that model keeps generating diverse images. We can see a drastic difference between boosting and other approaches: samples from prior, estimated by the boosting approach are very close to uniform in contrast to all the comparable methods.
In this work we propose a method for learning a data-driven prior, using a MM algorithm which allows us to reduce number of components in the prior distribution without the loss of performance. Based on this method we suggest an efficient algorithm for incremental VAE learning which has single encode-decoder pair for all the tasks and drastically reduces catastrophic forgetting. In the experiments we also show that results of the proposed approach are comparable with the methods with multihead architectures which have a lot more parameters.
The Authors acknowledge the usage of the Skoltech CDISE HPC cluster Zhores for obtaining the results presented in this paper.
- Importance weighted autoencoders. arXiv preprint arXiv:1509.00519. Cited by: §1.
- Inference suboptimality in variational autoencoders. In International Conference on Machine Learning, pp. 1086–1094. Cited by: §1.
- Hyperspherical variational auto-encoders. arXiv preprint arXiv:1804.00891. Cited by: §2.1.
- Adversarial feature learning. arXiv preprint arXiv:1605.09782. Cited by: §1.
- MaxEntropy pursuit variational inference. In International Symposium on Neural Networks, pp. 409–417. Cited by: §2.3.
- Nonparametric variational inference. In Proceedings of the 29th International Coference on International Conference on Machine Learning, pp. 235–242. Cited by: §2.3.
- An empirical investigation of catastrophic forgetting in gradient-based neural networks. arXiv preprint arXiv:1312.6211. Cited by: §1.
- Generative adversarial nets. In Advances in neural information processing systems, pp. 2672–2680. Cited by: §1.
- Nonparametric variational auto-encoders for hierarchical representation learning. In Proceedings of the IEEE International Conference on Computer Vision, pp. 5094–5102. Cited by: §2.1, §2.1.
- Boosting variational inference. arXiv preprint arXiv:1611.05559. Cited by: §2.3.
- Beta-vae: learning basic visual concepts with a constrained variational framework. In International Conference on Learning Representations, Vol. 3. Cited by: §2.1.
- Elbo surgery: yet another way to carve up the variational evidence lower bound. In Workshop in Advances in Approximate Bayesian Inference, NIPS, Vol. 1. Cited by: §2.1.
- On quadratic penalties in elastic weight consolidation. arXiv preprint arXiv:1712.03847. Cited by: §2.2.
- Semi-amortized variational autoencoders. In International Conference on Machine Learning, pp. 2683–2692. Cited by: §1.
- Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114. Cited by: §1.
- Overcoming catastrophic forgetting in neural networks. Proceedings of the national academy of sciences 114 (13), pp. 3521–3526. Cited by: §2.2, §5.3.
- MNIST handwritten digit database. AT&T Labs [Online]. Available: http://yann. lecun. com/exdb/mnist 2, pp. 18. Cited by: §5.1.
- Learning without forgetting. IEEE transactions on pattern analysis and machine intelligence 40 (12), pp. 2935–2947. Cited by: §2.2.
- Rotate your networks: better weight consolidation and less catastrophic forgetting. In 2018 24th International Conference on Pattern Recognition (ICPR), pp. 2262–2268. Cited by: §2.2.
- Hierarchical representations with poincar’e variational auto-encoders. arXiv preprint arXiv:1901.06033. Cited by: §2.1.
- Catastrophic interference in connectionist networks: the sequential learning problem. In Psychology of learning and motivation, Vol. 24, pp. 109–165. Cited by: §1.
- Variational boosting: iteratively refining posterior approximations. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 2420–2429. Cited by: §2.3.
- Variational continual learning. arXiv preprint arXiv:1710.10628. Cited by: §1, §2.2.
- Proximal algorithms in statistics and machine learning. Statistical Science 30 (4), pp. 559–581. Cited by: §1.
- Taming vaes. arXiv preprint arXiv:1810.00597. Cited by: §1, §2.1.
- Progressive neural networks. arXiv preprint arXiv:1606.04671. Cited by: §2.2.
- Continual learning in generative adversarial nets. arXiv preprint arXiv:1705.08395. Cited by: §2.2.
- Ladder variational autoencoders. In Advances in neural information processing systems, pp. 3738–3746. Cited by: §1.
- VAE with a vampprior. In International Conference on Artificial Intelligence and Statistics, pp. 1214–1223. Cited by: §1, §2.1, §3.1, §4.1.
- Functional frank-wolfe boosting for general loss functions. arXiv preprint arXiv:1510.02558. Cited by: §2.3.
- Incremental classifier learning with generative adversarial networks. arXiv preprint arXiv:1802.00853. Cited by: §2.2.
- Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. arXiv preprint arXiv:1708.07747. Cited by: §5.1.
A. MaxEntropy VI: objective derivation
Expanding second term of the objective, we get:
Using first-order Taylor expansion we can simplify some terms:
The constraint becomes of the following form:
Omitting terms, which do not depend on and considering only first order terms result in optimization problem: