Keep and Learn: Continual Learning by Constraining the Latent Space for Knowledge Preservation in Neural Networks

Keep and Learn: Continual Learning by Constraining the Latent Space for Knowledge Preservation in Neural Networks

Hyo-Eun Kim Corresponding Author: hekim@lunit.ioLunit Inc., Seoul, South Korea    Seungwook Kim Lunit Inc., Seoul, South Korea    Jaehwan Lee Lunit Inc., Seoul, South Korea
Abstract

Data is one of the most important factors in machine learning. However, even if we have high-quality data, there is a situation in which access to the data is restricted. For example, access to the medical data from outside is strictly limited due to the privacy issues. In this case, we have to learn a model sequentially only with the data accessible in the corresponding stage. In this work, we propose a new method for preserving learned knowledge by modeling the high-level feature space and the output space to be mutually informative, and constraining feature vectors to lie in the modeled space during training. The proposed method is easy to implement as it can be applied by simply adding a reconstruction loss to an objective function. We evaluate the proposed method on CIFAR-10/100 and a chest X-ray dataset, and show benefits in terms of knowledge preservation compared to previous approaches.

1 Introduction

In a restricted multi-center learning environment where each chunk of data is only available at the corresponding center, we should learn a model incrementally without previous data chunks. Consider the scenario in which privacy-sensitive medical data are spread across multiple hospitals such that a machine learning model has to be learned sequentially. If all data are available to be used concurrently, learning just with state-of-the-art deep learning models such as ResNet for image recognition [5] or GNMT for machine translation [15] can be a good solution. However, if a data chunk from one stage is not available anymore in the following learning stages, it is hard to preserve the knowledge learned from the old data chunk because of the phenomenon known as catastrophic forgetting [4]. This becomes more problematic especially in neural networks optimized with gradient descent [12].

Overcoming catastrophic forgetting is one of the key research topics in deep learning. One naive approach is to fine-tune (FT) the model with the data accessible at each stage by learning from the up-to-date model parameters [2]. Learning without Forgetting (LwF) is a representative method for overcoming catastrophic forgetting in neural networks [11]. Before starting training in the current stage, output logits (LwF-logits) of the current training examples are calculated first, so that each example is paired with its true label and also the pre-calculated LwF-logit. The LwF-logits are used as pseudo labels for preserving old knowledge. Elastic Weight Consolidation (EWC) maintains old knowledge by constraining important weights (i.e. model parameters) not to vary too much [8]. The relative importance between weights is defined based on Fisher information matrix. Deep Generative Replay (GR) [13] uses a generative adversarial network [3]. GR learns a generative model and a task solving model at the same time, and the learned generator is used for sampling old data during current learning stage. The concept of GR is interesting, but samples from generative models are not suitable for use in certain applications such as medical imaging where pixel-level details include important radiographic features for diagnosis.

LwF and EWC are representative approaches for preventing catastrophic forgetting in neural networks based on two distinctive philosophies: controlling the output activation (LwF) or the model parameters (EWC). In this work, we preserve knowledge by modeling the feature space directly.111We denote feature space to be the space of feature vectors, usually from the layer before the output layer. [11] showed that using the LwF-vectors of the second last hidden layer instead of the LwF-logits of the output layer had no benefit. Based on the assumption that there exists better feature space for knowledge preservation, we model the high-level feature space and the output (logit) space to be mutually informative each other, and constrain the feature space to be in the modeled space during training. With experimental validation, we show that the proposed method preserves more knowledge than previous approaches.

2 Baseline models

LwF and EWC are originally proposed for preventing catastrophic forgetting in multi-task learning where each task has its own data and the data used in previous tasks are not available when solving the current task. We call this as multi-center multi-task learning. We focus on multi-center single-task learning where the model is learned with different data-chunk of the same task and access to each data-chunk is restricted. In this section, we define several baseline models for the multi-center single-task learning environment.

Fine-tuning (FT) trains a model incrementally based on the model parameters learned in the previous stage. Figure 1(a) shows the model architecture for FT. , , and are random variables for the input, latent, and output spaces, respectively. Target loss function (e.g., negative-log-likelihood for classification) optimizes the model parameters which consist of (shared) and (new). In the first stage, is randomly initialized. In the following stages, is restored from the model learned in the previous stage.

Learning without Forgetting (LwF) trains a model using both ground-truth labels and pseudo labels (pre-calculated LwF-logits). Figure 1(b) demonstrates the -th learning stage. and are the model’s output for the current and the -th stages for i in . The loss function is described as,

(1)

where is the loss between the model output and its ground-truth label. is the loss between the model output and its LwF-logit, and is a weighting constant. and are initialized randomly in the first stage and restored from the previous stage in the following stages. In the -th stage, is initialized with of the ()-th stage and fine-tuned until the final stage. In the third stage, for example, and are restored from and of the second stage, respectively. For classification tasks, and are typically the cross-entropy loss.

Figure 1: Model architectures: (a) FT/EWC, (b) LwF, and (c) modified LwF (LwF+).

In the multi-center multi-task learning environment, LwF preserves old knowledge by constraining the outputs of the old task-specific layers with corresponding pseudo labels. But, finding out the optimal feature space in terms of all the tasks becomes hard as the number of tasks (i.e. output branches) increases.

Modified LwF (LwF+): LwF can be modified for the multi-center single-task learning. All the previous task-specific layers are merged into a single knowledge-preserving layer as shown in Figure 1(c). So the loss function becomes,

(2)

where is the loss between and its pseudo label (LwF-logit). and are initialized randomly in the first stage and restored from the previous model in the following stages. is initialized with from the first stage and fine-tuned until the end of the learning stages.

Elastic Weight Consolidation (EWC) constrains the model parameters by defining the importance of weights. Each parameter has its own weight-decay constant; the more important a parameter is, the larger the weight-decay constant. Based on the model in Figure 1(a), the loss function is,

(3)

where is the -th model parameter learned in the previous stage and is the -th element of the diagonal of the Fisher matrix for weighting the -th model parameter . is a weighting constant. are randomly initialized in the first stage and restored from the previous model for the following stages.

EWCLwF (EWCLwF+) is the combined model of EWC and LwF (LwF+). Since both methods keep old knowledge based on two distinctive approaches, they can be used complementarily. Based on the model architecture described in Figure 1(b) with the loss function in Eq. (1), in Eq. (3) is merged so the loss function becomes . EWCLwF+ is similar to EWCLwF. Based on the model LwF+ in Figure 1(c) with the loss in Eq. (2), target loss becomes .

All the presented models are originated from the two representative methods for knowledge preservation in neural networks. Details of the experimental set-up for the baseline models will be explained in Section 4.

3 Proposed Methodology

In a general neural network model as in Figure 1(a), the output of the input data is compared with its true label, and the error is propagated backward from top to bottom, which encourages the latent variable to be task-specific. To keep the previously learned knowledge, the latent space should be informative enough to include the information of the input .

Figure 2: Proposed model architecture: (a) the first learning stage and (b) the following learning stages.
Figure 3: Top layers of ResNet: based on (a) layer or (b) layer. Both are functionally equivalent.

During learning the feature extractor of and the classifier of , inverse function of () can be approximately modeled by minimizing the distance between the latent vector and its reconstruction like Figure 3(a). Without any constraints, minimizing the reconstruction loss easily makes the latent space to be trivial in terms of the information that can represent such that which is an entropy of is low. Since should be informative enough to minimize the task solving loss , joint learning with both the reconstruction and task solving losses prevents from being trivial. It is known that minimizing the conditional entropy can be done by minimizing the reconstruction error of under the auto-encoder framework [14]. And minimizing the task solving loss keeps not to reduce too much. As a result, and are being mutually informative from the joint learning with the two losses.222Note that the mutual information between and is .

Figure 3 shows the proposed model architecture. In the first stage, , , and (respectively parameterized by , , and ; initialized randomly in the first stage) are learned by minimizing the task solving and reconstruction losses concurrently. In the next stage, the parameters and of the functions and are restored from the and of the first stage and fixed during the rest of the learning stages.333 are used to restore the modeled space, so they do not need to be fine-tuned. and are the outputs for solving the task with current data and preserving previously-learned knowledge, respectively. Based on the loss function for LwF+ in Eq.(2), target space modeled in the first stage can be kept in the following stages by fixing of and of and guiding the output with LwF-logits. The loss function is shown below,

(4)

where is a weighting constant for the reconstruction loss. LwF-logits for are calculated in the same manner as in LwF+. and in the second stage are initialized with the parameters learned from the first stage and fine-tuned using the data in the corresponding stages until the end of the learning process.

Since we bound the space with the space modeled in the first stage and fix the and (with LwF-logits), tries to pull the new data examples into the modeled space which is remembering the previous data examples.

4 Experiments

Figure 4: Proposed model described in Figure 3 based on the modified ResNet in Figure 3.

We compare the proposed method with the baseline models in several image classification tasks. Base network is ResNet [5] which consists of multiple residual blocks and average-pooling (avgpool) followed by a fully-connected (fc) layer as shown in Figure 3(a). The 3-D feature map extracted from the top-most residual block is pooled into a 1-D feature vector via avgpool, and the output vector is obtained from through the final fc. Given of an input example, is given by , where is the fc layer parameterized by . and avgpool are commutative because avgpool is a linear operation. Based on the modified model in Figure 3(b), the output can be described as , where is now an 11 convolution layer () parameterized by . We used the modified ResNet in order to model the approximate inverse function accurately before avgpool. Both are equivalent in terms of their function, but the modified model requires more computation than the original ResNet. The proposed network architecture is shown in Figure 4. and are the model parameters of layers which are the replacement of fc layers in the original ResNet.

Three datasets are used for experimental validation; CIFAR-10/100 [9] and chest X-rays (CXRs) for natural image and medical image classification. ResNet-56, 110, 21 are the base models for CIFAR-10, CIFAR-100, and CXRs, respectively. Each network consists of an initial convolution layer, three sets of consecutive residual blocks, and a final layer. In ResNet-21, an additional convolution layer (kernel 33, filter width 32, stride 2) with maxpooling (kernel 22, stride 2) is added as conv-bn-relu-maxpool (bn: batch normalization [7], relu: rectified linear unit [10]) before the initial convolution to expand receptive field for large-size CXRs. Table 1 summarizes the layer components. The top layer of ResNet-21 is modified from its original architecture and this will be explained in Section 4.2. Approximate inverse function (of ) parameterized by in Figure 4 consists of multiple consecutive convolutions. in ResNet-56, 110, 21 for CIFAR-10, 100, CXRs includes four, three, three consecutive 33 (stride 1) convolution layers with filter widths (64, 128, 128, 64), (256, 256, 256), (32, 64, 128) followed by a single bn-relu, respectively.

ResNet-56 9
ResNet-110 18
ResNet-21 3
Table 1: Layer components. , , are # of residual blocks, a conv layer, a residual block, respectively; e.g., of ResNet-110 has 18 # of two consecutive 33 conv layers with filter width 64. Downsampling with stride 2 is performed by and .

For CIFAR-10/100, the initial learning rate of 0.1 is decayed by every 40 epochs until the 120-th epoch. For CXRs, the initial learning rate of 0.01 is decayed by every 20 epochs until the 80-th epoch. Weight decay constant of 0.0001 and stochastic gradient descent with momentum 0.9 are used. For CIFAR-10/100, 3232 image is randomly cropped from 4040 zero-padded image (4 pixels on each side of the original 3232 image) during training [5]. Each CXR is resized to 500500 and randomly cropped 448448 image is used for training. for CIFAR-10, CIFAR-100, and CXRs are 0.1, 10.0, and 1.0, respectively. They are selected from the set 0.1, 1.0, 10.0 by cross validation. in Eq. (1) is , where is the number of learning stages including the current one. and are 0.1 and 1.0. All experiments are done with tensorflow [1].

4.1 Cifar-10/100

CIFAR-10/100 have 10/100 classes with 3232 50k/10k training/test images, respectively. In our experiment, 10k training images are used for validation and the model which performs the best on the validation set is selected for evaluation on the test set. The remaining 40k training images are splitted into four sets (10k/set). Each model is trained continually in the multi-center single-task learning set-up, where each center has 10k training images and the task is 10/100-class classification. Table 2 shows the error rates on the test set with mean (std) of five trials. LwF+, EWCLwF+ mostly perform better than LwF, EWCLwF; i.e. LwF+, EWCLwF+ are more appropriate for the multi-center single-task learning. The proposed method performs the best as shown in this table.

CIFAR-10 CIFAR-100
stage-1 stage-2 stage-3 stage-4 stage-1 stage-2 stage-3 stage-4
FT 20.21(.151) 16.76(.419) 15.40(.174) 15.02(.174) 50.13(1.25) 42.79(.692) 40.53(.467) 38.96(.354)
EWC 19.87(.421) 16.70(.178) 15.42(.258) 14.77(.331) 49.93(.937) 42.52(.299) 40.72(.231) 38.94(.504)
LwF 20.28(.532) 16.62(.453) 15.46(.220) 14.68(.304) 50.41(.422) 42.70(.334) 39.50(.417) 37.51(.319)
LwF+ 19.88(.574) 16.57(.194) 15.02(.238) 14.05(.115) 50.69(.760) 42.64(.887) 39.31(.490) 37.30(.558)
EWCLwF 19.79(.122) 16.62(.041) 15.45(.413) 14.49(.183) 50.15(.552) 42.22(.481) 39.62(.338) 37.44(.526)
EWCLwF+ 20.26(.474) 16.99(.410) 15.34(.440) 14.25(.239) 50.10(.439) 42.49(.335) 39.32(.288) 37.21(.377)
Proposed 20.11(.431) 16.12(.253) 14.54(.175) 13.74(.195) 49.87(.461) 42.00(.479) 38.81(.438) 36.42(.373)
Table 2: CIFAR-10/100: test set (10k images) error rates - mean (std) of five trials.

After stage-1, training data of the stage-1 (st-1-trn) is not used in the following stages anymore. So, we evaluate the final model with st-1-trn to see how much of st-1-trn has been forgotten after the final stage. For CIFAR-10, 85.75%, 85.97%, 88.64%, 88.22%, 89.40% of st-1-trn are still preserved as correct at stage-4 for FT, EWC, LwF+, EWCLwF+, Proposed, respectively. For CIFAR-100, 58.67%, 58.85%, 65.57%, 66.91%, 69.34% of st-1-trn are preserved correctly at the final stage (with the same ordering).

4.2 Chest X-rays for Tuberculosis

We experiment with a real-field medical dataset in order to verify the proposed method is also valid in a practical set-up. A total of 10,508 de-identified CXRs (from the Korean Institute of Tuberculosis [6]) are used. It consists of 3,556 abnormal (tuberculosis; TB) and 6,952 normal cases. CXRs are commonly used for screening TB. The cases which require a follow-up test are recalled by radiologists. Among the 3,556 abnormal cases, 1,438 cases were diagnosed as active TB (TB-A) at the screening stage. The status of the remaining 2,118 cases which needed a follow-up sputum test could not be specified radiologically at the screening stage (TB-U). 80% of the data are randomly selected for training and divided into four sets; 288(TB-A), 424(TB-U), 1390(Normal) per each set. The remaining 20% are splitted evenly for validation and test; 143(TB-A), 211(TB-U), 696(Normal) for each set.

We modified the output layer of the model in order to exploit the status information of abnormality. Two output layers are used for 2-class (TB vs normal) and 3-class (TB-A, TB-U, and normal) classification, respectively. The 3-class is used for knowledge preservation. The 2-class is just for the performance measurement (AUC; area under ROC curve).

Table 3 summarizes AUC of each model with mean (std) of five trials. Except for the first stage, the proposed method is always better than the others. The proposed method also performs the best in terms of the ensemble performance of the five trials; 0.9257, 0.9205, 0.9217, 0.9271, 0.9228, 0.9172, 0.9363 for FT, EWC, LwF, LwF+, EWCLwF, EWCLwF+, Proposed, respectively. Figure 5 is the ROC curves of the st-1-trn at stage-4 (similar to CIFAR-10/100), which implicitly shows that the proposed method is helpful to preserve old knowledge.

stage-1 stage-2 stage-3 stage-4 FT 0.811(.025) 0.842(.019) 0.882(.011) 0.892(.015) EWC 0.812(.016) 0.832(.025) 0.865(.012) 0.887(.008) LwF 0.814(.020) 0.853(.026) 0.882(.020) 0.891(.019) LwF+ 0.806(.010) 0.844(.022) 0.881(.018) 0.898(.014) EWCLwF 0.821(.019) 0.841(.021) 0.869(.018) 0.890(.023) EWCLwF+ 0.817(.012) 0.852(.018) 0.871(.019) 0.884(.017) Proposed 0.813(.035) 0.869(.021) 0.896(.017) 0.909(.013)
Table 3: CXRs for TB: test set AUC - mean (std) of five trials.
Figure 5: ROC curves at stage-4 with stage-1 training data.

5 Conclusion

In this work, we raise the problem of catastrophic forgetting in multi-center single-task learning environment and propose a new way to preserve old knowledge in neural networks. By modeling the high-level feature space to be appropriate for knowledge preservation in the first stage and constraining the feature space to be in the modeled space during training in the following stages, we can preserve the knowledge learned in preceding stages. The proposed method is shown to be beneficial in terms of keeping the old knowledge in classification tasks. We need more experimental analysis beyond the classification such as lesion detection or segmentation, and we leave this for future work.

References

  • [1] Abadi, M., Agarwal, A., Barham, P., Brevdo, E., Chen, Z., Citro, C., Corrado, G.S., Davis, A., Dean, J., Devin, M., et al.: TensorFlow: Large-scale machine learning on heterogeneous systems (2015), http://tensorflow.org/, software available from tensorflow.org
  • [2] Girshick, R., Donahue, J., Darrell, T., Malik, J.: Rich feature hierarchies for accurate object detection and semantic segmentation. In: Proceedings of the IEEE conference on computer vision and pattern recognition (2014)
  • [3] Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A., Bengio, Y.: Generative adversarial nets. In: NIPS (2014)
  • [4] Goodfellow, I.J., Mirza, M., Xiao, D., Courville, A., Bengio, Y.: An empirical investigation of catastrophic forgetting in gradient-based neural networks. In: International Conference on Learning Representations (ICLR) (2014)
  • [5] He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: CVPR (2016)
  • [6] Hwang, S., Kim, H.E., Jeong, J., Kim, H.J.: A novel approach for tuberculosis screening based on deep convolutional neural networks. In: Medical Imaging (2016)
  • [7] Ioffe, S., Szegedy, C.: Batch normalization: Accelerating deep network training by reducing internal covariate shift. In: International Conference on Machine Learning (ICML) (2015)
  • [8] Kirkpatrick, J., Pascanu, R., Rabinowitz, N., Veness, J., Desjardins, G., Rusu, A.A., Milan, K., Quan, J., Ramalho, T., Grabska-Barwinska, A., et al.: Overcoming catastrophic forgetting in neural networks. In: Proceedings of the National Academy of Sciences (2017)
  • [9] Krizhevsky, A., Hinton, G.: Learning multiple layers of features from tiny images. In: Technical report, University of Toronto (2009)
  • [10] Krizhevsky, A., Sutskever, I., Hinton, G.E.: Imagenet classification with deep convolutional neural networks. In: Advances in Neural Information Processing Systems (NIPS) (2012)
  • [11] Li, Z., Hoiem, D.: Learning without forgetting. In: European Conference on Computer Vision (ECCV) (2016)
  • [12] McCloskey, M., Cohen, N.J.: Catastrophic interference in connectionist networks: The sequential learning problem. Psychology of learning and motivation 24, 109–165 (1989)
  • [13] Shin, H., Lee, J.K., Kim, J., Kim, J.: Continual learning with deep generative replay. In: arXiv preprint arXiv:1705.08690 (2017)
  • [14] Vincent, P., Larochelle, H., Lajoie, I., Bengio, Y., Manzagol, P.A.: Stacked denoising autoencoders: Learning useful representations in a deep network with a local denoising criterion. Journal of Machine Learning Research (JMLR) 11, 3371–3408 (2010)
  • [15] Wu, Y., Schuster, M., Chen, Z., Le, Q.V., Norouzi, M., Macherey, W., Krikun, M., Cao, Y., Gao, Q., Macherey, K., et al.: Google’s neural machine translation system: Bridging the gap between human and machine translation (2016)
Comments 0
Request Comment
You are adding the first comment!
How to quickly get a good reply:
  • Give credit where it’s due by listing out the positive aspects of a paper before getting into which changes should be made.
  • Be specific in your critique, and provide supporting evidence with appropriate references to substantiate general statements.
  • Your comment should inspire ideas to flow and help the author improves the paper.

The better we are at sharing our knowledge with each other, the faster we move forward.
""
The feedback must be of minimum 40 characters and the title a minimum of 5 characters
   
Add comment
Cancel
Loading ...
199825
This is a comment super asjknd jkasnjk adsnkj
Upvote
Downvote
""
The feedback must be of minumum 40 characters
The feedback must be of minumum 40 characters
Submit
Cancel

You are asking your first question!
How to quickly get a good answer:
  • Keep your question short and to the point
  • Check for grammar or spelling errors.
  • Phrase it like a question
Test
Test description