Using Hindsight to Anchor Past Knowledge in Continual Learning

Using Hindsight to Anchor Past Knowledge in Continual Learning

Abstract

In continual learning, the learner faces a stream of data whose distribution changes over time. Modern neural networks are known to suffer under this setting, as they quickly forget previously acquired knowledge. To address such catastrophic forgetting, many continual learning methods implement different types of experience replay, re-learning on past data stored in a small buffer known as episodic memory. In this work, we complement experience replay with a new objective that we call “anchoring”, where the learner uses bilevel optimization to update its knowledge on the current task, while keeping intact the predictions on some anchor points of past tasks. These anchor points are learned using gradient-based optimization to maximize forgetting, which is approximated by fine-tuning the currently trained model on the episodic memory of past tasks. Experiments on several supervised learning benchmarks for continual learning demonstrate that our approach improves the standard experience replay in terms of both accuracy and forgetting metrics and for various sizes of episodic memories.

\printAffiliationsAndNotice

1 Introduction

We study the problem of continual learning, where a machine learning model experiences a sequence of tasks. Each of these tasks is presented as a stream of input-output pairs, where each pair is drawn identically and independently (iid) from the corresponding task probability distribution. Since the length of the learning experience is not specified a-priori, the learner can only assume a single pass over the data and, due to space constraints, store nothing but a few examples in a small episodic memory. At all times during the lifetime of the model, predictions on examples from all tasks may be requested. Addressing continual learning is an important research problem, since it would enable the community to move past the assumption of “identically and independently distributed data”, and allow a better deployment of machine learning in-the-wild. However, continual learning presents one major challenge, catastrophic forgetting (McCloskey and Cohen, 1989). That is, as the learner experiences new tasks, it quickly forgets previously acquired knowledge. This is a hindrance specially for state-of-the-art deep learning models, where all parameters are updated after observing each example.

Continual learning has received increasing attention from the scientific community during the last decade. The state-of-the-art in algorithms for continual learning fall into three categories. First, regularization-based approaches reduce forgetting by restricting the updates in parameters that were important for previous tasks (Kirkpatrick et al., 2016; Rebuffi et al., 2017; Aljundi et al., 2018; Chaudhry et al., 2018; Nguyen et al., 2018). However, when the number of tasks are large, the regularization of past tasks becomes obsolete, leading to representation drift (Titsias et al., 2019). Second, modular approaches (Rusu et al., 2016; Lee et al., 2017) add new modules to the learner as new tasks are learned. While modular architectures overcome forgetting by design, the memory complexity of these approaches scales with the number of tasks. Third, memory-based methods (Lopez-Paz and Ranzato, 2017; Hayes et al., 2018; Isele and Cosgun, 2018; Riemer et al., 2019; Chaudhry et al., 2019a) store a few examples from past tasks in an “episodic memory”, to be revisited when training for a new task. Contrary to modular approaches, memory-based methods add a very small memory overhead for each new task. Memory-based methods are the reigning state-of-the-art, but their performance remains a far-cry from a simple oracle accessing all the data at once, hence turning the continual learning experience back into a normal supervised learning task. Despite intense research efforts, such gap in performance renders the problem of continual learning an open research question.

Contribution

We propose Hindsight Anchor Learning (HAL), a continual learning approach to improve the performance of memory-based continual learning algorithms. HAL leverages bilevel optimization to regularize the training objective with one representational point per class per task, called anchors. These anchors are constructed via gradient ascent in the image space, by maximizing one approximation to the forgetting loss for the current task throughout the entire continual learning experience. We estimate the amount of forgetting that the learner would suffer on these anchors if it were to be trained on future tasks in hindsight: that is, by measuring forgetting on a temporary predictor that has been fine-tuned on the episodic memories of past tasks. Anchors learned in such a way lie close to the classifier’s decision boundary, as visualized in Figure 3. Since points near the decision boundary are easiest to forget when updating the learner on future tasks, keeping predictions invariant on such anchors preserves performance at previous tasks effectively. In sum, the overall parameter update of HAL uses nested optimization to minimize the loss on the current mini-batch, while keeping the predictions at all anchors invariant.

Results

We compare HAL to EWC (Kirkpatrick et al., 2016), ICARL (Rebuffi et al., 2017), VCL (Nguyen et al., 2018), AGEM (Chaudhry et al., 2019a), experience replay (Hayes et al., 2018; Riemer et al., 2019), and MER (Riemer et al., 2019), across four commonly used benchmarks in supervised continual learning (MNIST permutations, MNIST rotations, split CIFAR-100, and split miniImageNet). In these experiments, HAL achieves state-of-the-art performance, improving accuracy by upto 7.5% and reducing forgetting by almost 23% in the best case. We show that these results hold for various sizes of episodic memories (between and examples per class per task).

We now begin our exposition by reviewing the continual learning setup. The rest of the manuscript then presents our new algorithm HAL (Section 3), showcases its empirical performance (Section 4), surveys the related literature (Section 5), and offers some concluding remarks (Section 6).

2 Continual learning setup

In continual learning, we experience a stream of data triplets containing an input , a target , and a task identifier . Each input-target pair is an identically and independently distributed example drawn from some unknown distribution , representing the -th learning task. We assume that the tasks are experienced in order ( for all ), and that the total number of tasks is not known a priori. Under this setup, our goal is to estimate a predictor , parameterized by , and composed by a feature extractor and a classifier , that minimizes the multi-task error

(1)

where , and is a loss function.

Inspired by prior literature in continual learning (Lopez-Paz and Ranzato, 2017; Hayes et al., 2018; Riemer et al., 2019; Chaudhry et al., 2019a), we consider streams of data that are experienced only once. Therefore, the learner cannot revisit any but a small amount of data triplets chosen to be stored in a small episodic memory . More specifically, we consider tiny “ring” episodic memories, which contain the last observed examples per class for each of the experienced tasks, where . That is, considering as variables the number of experienced tasks and examples , we study continual learning algorithms with a memory footprint.

Following Lopez-Paz and Ranzato (2017) and Chaudhry et al. (2018), we monitor two statistics to evaluate the quality of continual learning algorithms: final average accuracy, and final maximum forgetting. First, the final average accuracy of a predictor is defined as

(2)

where denotes the test accuracy on task after the model has finished experiencing task . That is, the final average accuracy measures the test performance of the model at every task after the continual learning experience has finished. Second, the final maximum forgetting is defined as

(3)

that is, the decrease in performance at each of the tasks between their peak accuracy and their accuracy after the continual learning experience has finished.

Finally, following Chaudhry et al. (2019a), we use the first tasks to cross-validate the hyper-parameters of each of the considered continual learning algorithms. These first tasks are not considered when computing the final average accuracy and maximum forgetting metrics.

3 Hindsight Anchor Learning (Hal)

The current state-of-the-art algorithms for continual learning are based on experience replay (Hayes et al., 2018; Riemer et al., 2019; Chaudhry et al., 2019b). These methods update the model while storing a small amount of past observed triplets in an episodic memory . For a new mini-batch of observations from task , the learner samples a mini-batch from at random, and employ the rule to update its parameters, where

(4)

denotes the average loss across a collection of triplets . In general, is constructed to have the same size as , but it can be smaller if the episodic memory does not yet contain enough samples.

In the previous, the episodic memory reminds the predictor about how to perform at past tasks using only a small amount of data. As such, the behaviour of the predictor on past tasks outside the data stored in is not guaranteed. Also, since is usually very small, the performance of the predictor becomes sensitive to the choice of samples stored in the episodic memory. Because of this reason, we propose to further fix the behaviour of the predictor at a collection of carefully constructed anchor points , one per class per past task , at each parameter update.

Let us assume that the anchor points are given —we will see later how to construct them in practice. To constrain the change of the predictor at these anchor points, we propose a two-step parameter update rule:

(5)

The first step computes a temporary parameter vector by minimizing the loss at a minibatch from the current task , and the episodic memory of past tasks (this is the usual experience replay parameter update). The second step employs nested optimization to perform the final update of the parameter , which trades-off the minimization of () the loss value at the current minibatch and the episodic memory, as well as () the change in predictions at the anchor points for all past tasks. The proposed rule not only updates the predictor conservatively, thereby reducing forgetting, but also, as shown in Appendix A, improves the transfer by maximizing the inner product between the gradients on and anchor points. In this respect, it bears similarity to gradient-based meta-learning approaches (Finn et al., 2017; Nichol and Schulman, 2018; Riemer et al., 2019).

Next, let us discuss how to choose the anchor points, (one per class per task) as to preserve the performance of current task throughout the entire learning experience. Ideally, the anchor points should attempt to minimize the forgetting on the current task as the learner is updated on future tasks. One could achieve this by letting be an example from the task that would undergo maximum forgetting during the entire continual learning experience, including past and future tasks. Then, requiring the predictions to remain invariant at such , by using Eq. 5, could effectively reduce forgetting on the current task. Mathematically, the desirable for the label is obtained by maximizing the following Forgetting loss:

(6)

where is the parameter vector after training on task and is the final parameter vector after the entire learning experience. Thus, keeping predictions intact on the pair above can effectively preserve the performance of task . However, the idealistic Eq. 6 requires access to () the entire distribution to compute the maximization, and () access to all future distributions to compute the final parameter vector . Both are unrealistic assumptions under the continual learning setup described in Section 2, as the former requires storing the entire dataset of task , and the latter needs access to future tasks.

To circumvent (), we can recast Eq. 6 as an optimization problem and learn the desired by initializing it at random and using gradient ascent updates for a given label in the image space (). The proposed optimization objective is given by:

(7)

where the regularizer, given by mean embedding loss, constrains the search space by trying to push the anchor point embedding towards the mean data embedding. We recall that denotes the feature extractor of the predictor, and is the neural mean embedding (Smola et al., 2007) of all observed examples from task . Since the feature extractor is updated after experiencing each data point, the mean embedding are computed as running averages. That is, after observing a minibatch of task , we update:

(8)

where is initialized to zero at the beginning of the learning experience. In our experiments, we learn one per class for each task. We fix the to the corresponding class label, and discard after training on task . Learning in this manner circumvents the requirement of storing the entire distribution for the current task .

Still,  Eq. 7 requires the parameter vector , to be obtained in the distant future after all learning tasks have been experienced. To waive this impossible requirement, we propose to approximate the future by simulating the past. That is, instead of measuring the forgetting that would happen after the model is trained at future tasks, we measure the forgetting that happens when the model is fine-tuned at past tasks. In this way, we say that forgetting is estimated in hindsight, using past experiences. More concretely, after training on task and obtaining the parameter vector , we minimize the loss during one epoch on the episodic memory to obtain a temporary parameter vector that approximates , and update as:

(9)

This completes the description of our proposed algorithm for continual learning, which combines experience replay with anchors learned in hindsight. We call our approach Hindsight Anchor Learning (HAL) and summarize the entire learning process as follows: {mdframed}[backgroundcolor=yellow!8] Hindsight Anchor Learning (HAL)

  • Initialize and from normal distributions and .

  • Initialize

  • For each task :

    • Initialize

    • For each minibatch from task :

      • Sample from

      • Update using Eq. 5

      • Update using Eq. 8

      • Update by adding in a first-in first-out (FIFO) ring buffer

    • Fine-tune on to obtain

    • Build using Eq. 9 times

    • Discard

  • Return .

4 Experiments

We now evaluate the performance of HAL against a variety of baselines on commonly used supervised continual leanring benchmarks.

4.1 Datasets and tasks

We perform experiments on four supervised classification benchmarks for continual learning.

  • Permuted MNIST is a variant of the MNIST dataset of handwritten digits (LeCun, 1998) where each task applies a fixed random pixel permutation to the original dataset. This benchmark contains tasks, each with samples from different classes.

  • Rotated MNIST is another variant of MNIST, where each task applies a fixed random image rotation (between and degrees) to the original dataset. This benchmark contains tasks, each with samples from different classes.

  • Split CIFAR is a variant of the CIFAR-100 dataset (Krizhevsky and Hinton, 2009; Zenke et al., 2017), where each task contains the data pertaining random classes (without replacement) out of the total classes. This benchmark contains tasks, each with samples per each of the classes.

  • Split miniImageNet is a variant of the ImageNet dataset (Russakovsky et al., 2015; Vinyals et al., 2016), containing a subset of images and classes from the original dataset. This benchmark contains tasks, each with samples per each of the classes.

For all datasets, the first tasks are used for hyper-parameter optimization (grids available in Appendix C). The learners can perform multiple epochs on these three initial tasks, that are later discarded for evaluation.

4.2 Baselines

Method Permuted MNIST Rotated MNIST
Accuracy Forgetting Accuracy Forgetting
Finetune 53.5 (±1.46) 0.29 (±0.01) 41.9 (±1.37) 0.50 (±0.01)
EWC (Kirkpatrick et al., 2016) 63.1 (±1.40) 0.18 (±0.01) 44.1 (±0.99) 0.47 (±0.01)
VCL (Nguyen et al., 2018) 51.8 (±1.54) 0.44 (±0.01) 48.2 (±0.99) 0.50 (±0.01)
VCL-Random (Nguyen et al., 2018) 52.3 (±0.66) 0.43 (±0.01) 54.4 (±1.44) 0.44 (±0.01)
AGEM (Chaudhry et al., 2019a) 62.1 (±1.39) 0.21 (±0.01) 50.9 (±0.92) 0.40 (±0.01)
MER (Riemer et al., 2019) 69.9 (±0.40) 0.14 (±0.01) 66.0 (±2.04) 0.23 (±0.01)
ER-Ring (Chaudhry et al., 2019b) 70.2 (±0.56) 0.12 (±0.01) 65.9 (±0.41) 0.24 (±0.01)
HAL (ours) 73.6 (±0.31) 0.09 (±0.01) 68.4 (±0.72) 0.21 (±0.01)
Clone-and-finetune 81.4 (±0.35) 0.0 87.5 (±0.11) 0.0
Multitask 83.0 0.0 83.3 0.0
Method Split CIFAR Split miniImageNet
Accuracy Forgetting Accuracy Forgetting
Finetune 42.9 (±2.07) 0.25 (±0.03) 34.7 (±2.69) 0.26 (±0.03)
EWC (Kirkpatrick et al., 2016) 42.4 (±3.02) 0.26 (±0.02) 37.7 (±3.29) 0.21 (±0.03)
ICARL (Rebuffi et al., 2017) 46.4 (±1.21) 0.16 (±0.01) - -
AGEM (Chaudhry et al., 2019a) 54.9 (±2.92) 0.14 (±0.03) 48.2 (±2.49) 0.13 (±0.02)
MER (Riemer et al., 2019) 49.7 (±2.97) 0.19 (±0.03) 45.5 (±1.49) 0.15 (±0.01)
ER-Ring (Chaudhry et al., 2019b) 56.2 (±1.93) 0.13 (±0.01) 49.0 (±2.61) 0.12 (±0.02)
HAL (ours) 60.4 (±0.54) 0.10 (±0.01) 51.6 (±2.02) 0.10 (±0.01)
Clone-and-finetune 60.3 (±0.55) 0.0 50.3 (±1.00) 0.0
Multitask 68.3 0.0 63.5 0.0
Table 1: Accuracy (Eq. 2) and Forgetting (Eq. 3) results of continual learning experiments. Averages and standard deviations are computed over five runs using different random seeds. When used, episodic memories contain up to one example per class per task. Last two rows are oracle baselines.

We compare our proposed model HAL to the following baselines.

  • Finetune is a single model trained on the stream of data, without any regularization or episodic memory.

  • ICARL (Rebuffi et al., 2017) uses nearest-mean-of-exemplars rule for classification and avoids catastrophic forgetting by regularizing over the feature representations of previous tasks using knowledge distillation loss (Hinton et al., 2014).

  • EWC (Kirkpatrick et al., 2016) is a continual learning method that limits changes to parameters critical to past tasks, as measured by the Fisher information matrix.

  • VCL (Nguyen et al., 2018) is a continual learning method that uses online variational inference for approximating the posterior distribution which is then used to regularize the model.

  • AGEM (Chaudhry et al., 2019a) is a continual learning method improving on (Lopez-Paz and Ranzato, 2017), which uses an episodic memory of parameter gradients to limit forgetting.

  • MER (Riemer et al., 2019) is a continual learning method that combines episodic memories with meta-learning to limit forgetting.

  • ER-Ring (Chaudhry et al., 2019b) is a continual learning method that uses a ring buffer as episodic memory.

  • Multitask is an oracle baseline that has access to all data to optimize Eq. 1, useful to estimate an upper bound on the obtainable Accuracy (Eq. 2).

  • Clone-and-finetune is an oracle baseline training one independent model per task, where the model for task is initialized by cloning the parameters of the model for task .

All baselines use the same neural network architectures: a perceptron with two hidden layers of 256 ReLU neurons in the MNIST experiments, and a ResNet18, with three times less feature maps across all layers, similar to Lopez-Paz and Ranzato (2017), in CIFAR and ImageNet experiments. The task identifiers are used to select the output head in the CIFAR and ImageNet experiments, while ignored in the MNIST experiments. Batch size is set to for both the stream of data and episodic memories, across experiments and models. The size of episodic memories is set between and examples per class per task. The results of VCL are complied by running the official implementation1, that only works for fully-connected networks, in our continual-learning setup. All the other baselines use our unified code base which is available at https://bit.ly/2mw8bsE.

4.3 Results


Permuted MNIST

Split CIFAR
Figure 1: Evolution of Accuracy (Eq. 2) as new tasks are learned. When used, episodic memories contain up to one example per class per task.
Method Permuted MNIST Rotated MNIST
VCL-Random 55.8 (±1.29) 58.5 (±1.21) 61.2 (±0.12) 64.4 (±0.16)
AGEM 63.2 (±1.47) 64.1 (±0.74) 49.9 (±1.49) 53.0 (±1.52)
MER 74.9 (±0.49) 78.3 (±0.19) 76.5 (±0.30) 77.3 (±1.13)
ER-Ring 73.5 (±0.43) 75.8 (±0.24) 74.7 (±0.56) 76.5 (±0.48)
HAL (ours) 76.2 (±0.52) 78.4 (±0.27) 77.0 (±0.66) 78.7 (±0.97)
Method Split CIFAR Split miniImageNet
ICARL 51.7 (±1.41) 51.2 (±1.32) - -
AGEM 56.9 (±3.45) 59.9 (±2.64) 51.6 (±2.69) 54.3 (±1.56)
MER 57.7 (±2.59) 60.6 (±2.09) 49.4 (±3.43) 54.8 (±1.79)
ER-Ring 60.9 (±1.44) 62.6 (±1.77) 53.5 (±1.42) 54.2 (±3.23)
HAL (ours) 62.9 (±1.49) 64.4 (±2.15) 56.5 (±0.87) 57.2 (±1.54)
Table 2: Accuracy (Eq. 2) results for large ( to examples per class per task) episodic memory sizes. Here we only compare methods that use an episodic memory. Averages and standard deviations are computed over five runs using different random seeds.

Table 1 summarizes the main results of our experiments when episodic memory of only one example per class per task is used. First, our proposed HAL is the method achieving maximum Accuracy (Eq. 2) and minimal Forgetting (Eq. 3) at all benchmarks. This does not include oracle baselines Multitask (which has access to all data simultaneously) and Clone-and-finetune (which trains a separate model per task). Second, the relative gains from the second-best method ER-Ring to HAL are significant, confirming that the anchoring objective (Eq. 5) allows experience-replay methods to generalize better from the same amount of episodic memory.

Third, regularization based-approaches, such as EWC (Kirkpatrick et al., 2016) and VCL (Nguyen et al., 2018), suffer under the single epoch setup. As noted by Chaudhry et al. (2019a), EWC requires multiple passes over the samples of each task to perform well. The poor performance of VCL is attributed to noisy posterior estimation in the single pass setup. Note that approaches making use of memory (MER, ER and HAL) work significantly better in this setup.

Fourth, ICARL (Rebuffi et al., 2017), another method making use of episodic memory, performs poorly in our setup. From the Table 1, it can be argued that direct training on a very small episodic memory, as done in experience replay, allows the methods to generalize better compared to when the same memory is used indirectly in the knowledge distillation loss (Hinton et al., 2014) as done in ICARL.

Figure 1 shows a more fine grained analysis of average accuracy as new tasks are learned on Permuted MNIST and Split CIFAR. HAL preserves the performance of a predictor more effectively than other baselines.

Table 2 shows the Accuracy of methods employing episodic memory when the size of memory is increased. We use to examples per class per task, resulting in a total memory size from to for MNIST experiments, and from to for CIFAR and ImageNet experiments. The corresponding numbers for Forgetting are given in Appendix B. HAL consistently improves on ER-Ring and other baselines.

Figure 2 provides the training time of the continual learning baselines on MNIST benchmarks. Although HAL adds an overhead on top of experience replay baseline, it is significantly faster than MER —another approach that makes use of nested optimization to reduce forgetting. However, HAL requires extra memory to store task anchors that, as we will show next, are more effective than additional data samples one can store for experience replay. Overall, we conclude that HAL provides the best trade-off in terms of efficiency and performance.

Ablation Study

We now turn our attention towards two questions; () whether for the same episodic memory size in bytes HAL improves over the experience replay baseline, () whether fine-tuning on replay buffer is a good approximation of forgetting when the learner is updated on future tasks.

To answer the first question, let be the total size of episodic memory for all tasks when one example per class per task is stored in the replay buffer. We then run experience replay with double the size of episodic memory (i.e.) storing two examples per class per task instead of one. On the other hand, the episodic memory size in HAL is kept at . This effectively makes the size of memory in bytes taken by experience replay and that of HAL equal as latter requires extra memory to store anchors. Table 3 summarizes the results of this study. For the same memory size in bytes, HAL performs better than experience replay when additional real data samples are stored in the episodic memory. It is surprising that the anchors learned by HAL, initialized from random noise and learned using gradient-based optimization, perform better compared to randomly sampled real data. To understand this, in Figure 3 we visualize HAL’s anchors along with the task data in image and feature space on Permuted MNIST benchmark. From the left of the figure, it can be seen that HAL anchors lie with in the data cluster of a class in the image space suggesting that mean embedding loss in Eq. 9 effectively regularizes against outliers. More interestingly, the figure on the right shows that these anchors lie at or close to the cluster edges in the feature space. In other words, anchor points learned by HAL lie close to the classifier decision boundary. This can explain their effectiveness compared to the real data samples that can lie anywhere in the data cluster in feature space.

Figure 2: Training time (s) of MNIST experiments for the entire continual learning experience. MER and HAL both use meta-learning objectives to reduce forgetting.

Image Space

Feature Space
Figure 3: t-SNE visualization of images and anchors (HAL) in the image space (left) and the feature space (right) on Permuted MNIST benchmark for a single task. Anchor points are exaggerated in size for the purpose of better visualization. The left plot shows that anchor points lie with in the data cluster of a class, whereas the right plot shows that, in the feature space, anchor points lie close to the edge of the cluster of a class or near decision boundaries.
Method Permuted MNIST Split CIFAR
Accuracy Forgetting Accuracy Forgetting
ER-Ring- 70.2 ±(0.56) 0.12 (±0.01) 56.2 (±1.93) 0.13 (±0.01)
ER-Ring- 71.9 (±0.31) 0.11 (±0.01) 58.6 (±2.68) 0.12 (±0.01)
HAL- 73.6 (±0.31) 0.09 (±0.01) 60.4 (±0.54) 0.10 (±0.01)
Table 3: Comparison of HAL with experience replay. ER-Ring and HAL use one example per class per task in the episodic memory, whereas ER-Ring- uses two examples per class per task in the memory. Averages and standard deviations are computed over five runs using different random seeds.
Anchor type Permuted MNIST Split CIFAR
Accuracy Forgetting Accuracy Forgetting
HAL 73.6 (±0.31) 0.09 (±0.01) 60.4 (±0.54) 0.10 (±0.01)
Oracle 73.9 (±0.41) 0.09 (±0.01) 61.1 (±0.94) 0.09 (±0.01)
Table 4: Performance comparison of HAL with Oracle where the learner has access to all the future tasks to exactly quantify forgettig of an anchor. Averages and standard deviations are computed over five runs using different random seeds.

Finally, to answer the second part we assume a non-continual setup where at each step the learner has an oracle access to all future tasks. After training on task , the learner is fine-tuned on all future tasks and anchor points are subsequently learned by optimizing idealistic Eq. 7. The results are reported in Table 4. It can be seen from the table that the proposed HAL performs very close to the non-continual oracle baseline. This suggests that HAL’s approximation of forgetting when the learner is updated on future tasks by replaying past data is effective in many existing continual learning benchmarks.

5 Related work

In continual learning (Ring, 1997), also called lifelong learning (Thrun, 1998), a learner addresses a sequence of changing tasks without storing the complete datasets of these tasks. This is in contrast to multitask learning (Caruana, 1997), where the learner assumes simultaneous access to data from all tasks. The main challenge in continual learning is to avoid catastrophic interference (McCloskey and Cohen, 1989; McClelland et al., 1995; Goodfellow et al., 2013), that is, the learner forgetting previously acquired knowledge when learning new tasks. The state-of-the art methods in continual learning can be categorized into three classes.

First, regularization approaches discourage updating parameters important for past tasks (Kirkpatrick et al., 2016; Aljundi et al., 2018; Nguyen et al., 2018; Zenke et al., 2017). While efficient in terms of memory and computation, these approaches suffer from brittleness due to feature drift as the number of tasks increases (Titsias et al., 2019). Additionally, these approaches are only effective when we can perform multiple passes over each dataset (Chaudhry et al., 2019a), a case deemed unrealistic in this work.

Second, modular approaches use different parts of the prediction function for each new task (Fernando et al., 2017; Aljundi et al., 2017; Rosenbaum et al., 2018; Chang et al., 2018; Xu and Zhu, 2018; Ferran Alet, 2018). Modular approaches do not scale to a large number of tasks, as they require searching over combinatorial space of module architectures. Another modular approach (Rusu et al., 2016; Lee et al., 2017) adds new parts to the prediction function as new tasks are learned. By construction, modular approaches have zero forgetting, but their memory requirements increase with the number of tasks.

Third, episodic memory approaches maintain and revisit a small episodic memory of data from past tasks. In some of these methods (Li and Hoiem, 2016; Rebuffi et al., 2017), examples in the episodic memory are replayed and predictions are kept invariant by means of distillation (Hinton et al., 2014). In other approaches (Lopez-Paz and Ranzato, 2017; Chaudhry et al., 2019a; Aljundi et al., 2019b) the episodic memory is used as an optimization constraint that discourages increases in loss at past tasks. More recently, several works (Hayes et al., 2018; Riemer et al., 2019; Rolnick et al., 2018; Chaudhry et al., 2019b) have shown that directly optimizing the loss on the episodic memory, also known as experience replay, is cheaper than constraint-based approaches and improves prediction performance. Our contribution in this paper has been to improve experience replay methods with task anchors learned in hindsight.

There are other definitions of continual learning, such as the one of task-free continual learning. The task-free formulation does not consider the notion of tasks, and instead works on undivided data streams (Aljundi et al., 2019a, b). We have focused on the task-based definition of continual learning and, similar to many recent works (Lopez-Paz and Ranzato, 2017; Hayes et al., 2018; Riemer et al., 2019; Chaudhry et al., 2019a), assumed that only a single pass through the data was possible.

Finally, our gradient-based learning of anchors bears a similarity to  (Simonyan et al., 2014) and (Wang et al., 2018). In Simonyan et al. (2014), the authors use gradient ascent on class scores to find saliency maps of a classification model. Contrary to them, our proposed hindsight learning objective optimizes for the forgetting metric, as reducing it is necessary for continual learning. Dataset distillation (Wang et al., 2018) proposes to encode the entire dataset in a few synthetic points at a given parameter vector by a gradient-based optimization process. Their method requires access to the entire dataset of a task for optimization purposes. We, instead, learn anchors in hindsight from the replay buffer of past tasks after training is finished for current task. While Wang et al. (2018) aim to replicate the performance of the entire dataset from the synthetic points, we focus on reducing forgetting of an already learned task.

6 Conclusion

We introduced a bilevel optimization objective, dubbed anchoring, for continual learning. In our approach, we learned one “anchor point” per class per task, where predictions are requested to remain invariant by the means of nested optimization. These anchors are learned using gradient-based optimization, and represent points that would maximize the forgetting of the current task throughout the entire learning experience. We simulate the forgetting that would happen during the learning of future tasks in hindsight, that is, by taking temporary gradient steps across a small episodic memory of past tasks. We call our approach Hindsight Anchor Learning (HAL). As shown in our experiments, anchoring in hindsight complements and improves the performance of continual learning methods based on experience replay, achieving a new state of the art on four standard continual learning benchmarks.

Acknowledgement

The authors would like to thank Marc’Aurelio Ranzato for helpful discussions. This work was supported by the ERC grant ERC-2012-AdG 321162-HELIOS, EPSRC grant Seebibyte EP/M013774/1 and EPSRC/MURI grant EP/N019474/1. We would also like to acknowledge the Royal Academy of Engineering and FiveAI. Arslan is funded by Amazon Research award.

Appendix

Section A describes the approximate update performed by anchoring objective (Eq. 5 in the main paper). Section B reports the Forgetting metric (Eq. 3) for bigger episodic memories. Section C provides the grid considered for hyper-parameters. Section D gives pseudo-code for HAL.

Appendix A Approximate Update by Anchoring Objective

Here we will use a Taylor series expansion to approximate the update performed by anchoring objective (Eq. 5 in the main paper). In particular, we are interested in the regularization part of the anchoring objective that involves nested update. We refer to this gradient as . We follow similar arguments as (Nichol and Schulman, 2018).

Let be the parameter vector before the temporary update in the anchoring objective (Eq. 5). Also, let and be the cross-entropy and L2 losses, respectively. We use the following definitions:

Let be the operator giving a temporary update in the two-step process of (Eq. 5), and let be the temporary update itself (i.e.) (note that is used in the main paper instead of ). The is given by:

(10)

where the second step is obtained by using chain rule. Now, if we calculate the first order Taylor series approximation of ,

(11)

where in the second step we substituted the value of . By putting Eq. 11 in Eq. 10 and after some simplification we get:

(12)

This form is very similar to the second-order MAML gradient formulation, Eq. 25 in (Nichol and Schulman, 2018). Further simplification of the inner product terms between Hessian and gradients yields inner product between the gradients and . This shows that similar to MAML (Finn et al., 2017), Reptile (Nichol and Schulman, 2018) and MER (Riemer et al., 2019), anchoring objective, as described in Eq. 5 of the main paper, maximizes the inner product between the gradients. However, unlike the other meta-learning approaches, in anchoring objective, these gradients correspond to different loss functions, cross-entropy and L2 losses on data from current task and episodic memory, and HAL anchors, respectively.

Appendix B More Results

Table 5 shows the Forgetting of methods employing episodic memory when the size of memory is increased. We use to examples per class per task, resulting in a total memory size from to for MNIST experiments, and from to for CIFAR and ImageNet experiments.

Method Permuted MNIST Rotated MNIST
VCL-Random 0.39 (±0.01) 0.36 (±0.01) 0.37 (±0.01) 0.33 (±0.01)
AGEM 0.20 (±0.01) 0.19 (±0.01) 0.41 (±0.01) 0.38 (±0.01)
MER 0.14 (±0.01) 0.09 (±0.01) 0.12 (±0.01) 0.11 (±0.01)
ER-Ring 0.09 (±0.01) 0.07 (±0.01) 0.15 (±0.01) 0.13 (±0.01)
HAL (ours) 0.07 (±0.01) 0.05 (±0.01) 0.12 (±0.01) 0.11 (±0.01)
Method Split CIFAR Split miniImageNet
ICARL 0.13 (±0.02) 0.13 (±0.02) - -
AGEM 0.13 (±0.03) 0.10 (±0.02) 0.10 (±0.02) 0.08 (±0.01)
MER 0.11 (±0.01) 0.09 (±0.02) 0.12 (±0.02) 0.07 (±0.01)
ER-Ring 0.09 (±0.01) 0.06 (±0.01) 0.07 (±0.02) 0.08 (±0.02)
HAL (ours) 0.08 (±0.01) 0.06 (±0.01) 0.06 (±0.01) 0.06 (±0.01)
Table 5: Forgetting (Eq. 3) results for large ( to examples per class per task) episodic memory sizes. Here we only compare methods that use an episodic memory. Averages and standard deviations are computed over five runs using different random seeds.

Appendix C Hyper-parameter Selection

In this section, we report the hyper-parameters grid considered for experiments. The best values for different benchmarks are given in parenthesis.

  • Multitask

    • learning rate: [0.003, 0.01, 0.03 (CIFAR, miniImageNet), 0.1 (MNIST perm, rot), 0.3, 1.0]

  • Clone-and-finetune

    • learning rate: [0.003, 0.01, 0.03 (CIFAR, miniImageNet), 0.1 (MNIST perm, rot), 0.3, 1.0]

  • Finetune

    • learning rate: [0.003, 0.01, 0.03 (CIFAR, miniImageNet), 0.1 (MNIST perm, rot), 0.3, 1.0]

  • EWC

    • learning rate: [0.003, 0.01, 0.03 (CIFAR, miniImageNet), 0.1 (MNIST perm, rot), 0.3, 1.0]

    • regularization: [0.1, 1, 10 (MNIST perm, rot, CIFAR, miniImageNet), 100, 1000]

  • AGEM

    • learning rate: [0.003, 0.01, 0.03 (CIFAR, miniImageNet), 0.1 (MNIST perm, rot), 0.3, 1.0]

  • MER

    • learning rate: [0.003, 0.01, 0.03 (MNIST, CIFAR, miniImageNet), 0.1, 0.3, 1.0]

    • within batch meta-learning rate: [0.01, 0.03, 0.1 (MNIST, CIFAR, miniImageNet), 0.3, 1.0]

    • current batch learning rate multiplier: [1, 2, 5 (CIFAR, miniImageNet), 10 (MNIST)]

  • ER-Ring

    • learning rate: [0.003, 0.01, 0.03 (CIFAR, miniImageNet), 0.1 (MNIST perm, rot), 0.3, 1.0]

  • HAL

    • learning rate: [0.003, 0.01, 0.03 (CIFAR, miniImageNet), 0.1 (MNIST perm, rot), 0.3, 1.0]

    • regularization (): [0.01, 0.03, 0.1 (MNIST perm, rot), 0.3 (miniImageNet), 1 (CIFAR), 3, 10]

    • mean embedding strength (): [0.01, 0.03, 0.1 (MNIST perm, rot, CIFAR, miniImageNet), 0.3, 1, 3, 10]

    • decay rate (): 0.5

    • gradient steps on anchors (): 100

Appendix D Hal Algorithm

Algorithm 1 provides pseudo-code for HAL.

1:procedure HAL()
2:      
3:      
4:      for  do
5:            
6:            for  do Sample a batch from current task
7:                  Sample a batch from episodic memory
8:                  Temporary parameter update
9:                  Anchoring objective (Eq. 5)
10:                  Running average of mean embedding
11:                  Add samples to a ring buffer
12:            end for
13:             Get anchors for current task
14:      end for
15:      return
16:end procedure

1:procedure GetAnchors()
2:      
3:      for  do
4:             Finetune by taking SGD steps on the episodic memory
5:      end for
6:       Store the updated parameter
7:       Initialize the task anchors
8:      for  do
9:             Maximize forgetting (Eq. 9)
10:      end for
11:      return
12:end procedure
Algorithm 1 Training of HAL on sequential data , with total replay buffer size ’mem_sz’, learning rate ’’, regularization strength ’’, mean embedding decay ’’, mean embedding strength ’’.

Footnotes

  1. https://github.com/nvcuong/variational-continual-learning

References

  1. Memory aware synapses: learning what (not) to forget. In ECCV, Cited by: §1, §5.
  2. Expert gate: lifelong learning with a network of experts.. In CVPR, pp. 7120–7129. Cited by: §5.
  3. Task-free continual learning. In CVPR, pp. 11254–11263. Cited by: §5.
  4. Online continual learning with no task boundaries. arXiv preprint arXiv:1903.08671. Cited by: §5, §5.
  5. Multitask learning. Machine learning 28 (1), pp. 41–75. Cited by: §5.
  6. Automatically composing representation transformations as a means for generalization. In ICML workshop Neural Abstract Machines and Program Induction v2, Cited by: §5.
  7. Riemannian walk for incremental learning: understanding forgetting and intransigence. In ECCV, Cited by: §1, §2.
  8. Efficient lifelong learning with a-gem. In ICLR, Cited by: §1, §1, §2, §2, 5th item, §4.3, Table 1, §5, §5, §5.
  9. Continual learning with tiny episodic memories. arXiv preprint arXiv:1902.10486. Cited by: §3, 7th item, Table 1, §5.
  10. Pathnet: evolution channels gradient descent in super neural networks. arXiv preprint arXiv:1701.08734. Cited by: §5.
  11. Modular meta-learning. arXiv preprint arXiv:1806.10166v1. Cited by: §5.
  12. Model-agnostic meta-learning for fast adaptation of deep networks. In ICML-Volume 70, pp. 1126–1135. Cited by: Appendix A, §3.
  13. An empirical investigation of catastrophic forgetting in gradient-based neural networks. arXiv preprint arXiv:1312.6211. Cited by: §5.
  14. Memory efficient experience replay for streaming learning. arXiv preprint arXiv:1809.05922. Cited by: §1, §1, §2, §3, §5, §5.
  15. Distilling the knowledge in a neural network. In NIPS, Cited by: 2nd item, §4.3, §5.
  16. Selective experience replay for lifelong learning. arXiv preprint arXiv:1802.10269. Cited by: §1.
  17. Overcoming catastrophic forgetting in neural networks. PNAS. Cited by: §1, §1, 3rd item, §4.3, Table 1, §5.
  18. Learning multiple layers of features from tiny images. https://www.cs.toronto.edu/ kriz/cifar.html. Cited by: 3rd item.
  19. The mnist database of handwritten digits. http://yann.lecun.com/exdb/mnist/. Cited by: 1st item.
  20. Lifelong learning with dynamically expandable networks. arXiv preprint arXiv:1708.01547. Cited by: §1, §5.
  21. Learning without forgetting. In ECCV, pp. 614–629. Cited by: §5.
  22. Gradient episodic memory for continuum learning. In NIPS, Cited by: §1, §2, §2, 5th item, §4.2, §5, §5.
  23. Why there are complementary learning systems in the hippocampus and neocortex: insights from the successes and failures of connectionist models of learning and memory.. Psychological review 102 (3), pp. 419. Cited by: §5.
  24. Catastrophic interference in connectionist networks: the sequential learning problem. Psychology of learning and motivation 24, pp. 109–165. Cited by: §1, §5.
  25. Variational continual learning. ICLR. Cited by: §1, §1, 4th item, §4.3, Table 1, §5.
  26. Reptile: a scalable metalearning algorithm. arXiv preprint arXiv:1803.02999, 2018. Cited by: Appendix A, Appendix A, §3.
  27. ICaRL: incremental classifier and representation learning. In CVPR, Cited by: §1, §1, 2nd item, §4.3, Table 1, §5.
  28. Learning to learn without forgetting by maximizing transfer and minimizing interference. In ICLR, Cited by: Appendix A, §1, §1, §2, §3, §3, 6th item, Table 1, §5, §5.
  29. CHILD: a first step towards continual learning. Machine Learning 28 (1), pp. 77–104. Cited by: §5.
  30. Experience replay for continual learning. CoRR abs/1811.11682. External Links: Link, 1811.11682 Cited by: §5.
  31. Routing networks: adaptive selection of non-linear functions for multi-task learning. In ICLR, Cited by: §5.
  32. ImageNet Large Scale Visual Recognition Challenge. IJCV 115 (3), pp. 211–252. Cited by: 4th item.
  33. Progressive neural networks. arXiv preprint arXiv:1606.04671. Cited by: §1, §5.
  34. Deep inside convolutional networks: visualising image classification models and saliency maps. In ICLR, Cited by: §5.
  35. A hilbert space embedding for distributions. In ALT, pp. 13–31. Cited by: §3.
  36. Lifelong learning algorithms. In Learning to learn, pp. 181–209. Cited by: §5.
  37. Functional regularisation for continual learning using gaussian processes. arXiv preprint arXiv:1901.11356. Cited by: §1, §5.
  38. Matching networks for one shot learning. In NIPS, pp. 3630–3638. Cited by: 4th item.
  39. Dataset distillation. arXiv preprint arXiv:1811.10959. Cited by: §5.
  40. Reinforced continual learning. In arXiv preprint arXiv:1805.12369v1, Cited by: §5.
  41. Continual learning through synaptic intelligence. In ICML, Cited by: 3rd item, §5.
Comments 0
Request Comment
You are adding the first comment!
How to quickly get a good reply:
  • Give credit where it’s due by listing out the positive aspects of a paper before getting into which changes should be made.
  • Be specific in your critique, and provide supporting evidence with appropriate references to substantiate general statements.
  • Your comment should inspire ideas to flow and help the author improves the paper.

The better we are at sharing our knowledge with each other, the faster we move forward.
""
The feedback must be of minimum 40 characters and the title a minimum of 5 characters
   
Add comment
Cancel
Loading ...
408746
This is a comment super asjknd jkasnjk adsnkj
Upvote
Downvote
""
The feedback must be of minumum 40 characters
The feedback must be of minumum 40 characters
Submit
Cancel

You are asking your first question!
How to quickly get a good answer:
  • Keep your question short and to the point
  • Check for grammar or spelling errors.
  • Phrase it like a question
Test
Test description