Continual Learning in Deep Neural Network by Using a Kalmana Optimiser
Abstract
Learning and adapting to new distributions or learning new tasks sequentially without forgetting the previously learned knowledge is a challenging phenomenon in continual learning models. Most of the conventional deep learning models are not capable of learning new tasks sequentially in one model without forgetting the previously learned ones. We address this issue by using a Kalman Optimiser. The Kalman Optimiser divides the neural network into two parts: the long-term and short-term memory units. The long-term memory unit is used to remember the learned tasks and the short-term memory unit is to adapt to the new task. We have evaluated our method on MNIST, CIFAR10, CIFAR100 datasets and compare our results with state-of-the-art baseline models. The results show that our approach enables the model to continually learn and adapt to the new changes without forgetting the previously learned tasks.
1 Introduction
Conventional deep learning models have achieved significant successes in a variety of fields including computer vision and speech recognition. However, most of the dominant models have to be trained with all the expected tasks or variations in the data at the same time. Otherwise, they tend to forget the learned knowledge when they switch between different tasks and various datasets in a periodically changing environment. The problem that the model forgets how to perform on the previously learned tasks is often referred to as catastrophic forgetting (McCloskey & Cohen, 1989; Goodfellow et al., 2013). This issue often occurs when a model adjusts its parameters to cater for new tasks and when the newly set parameters are not suitable anymore to provide accurate results to the previously learned tasks when they occur again. The parameters may significantly change when the training task is very different from previously learned ones. For example, in Figure 1, training a neural network is aimed at finding an ideal solution (i.e. the red circle in Figure 1) for the training data. When the model continually learns with a new set of data, the neural network will find another ideal solution (i.e. the green circle in Figure 1). The new solution could be very different from the previous one. Consequently, the performance of the neural network on the previously learned task(s) would decrease. In other words, the neural network forgets how to perform on the old dataset.
This learning process is very different from the biological learning process which can acquire new knowledge sequentially. In real-world scenarios, we cannot ensure that our training data is the most representative set and it may not cover all the tasks in advance. Possible solutions to address this issue include training a new model for each new task, re-training the model when the previously learned tasks reappear or are required again and storing all the training data and frequently training the model based on the whole datasets. Storing all the training data is very inefficient and requires high resource. Training the model again and again for re-appearing tasks and goals is very inefficient and computationally costly. Our goal is to develop a continual learning model in a way that the model can learn new tasks without forgetting the previously learned ones.
To solve this problem, ensemble learning is among the solutions that are proposed in the existing work (Woźniak et al., 2014; Polikar et al., 2001; Dai et al., 2007). The fundamental idea in ensemble learning is to build a network for each task or learning target. As a consequence, networks would be created ( is the number of the tasks to be learned). This method is not always a desirable solution because of high memory requirements and complexity (Kemker et al., 2017). (Fernando et al., 2017) propose PathNet which is based on ensemble learning but offers less complexity. In PathNet, the learned networks can contribute to train the model while learning new tasks. Different from ensemble learning, the dual-memory-based learning approach interleaves the new training data with the learned samples offers less complexity. (Lopez-Paz et al., 2017) propose Gradient Episodic Memory (GEM) to deal with the forgetting issue. GEM stores a subset of the samples that are used in the new training process. While training on new tasks, the losses on the old samples are only allowed to decrease. Other recent developments in this direction reduce the memory requirements of old knowledge by leveraging a pseudo-rehearsal technique; e.g., (Robins, 1995) construct probabilistic models to generate training samples (based on what has been seen in the past) to reduce the memory resources required for storing and reloading large volumes of old samples. (Shin et al., 2017) propose an architecture consisting of deep generative models to generate training samples and task solvers to perform a classification task. Similarly, (Kamra et al., 2017) use a variational autoencoder to generate the training samples. (Hinton & Plaut, 1987) propose a dual-memory system in which each synaptic connection has two weights: a plastic weight with slow change rate for long-term knowledge and a fast-changing weight for short-term knowledge. Another solution called Regularisation-based method does not need to produce multiple models or store the trained samples. (Kirkpatrick et al., 2017) propose the Elastic Weight Consolidation (EWC) to overcome the forgetting problem. EWC pulls back the changes of parameters when the models require to carry out a previously learned task. (Huszár, 2017) extended the EWC to a continuous technique that learns the new tasks recursively. Similarly, (Zenke et al., 2017) attempt to address the forgetting problem by avoiding to change the influential parameters that contribute to carrying out the previously learned tasks.
Although there are various works to address the forgetting problem in continual learning scenarios, few of the existing work proposes an optimiser or leverages the uncertainty information in the neural networks to resolve the catastrophic forgetting problem. In this paper, we propose a Kalman-Filter-based optimiser. Some existing works that combine the Kalman Filter with neural networks. For instance, (Ollivier, 2017) demonstrate an extended Kalman Filter to estimate the parameters similar to an online stochastic natural gradient descent. (Trebatickỳ, 2005) use a Kalman Filter to train a Recurrent neural network. (Haarnoja et al., 2016) combine the Kalman Filter with a feed-forward Convolutional neural network. (Wu & Wang, 2012) propose an Extended and Unscented Kalman Filter training algorithm for training feed-forward neural networks. Different from previous work, our key idea is to find an optimal solution for all the tasks by letting different part of the parameters adapt to different tasks, see Figure 1. More specifically, we obtain the important parameters in the previously learned tasks and group them in long-term memory units. The other parameters are grouped into short-term memory units. Different from the classical Long Short Term Memory networks (LSTM) and (Hinton & Plaut, 1987), we only use a single neural network layer, and the parameters that are involved in each of the units are dynamically decided at the end of each training process. We use a Kalman Filter to restrict the changes of the parameter in the long-term memory units. Then the short-term memory units adapt to the new tasks.
2 Methodology
2.1 Gradient as an Uncertainty Measure
We consider the weights as the indeterminate values instead of the deterministic values. Different from the Bayesian neural networks (Blundell et al., 2015), our goal is to track the changes of the values and uncertainties of the parameters and then adjust them. To determine the uncertainty, we use the gradient as a measure since it reflects how the model is uncertain with the current parameters given the current data. From the gradient descent point of view, if the parameter is far from the optimal value, the gradient of the parameter is larger. In other words, the model would be highly uncertain with this parameter given the training data. For example, in Figure 1, the training parameters take a larger step at the beginning of training since they are far from the optimal value. As the training continues, the step becomes smaller and smaller. Later on, and when this model is trained on a new task, this process is repeated.
2.2 Kalman Optimizer
Based on the values of the weights and the uncertainty measure, we restrict the changes of the weights to let the new value have lower uncertainty by using a Kalman Filter (Rhudy et al., 2017; Enshaeifar et al., 2016). At the end of the first training, we consider the uncertainty and the optimal solution of our first task as our prior knowledge. During the training for the second task, we track the changes in the weights and uncertainty. We then use a Kalman Filter to adjust the weights based on our prior knowledge. The predicted values of the Kalman Filter would be close to the values that will result in lower uncertainty. Given the initial information and , where is the set of parameters in the pre-trained model and is the uncertainty of model on the previously trained dataset or task, is a very small hyper-parameter in the case that the denominator is zero. While training on the new dataset or for a new task, according to the gradient descent process and mini-batch algorithm, at batch , we can predict the value and obtain an uncertainty measure which refers to the gradient of the parameter on the batch data. The Kalman Gain , the optimal value and the uncertainty can be calculated according to the following equations:
(1a) | ||||
(1b) | ||||
(1c) | ||||
(1d) |
Since is the gradient of the model on the new dataset or task, it would be relatively higher than , especially in the beginning of the training process. This means would decrease more rapidly compared with . Hence, the predicted values would be close to the previous optimal solution.
However, the model cannot learn the new task very well, if all the weights are close to the previous optimal solution. To address this issue, we let the Kalman Optimiser identify which part of the parameters should adjust to remember the learned knowledge and which part of the parameters should have less influence on learning the new task. To achieve this goal, we find the important parameters to the previously learned tasks and group them as the long-term memory, and group the rest of the parameters as the short-term memory. To find the important parameters, we use the Fisher Information matrix. The Fisher Information matrix is the approximation of the second order derivatives of the loss near a minimal point. We assume that the covariance matrix of the posterior distribution for a trained task is diagonal and obtain the Fisher Information by using Equation (2), where represents the parameters, is the training dataset, is the Fisher Information matrix.
(2) |
We further normalise the Fisher Information by using Equation (3) to obtain the rate of importance for the parameters:
(3) |
We then use the normalised Fisher Information to divide the units into two different categories (i.e. long-term and short-term memory units). We set a threshold to distinguish the boundary to choose the significant parameters that are involved in long-term memory. The final update procedure is shown in Equation (4) where is a pre-defined threshold. Hence the Kalman Optimise can identify what parameters to adjust instead of adjusting all of them.
(4a) | ||||
(4b) | ||||
(4c) |
While learning several tasks, (Kirkpatrick et al., 2017) use multiple Fisher Information matrices to apply the constraints on parameters. This requires high computation and increases the complexity. Computing and storing large number of Fisher Information can also quickly become intractable. In our proposed method, only one Fisher Information matrix is needed. In the Kalman Optimizer, only the larger value among the different Fisher Information Matrix will be stored in the memory. We update the normalised Fisher Information recursively and this addresses the complexity and scalability issues related to the computation and storing large number of Fisher Information.
The last step is to update the uncertainty information. Up to this stage, all the prior knowledge is updated by the Kalman Filter. At this stage, we update our prior knowledge at the end of each training, in case the Kalman Filter is converged into a constant. Mathematically, , which is the uncertainty of the model on the learned tasks, could be a minimal value at the end of the training process. If the uncertainty is not updated, the value predicted by the Kalman Optimiser will be close to a constant. Hence, after training task , if the parameter is more important to task than other tasks, we then update the uncertainty of this parameter given the data in task . The other uncertainties remain the same since we cannot access to the previous dataset anymore in our experiments. Furthermore, these small uncertainties are very helpful to preserve the learned knowledge.
3 Experiments
We evaluate our method by sequentially learning the disjoint MNIST (LeCun et al., 2010), CIFAR10 and CIFAR100 datasets (Krizhevsky et al., 2009). These experiments are commonly used to evaluate the performance of the methods that address the forgetting problem in continual learning scenarios.
3.1 Disjoint MNIST
The first experiment is conducted by using the disjoint MNIST dataset. We split the MNIST dataset into 5 subsets of consecutive digits from 0 to 10. We use a shallow neural network containing two hidden layers consisting 256 units. Due to the label distribution changes in different task, we use a multi-head approach which only computes the loss for the digits present in the current task (Zenke et al., 2017). The results are shown in Figure 2.
We compare the performance of the model with and without the proposed optimiser. To evaluate the performance, we compute the average accuracy of the model on all the tasks. The aim of the training is to have an average accuracy close to 1.0 after training the model on all the tasks. The first two figures in Figure 2 show that the performance of the neural network drops to a guesstimate level (less than 50%) after training on all the tasks. The performance of the third task is also decreased. However, the neural network with the proposed optimiser maintains the performance on all the learned tasks. The accuracy of the proposed method on the first two tasks stays close to 1.0, while the average accuracy of the model on all the tasks keeps increasing.
3.2 Disjoint CIFAR10 and CIFAR100
This experiment consists of two parts. In the first part, we evaluate our method based on the disjoint CIFAR10 dataset. We split the CIFAR10 into two subsets of consecutive classes. We use a Convolutional Neural Network containing four convolutional layers, two fully-connected layers and also the multi-head approach. The results are shown in Figure 3. As shown in Figure 3(a), the performance of the common model on the first task after learning the subsequent tasks and then by revisiting the first task decreases dramatically. However, the proposed method still remembers how to perform on the first task. The fraction accuracy, which is the average accuracy on all previous tasks (Figure 3(c)), of the proposed method keeps increasing while the fraction accuracy of the baseline model decreases steadily. In the second part, we evaluate our method on CIFAR100. The first task is the original CIFAR10 dataset, the second and third tasks are ten different classes from CIFAR100. The other settings are the same. The results are shown in Figure 4. The performance of the model on the learned tasks decreases while the proposed method remembers how to perform the learned tasks.
4 Conclusions
We present a novel optimisation method to learn and adapt to the new tasks without forgetting the previously learned tasks. The key idea is to find an optimal solution by letting different parts of the parameters adapt to different tasks. The proposed method uses the gradients to obtain an uncertainty measure and groups the learning parameters into long-term and short-term memory units. These units define which parameters are restricted to be changed by the Kalman update procedure (i.e. long-term memory). This update procedure adds an adjustment and control mechanism to allow the model to learn new tasks without significantly forgetting the previously learned ones. We evaluate our method based on the disjoint MNIST and CIFAR10 datasets, and also by continually learning from CIFAR10 to CIFAR100 and compared the results to a baseline. The results show that the proposed method can preserve the previously learned knowledge and efficiently learn and adapt to the new tasks.
Acknowledgments
This work is partially supported by the EU H2020 IoTCrawler project under contract number: 779852.
References
- Blundell et al. (2015) Blundell, Charles, Cornebise, Julien, Kavukcuoglu, Koray, and Wierstra, Daan. Weight uncertainty in neural networks. arXiv preprint arXiv:1505.05424, 2015.
- Dai et al. (2007) Dai, Wenyuan, Yang, Qiang, Xue, Gui-Rong, and Yu, Yong. Boosting for transfer learning. In Proceedings of the 24th International Conference on Machine Learning, ICML ’07, pp. 193–200, New York, NY, USA, 2007. ACM. ISBN 978-1-59593-793-3. doi: 10.1145/1273496.1273521. URL http://doi.acm.org/10.1145/1273496.1273521.
- Enshaeifar et al. (2016) Enshaeifar, Shirin, Spyrou, Loukianos, Sanei, Saeid, and Took, Clive Cheong. A regularised eeg informed kalman filtering algorithm. Biomedical Signal Processing and Control, 25:196–200, 2016.
- Fernando et al. (2017) Fernando, Chrisantha, Banarse, Dylan, Blundell, Charles, Zwols, Yori, Ha, David, Rusu, Andrei A, Pritzel, Alexander, and Wierstra, Daan. Pathnet: Evolution channels gradient descent in super neural networks. arXiv preprint arXiv:1701.08734, 2017.
- Goodfellow et al. (2013) Goodfellow, Ian J, Mirza, Mehdi, Xiao, Da, Courville, Aaron, and Bengio, Yoshua. An empirical investigation of catastrophic forgetting in gradient-based neural networks. arXiv preprint arXiv:1312.6211, 2013.
- Haarnoja et al. (2016) Haarnoja, Tuomas, Ajay, Anurag, Levine, Sergey, and Abbeel, Pieter. Backprop kf: Learning discriminative deterministic state estimators. In Advances in Neural Information Processing Systems, pp. 4376–4384, 2016.
- Hinton & Plaut (1987) Hinton, Geoffrey E and Plaut, David C. Using fast weights to deblur old memories. In Proceedings of the ninth annual conference of the Cognitive Science Society, pp. 177–186, 1987.
- Huszár (2017) Huszár, Ferenc. On quadratic penalties in elastic weight consolidation. arXiv preprint arXiv:1712.03847, 2017.
- Kamra et al. (2017) Kamra, Nitin, Gupta, Umang, and Liu, Yan. Deep generative dual memory network for continual learning. arXiv preprint arXiv:1710.10368, 2017.
- Kemker et al. (2017) Kemker, Ronald, McClure, Marc, Abitino, Angelina, Hayes, Tyler, and Kanan, Christopher. Measuring catastrophic forgetting in neural networks. arXiv preprint arXiv:1708.02072, 2017.
- Kirkpatrick et al. (2017) Kirkpatrick, James, Pascanu, Razvan, Rabinowitz, Neil, Veness, Joel, Desjardins, Guillaume, Rusu, Andrei A, Milan, Kieran, Quan, John, Ramalho, Tiago, Grabska-Barwinska, Agnieszka, et al. Overcoming catastrophic forgetting in neural networks. Proceedings of the national academy of sciences, pp. 201611835, 2017.
- Krizhevsky et al. (2009) Krizhevsky, Alex, Nair, Vinod, and Hinton, Geoffrey. Cifar-10 and cifar-100 datasets. URl: https://www. cs. toronto. edu/kriz/cifar. html, 6, 2009.
- LeCun et al. (2010) LeCun, Yann, Cortes, Corinna, and Burges, CJ. Mnist handwritten digit database. AT&T Labs [Online]. Available: http://yann. lecun. com/exdb/mnist, 2, 2010.
- Lopez-Paz et al. (2017) Lopez-Paz, David et al. Gradient episodic memory for continual learning. In Advances in Neural Information Processing Systems, pp. 6467–6476, 2017.
- McCloskey & Cohen (1989) McCloskey, Michael and Cohen, Neal J. Catastrophic interference in connectionist networks: The sequential learning problem. In Psychology of learning and motivation, volume 24, pp. 109–165. Elsevier, 1989.
- Ollivier (2017) Ollivier, Yann. Online natural gradient as a kalman filter. arXiv preprint arXiv:1703.00209, 2017.
- Polikar et al. (2001) Polikar, Robi, Upda, Lalita, Upda, Satish S, and Honavar, Vasant. Learn++: An incremental learning algorithm for supervised neural networks. IEEE transactions on systems, man, and cybernetics, part C (applications and reviews), 31(4):497–508, 2001.
- Rhudy et al. (2017) Rhudy, Matthew B, Salguero, Roger A, and Holappa, Keaton. A kalman filtering tutorial for undergraduate students. International Journal of Computer Science & Engineering Survey (IJCSES), 8:1–18, 2017.
- Robins (1995) Robins, Anthony. Catastrophic forgetting, rehearsal and pseudorehearsal. Connection Science, 7(2):123–146, 1995.
- Shin et al. (2017) Shin, Hanul, Lee, Jung Kwon, Kim, Jaehong, and Kim, Jiwon. Continual learning with deep generative replay. In Advances in Neural Information Processing Systems, pp. 2990–2999, 2017.
- Trebatickỳ (2005) Trebatickỳ, Peter. Recurrent neural network training with the extended kalman filter. In IIT. SRC 2005: Student Research Conference, pp. 57, 2005.
- Woźniak et al. (2014) Woźniak, Michał, Graña, Manuel, and Corchado, Emilio. A survey of multiple classifier systems as hybrid systems. Information Fusion, 16:3–17, 2014.
- Wu & Wang (2012) Wu, Xuedong and Wang, Yaonan. Extended and unscented kalman filtering based feedforward neural networks for time series prediction. Applied Mathematical Modelling, 36(3):1123–1131, 2012.
- Zenke et al. (2017) Zenke, Friedemann, Poole, Ben, and Ganguli, Surya. Continual learning through synaptic intelligence. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 3987–3995. JMLR. org, 2017.