Prototype Recalls for Continual Learning
Continual learning is a critical ability of continually acquiring and transferring knowledge without catastrophically forgetting previously learned knowledge. However, enabling continual learning for AI remains a long-standing challenge. In this work, we propose a novel method, Prototype Recalls, that efficiently embeds and recalls previously learnt knowledge to tackle catastrophic forgetting issue. In particular, we consider continual learning in classification tasks. For each classification task, our method learns a metric space containing a set of prototypes where embedding of the samples from the same class cluster around prototypes and class-representative prototypes are separated apart. To alleviate catastrophic forgetting, our method preserves the embedding function from the samples to the previous metric space, through our proposed prototype recalls from previous tasks. Specifically, the recalling process is implemented by replaying a small number of samples from previous tasks and correspondingly matching their embedding to their nearest class-representative prototypes. Compared with recent continual learning methods, our contributions are fourfold: first, our method achieves the best memory retention capability while adapting quickly to new tasks. Second, our method uses metric learning for classification, and does not require adding in new neurons given new object classes. Third, our method is more memory efficient since only class-representative prototypes need to be recalled. Fourth, our method suggests a promising solution for few-shot continual learning. Without tampering with the performance on initial tasks, our method learns novel concepts given a few training examples of each class in new tasks.
Continual learning, also known as lifelong learning, is the crucial ability for humans to continually acquire and transfer new knowledge across their lifespans while retaining previously learnt experiences hassabis2017neuroscience. This ability is also critical for artificial intelligence (AI) systems to interact with the real world and process continuous streams of information thrun1995lifelong. However, the continual acquisition of incrementally available data from non-stationary data distributions generally leads to catastrophic forgetting in the system mccloskey1989catastrophic; ratcliff1990connectionist; french1999catastrophic. Continual learning remains a long-standing challenge for deep neural network models since these models typically learn representations from stationary batches of training data and tend to fail to retain good performances in previous tasks when data become incrementally available over tasks kemker2018measuring; maltoni2019continuous.
Numerous methods for alleviating catastrophic forgetting have been currently proposed. The most pragmatical way is to jointly train deep neural network models on both old and new tasks, which however demands a large amount of resources to store previous training data and hinders the learning of novel data in real time. Another option is to complement the training data for each new task with “pseudo-data” of the previous tasks shin2017continual; robins1995catastrophic. Besides the main model for task performance, a separate generative model is trained to generate fake historical data used for pseudo-rehearsal. Deep Generative Replay (DGR) shin2017continual replaces the storage of the previous training data with a Generative Adversarial Network to synthesize training data on all previously learnt tasks. These generative approaches have succeeded over very simple and artificial inputs but they cannot tackle more complicated inputs atkinson2018pseudo. Moreover, to synthesize the historical data reasonably well, the size of the generative model is usually huge that costs much memory wen2018few. An alternative method is to store the weights of the model trained on previous tasks, and impose constraints of weight updates on new tasks he2018overcoming; kirkpatrick2017overcoming; zenke2017continual; lee2017overcoming; lopez2017gradient. For example, Elastic Weight Consolidation (EWC) kirkpatrick2017overcoming and Learning Without Forgetting (LwF) li2018learning store all the model parameters on previously learnt tasks, estimate their importance on previous tasks and penalize future changes to the weights on new tasks. However, selecting the “important” parameters for previous tasks complicates the implementation by exhaustive hyper-parameter tuning. In addition, state-of-the-art neural network models often involve millions of parameters and storing all network parameters from previous tasks does not necessarily reduce the memory cost wen2018few. In contrast with these methods, storing a small subset of examples from previous tasks and replaying the “exact subset” substantially boost performance kemker2017fearnet; rebuffi2017icarl; nguyen2017variational. To achieve the desired network behavior on previous tasks, incremental Classifier and Representation Learner (iCARL) rebuffi2017icarl and Few-shot Self-Reminder (FSR) wen2018few follow the idea of logit matching or knowledge distillation in model compression ba2014deep; bucilua2006model; hinton2015distilling. However, such approaches ignore the topological relations among clusters in the embedding space and rely too much on a small amount of individual data, which may result in overfitting as shown in our experiments (Section 4.2). In contrast with them, without tampering the performance in memory retention, our method learns embedding functions and compares the feature similarities represented by class prototypes in the embedding space which improves generalization, especially in the few-shot settings, as also been verified in works hoffer2015deep; snell2017prototypical.
In this paper, we propose the method, Prototype Recalls, for continual learning in classification tasks. Similar as snell2017prototypical, we use a neural network to learn class-representative prototypes in an embedding space and classify embedded test data by finding their nearest class prototype. To tackle the problem of catastrophic forgetting, we impose additional constraints on the network by classifying the embedded test data based on prototypes from previous tasks, which promotes the preservation of initial embedding function. For example (Figure 3), in the first task (Subfigure (a)a), the network learns color prototypes to classify blue and yellow circles and in the second task (Subfigure (b)b), the network learns shape prototypes to classify green circles and triangles. With catastrophically forgetting color features, the network extracts circle features on the first task and fails to classify blue and yellow circles. To alleviate catastrophic forgetting, our method replays the embeded previous samples (blue and yellow circles) and match them with previous color prototypes (blue and yellow) which reminds the network of extracting both color and shape features in both classification tasks.
We evaluate our method under two typical experimental protocols, incremental domain and incremental class, for continual learning across three benchmark datasets, MNIST deng2012mnist, CIFAR10 krizhevsky2009learning and miniImageNet deng2009imagenet. Compared with the state-of-the-arts, our method significantly boosts the performance of continual learning in terms of memory retention capability while being able to adapt to new tasks. Unlike parameter regularization methods or iCARL or FSR, our approach further reduces the memory storage by replacing logits of each data or network parameters with one prototype of each class in the episodic memory. Moreover, in contrast to these methods where the last layer in traditional classification networks often structurally depends on the number of classes, our method leverages on metric learning, maintains the same network architecture and does not require adding new neurons or layers for new object classes. Additionally, without sacrificing classification accuracy on initial tasks, our method can generalize to learn new concepts given a few training examples in new tasks due to the advantage of metric learning, commonly used in few-shot settings snell2017prototypical; hoffer2015deep.
2 Proposed Method
We propose the method, Prototype Recalls, for continual learning. For a sequence of datasets , given in any task where , the goal for the model is to retain the good classification performance on all datasets after being sequentially trained over tasks. The value of is not pre-determined. The model with learnable parameters is only allowed to carry over a limited amount of information from the previous tasks. This constraint eliminates the naive solution of combining all previous datasets to form one big training set for fine-tuning the model at task . Each dataset consists of labeled examples where each is the -dimensional feature vector of an example and is the corresponding class label. denotes the set of examples labeled with class .
At task , if we simply train a model by only minimizing the classification loss on dataset , the model will forget how to perform classification on previous datasets which is described as catastrophic forgetting problem mccloskey1989catastrophic; ratcliff1990connectionist; french1999catastrophic. Here we show how the model trained in our method retains the good performance on all previous tasks while adaptively learning new tasks. The loss for all the previous datasets is denoted by . Our objective is to learn defined as follows:
where defines the classification loss of on dataset and measures the differences in the network behaviors in the embedding space learnt by and on , as introduced later in Equ 7. Given that are learnt from the previous tasks, at task , learning requires minimizing both terms and . In the subsections below and Figure 3, we describe how to optimize these two terms.
To perform classification on dataset , our method learns an embedding space in which points cluster around a single prototype representation for each class and classification is performed by finding the nearest class prototype snell2017prototypical (Figure (a)a). Compared to traditional classification networks with a specific classification layer attached in the end, such as iCARL and FSR, our method keeps the network architecture unchanged while finding the nearest neighbour in the embedding space, which would lead to more efficient memory usage. For example, in one of the continual learning protocols snell2017prototypical where the models are asked to classify incremental classes (also see Section 3.1), traditional classification networks have to expand their architectures by accommodating more output units in the last classification layer based on the number of incremental classes and consequently, additional network parameters have to be added into the memory.
Without loss of generality, here we show how our method performs classification on . First, the model learns an embedding function and computes an -dimensional prototype which is the mean of the embeddings from examples :
The pairwise distance of one embedding and one prototype within the same class should be smaller than the intra-class ones. Our method introduces a distance function . For each example , it estimates a distance distribution based on a softmax over distances to the prototypes of classes in the embedding space:
The objective function is to minimize the negative log-probability of the ground truth class label via Stochastic Gradient Descent bottou2010large:
In practice, when is large, computing is costly and memory inefficient during training. Thus, at each training iteration, we randomly sample two complement subsets from over all classes: one for computing prototypes and the other for estimating distance distribution. Our primary choice of the distance function is squared Euclidean distance which has been verified to be effective in snell2017prototypical. In addition, we include temperature hyperparameter in as introduced in network distillation literature hinton2015distilling and set its value empirically based on the validation sets. A higher value for produces a softer probability distribution over classes.
2.2 Prototype Recall
Regardless of the changes of the network parameters from to at task and respectively, the primary goal of is to learn the embedding function which results in the similar metric space as on dataset in task (Figure (b)b). Given a limited amount of memory, a direct approach is to randomly sample a small subset from and replay these examples on task . There have been some attempts chen2012super; koh2017understanding; brahma2018subset selecting representative examples for based on different scoring functions. However, the recent work wen2018few has shown that random sampling uniformly across classes has already yielded outstanding performance in continual learning tasks. Hence, we adopt the same random sampling strategy to form .
Intuitively, if the number of data samples in is very large, the network could re-produce the metric space at task by replaying , which is our desired goal. However, this does not hold in practice given limited memory capacity. With the simple inductive bias that the metric space at task can be underlined by class-representative prototypes, we introduce another loss that embedded data sample in should still be closest to their corresponding class prototype among all prototypes at task . This ensures the metric space represented by a set of prototypes learnt from by provides good approximation to the one in task .
Formally, for any after task , we formulate the regularization of network behaviors in the metric space of task by satisfying two criteria: first, learns a metric space to classify by minimizing the classification loss , as introduced in Sec. 2.1 above; second, to preserve the similar topological structure among clusters on dataset , the embeddings predicted by based on should produce the similar distance distribution based on a softmax over the distance to prototypes computed using on dataset :
Concretely, is to minimize the negative log-probability of the ground truth class label conditioned on prototypes , which is pre-computed using in Eq 5 at task and stored in the episodic memory until task :
Overall, we define in Eq 1 as below:
2.3 Dynamic Episodic Memory Allocation
Given a limited amount of memory with capacity , our proposed method has to store a small subset with examples randomly sampled from and prototypes computed using embedding function on where . The following constraint has to be satisfied:
When the number of tasks is small, can be large and the episodic memory stores more examples in . Dynamic memory allocation of enabling more example replays in earlier tasks puts more emphasis on reviewing earlier tasks which are easier to forget, and introduces more varieties in data distributions when matching with prototypes. Pseudocode to our proposed algorithm in continual learning for a training episode is provided in Algorithm 1.
3 Experimental Details
We introduce two task protocols for evaluating continual learning algorithms with different memory usage over three benchmark datasets. Source codes will be public available upon acceptance.
3.1 Task Protocols
Permuted MNIST in incremental domain task is a benchmark task protocol in continual learning lee2017overcoming; lopez2017gradient; zenke2017continual (Figure (a)a). In each task, a fixed permutation sequence is randomly generated and is applied to input images in MNIST deng2012mnist. Though the input distribution always changes across tasks, models are trained to classify 10 digits in each task and the model structure is always the same. There are 20 tasks in total. During testing, the task identity is not available to models. The models have to classify input images into 1 out of 10 digits.
Split CIFAR10 and split MiniImageNet in incremental class task is a more challenging task protocol where models need to infer the task identity and meanwhile solve each task. The input data is also more complex which includes classification on natural images in CIFAR10 krizhevsky2009learning and miniImageNet deng2009imagenet. The former contains 10 classes and the latter consists of 100 classes. In CIFAR10, the model is first trained with 2 classes and later with 1 more class in each subsequent task. There are 9 tasks in total and 5,000 images per class in the training set. In miniImageNet, models are trained with 10 classes in each task. There are 10 tasks in total and 480 images per class in the training set.
Few-shot Continual Learning Humans can learn novel concepts given a few examples without sacrificing classification accuracy on initial tasks gidaris2018dynamic. However, typical continual learning schemes assume that a large amount of training data over all tasks is always available for fine-tuning networks to adapt to new data distributions, which does not always hold in practice. We revise task protocols to more challenging ones: networks are trained with a few examples per class in sequential tasks except for the first task. For example, on CIFAR10/miniImageNet, we train the models with 5,000/480 example images per class in the first task and 50/100 images per class in subsequent tasks.
We include the following categories of continual learning methods for comparing with our method. To eliminate the effect of network structures in performance, we introduce control conditions with the same architecture complexity for all the methods in the same task across all the experiments.
Parameter Regularization Methods: Elastic Weight Consolidation (EWC) kirkpatrick2017overcoming, Synaptic Intelligence (SI) zenke2017continual and Memory Aware Synapses (MAS) aljundi2018memory where regularization terms are added in the loss function; online EWC schwarz2018progress which is an extension of EWC in scalability to a large number of tasks; L2 distance indicating parameter changes between tasks is added in the loss kirkpatrick2017overcoming; SGD, which is a naive baseline without any regularization terms, is optimized with Stochastic Gradient Descent bottou2010large sequentially over all tasks.
Memory Distillation and Replay Methods: incremental Classifier and Representation Learner (iCARL) rebuffi2017icarl and Few-shot Self-Reminder (FSR) wen2018few propose to regularize network behaviors by exact pseudo replay. Specifically, in FSR, there are two variants: FSR-KLD for logits matching via Kullback–Leibler Divergence loss and FSR-MSE for logits distillation via L2 distance loss.
Performance is reported in terms of both mean and standard deviation after 10 runs per protocol. Since generative model-based approaches van2018generative; shin2017continual greatly alter architecture of the classification networks, we do not compare with them.
3.3 Memory Comparison
For fair comparison, we use the same feed-forward architecture for all the methods and allocate a comparable amount of memory as EWC kirkpatrick2017overcoming and other parameter regularization methods, for storing example images per class and their prototypes. In EWC, the model often allocates a memory size twice as the number of network parameters for computing Fisher information matrix which can be used for regularizing changes of network parameters kirkpatrick2017overcoming. In more challenging classification tasks, the network size tends to be larger and hence, these methods require much more memory. In Table 16, we show an example of memory allocation on split CIFAR10 in incremental class tasks with full memory and little memory respectively. The feed-forward classification network contains around parameters. Weight regularization methods require memory allocation twice as that, which takes about parameters. The input RGB images are of size . Via Equ. 8, our method can allocate episodic memory with full capacity and calculate which is equivalent to storing example images per class. In experiments with little training data as described in Section 3.1, we reduce to example images per class.
4 Experimental Results
4.1 Alleviating Forgetting
Figure 11 reports the results of continual learning methods with full memory under the two task protocols. All compared continual learning methods outperform SGD (cyan) which is a baseline without preventing catastrophic forgetting. Our method (red) achieves the highest average classification accuracy among all the compared methods, including both parameter regularization methods and memory-based methods, with minimum forgetting.
A good continual learning method should not only show good memory retention but also be able to adapt to new tasks. In Figure (a)a, although our method (red) performs on par with EWC (brown) and FSR (date) in retaining the classification accuracy on dataset in the first task along with 20 sequential tasks, the average classification accuracy of our method is far higher than EWC (brown) and FSR (date) as shown in Figure (b)b, indicating both of these methods are able to retain good memory but fail to learn new tasks. After the 13th task, the average classification performance of EWC is even worse than SGD. Across total 20 tasks, our method leads FSR (date) by 3% more accurate on average. Similar reasoning can be applied to comparison with SI (green): although our method performs comparably well as SI in terms of average classification accuracy, SI fails to retain the classification accuracy on , which is 6% lower than ours in the 20th task.
Figure (c)c and (d)d show the average task classification accuracy over sequential tasks in incremental class protocol. Incremental class protocol is more challenging than incremental domain protocol, since the models have to infer both the task identity and class labels in the task. Our method (red) performs slightly better than iCARL (date) and has the hightest average classification accuracy in continual learning. Compared with third best method, FSR (green), our method yields constantly around 5% higher on average across all tasks on CIFAR10 and miniImageNet respectively. Note that most weight regularization methods, such as EWC (brown), perform as badly as SGD. It is possible that EWC computes Fisher matrix to maintain local information and does not consider the scenarios when data distributions across tasks are too far apart. On the contrary, our method maintains remarkably better performance than EWC, because ours focuses primarily on the behaviors of network outputs, which indirectly relaxes the constraint about the change of network parameters.
4.2 Few-shot Continual Learning
We evaluate continual learning methods with little memory under two task protocols with few training data in the second tasks and onwards except for the first tasks. Figure 14 reports their performance. Our method (red) has the highest average classification accuracy over all sequential tasks among state-of-the-art methods with 27% and 11% vs. 19% and 4% of FSR-KLD (yellow), which is the second best, at the 9th and 10th tasks on CIFAR10 and miniImageNet respectively. Weight regularization methods, such as EWConline (blue) and MAS (brown), perform as badly as SGD (cyan), worse than logits matching methods, such as FSR (green and yellow) or iCARL (purple). Similar observations have been made as Figure 11 with full training data.
Compared with logits matching methods, our method has the highest average task classification accuracy. It reveals that our method performs classification via metric learning in an effective few-shot manner. It is also because our network architecture is not dependent on the number of output classes and the knowledge in previous tasks can be well preserved and transferred to new tasks. It is superior to traditional networks with new parameters added in the last classification layer, which easily leads to overfitting. As a side benefit, given the same number of example inputs in the episodic memory, our method is more efficient in memory usage since it stores one prototype per class instead of the logits for each example input as verified in Table 16.
|Full Training and Full Memory Size in Magnitudes of|
|Little Training and Little Memory Size in Magnitudes of|
4.3 Network Analysis
We also study the effects of the following three factors upon performance improvement. Figure 16 reports the average classification accuracy of these ablated methods. (1) Intuitively, limited memory capacity restricts number of example inputs to re-play and leads to performance drop. On permuted MNIST in incremental domain, with full memory capacity reduced by 2.5 times (from 5,000 example inputs to 2,000), our method shows a moderate decrease of average classification accuracy by 1% in the 20th task. (2) We also compare our method with memory replay optimized by cross-entropy loss at full memory conditions. A performance drop around 1.5% is observed which validates classifying example inputs based on initial prototypes results in better performance in memory retention. (3) Given fixed , our method adopts the strategy of decreasing numbers of example inputs in memory, with the increasing number of tasks. The performance drop of 1.5% using uniform memory allocation demonstrates the usefulness of dynamic memory allocation which enforces more examples to be replayed in earlier tasks, and therefore promotes memory retention.
In Figure 17, we provide visualizations of class embeddings by projecting these latent representations of classes into 2D space. It can be seen that our method is capable of clustering latent representations belonging to the same class and meanwhile accommodating new class embeddings across sequential tasks. Interestingly, the clusters are topologically organized based on feature similarities among classes and the topological structure from the same classes is preserved across tasks. For example, the cluster of “bird” (black) is close to that of “plane” (orange) in Task 3 and the same two clusters are still close in Task 9. This again validates that classifying example inputs from previous tasks based on initial prototypes promotes preservation of topological structure in the initial metric space.
We address the problem of catastrophic forgetting by proposing prototype recalls in classification tasks. In addition to significantly alleviating catastrophic forgetting on benchmark datasets, our method is superior to others in terms of making the memory usage efficient, and being generalizable to learning novel concepts given only a few training examples in new tasks.
However, given a finite memory capacity and a high number of tasks, we recognize that our method, just like other memory-based continual learning algorithms, have limitations in number of prototypes stored. The memory requirement of our method increases linearly with the number of continuous tasks. In practice, there is always a trade-off between memory usage and retention. We believe that our method is one of the most efficient continual learning methods in eliminating catastrophic forgetting with a decent amount of memory usage. Moreover, we restrict ourselves in classification tasks with discrete prototypes. In the future work, to apply our algorithm in more complex and challenging problems, such as regression and reinforcement learning (RL), one possible solution is to quantize the continuous space in regression or formulate RL in discrete state-action pairs.