Meta-learning with differentiable closed-form solvers
Adapting deep networks to new concepts from few examples is extremely challenging, due to the high computational and data requirements of standard fine-tuning procedures. Most works on meta-learning and few-shot learning have thus focused on simple learning techniques for adaptation, such as nearest neighbors or gradient descent. Nonetheless, the machine learning literature contains a wealth of methods that learn non-deep models very efficiently. In this work we propose to use these fast convergent methods as the main adaptation mechanism for few-shot learning. The main idea is to teach a deep network to use standard machine learning tools, such as logistic regression, as part of its own internal model, enabling it to quickly adapt to novel tasks. This requires back-propagating errors through the solver steps. While normally the matrix operations involved would be costly, the small number of examples works to our advantage, by making use of the Woodbury identity. We propose both iterative and closed-form solvers, based on logistic regression and ridge regression components. Our methods achieve excellent performance on three few-shot learning benchmarks, showing competitive performance on Omniglot and surpassing all state-of-the-art alternatives on miniImageNet and CIFAR-100.
Meta-learning with differentiable closed-form solvers
Luca Bertinetto University of Oxford firstname.lastname@example.org João F. Henriques University of Oxford email@example.com Philip H.S. Torr University of Oxford firstname.lastname@example.org Andrea Vedaldi University of Oxford email@example.com
noticebox[b]Preprint. Work in progress.\end@float
While modern machine learning techniques thrive in big data, applying them to low-data regimes should certainly be possible. For example, humans can easily perform fast mapping carey1978less (); carey1978acquiring (), namely learning a new concept after only a single exposure. Instead, supervised learning algorithms — and neural networks in particular — typically need to be trained from a vast amount of data in order to generalize well. This is problematic, as the availability of large labelled datasets cannot always be taken for granted. Labels could be costly to acquire: in drug discovery, for instance, researchers are often limited to characterizing only a handful of compounds altae2017low (). In other circumstances, data itself could be scarce. This can happen for example with the task of classifying rare animal species, whose exemplars are not easy to observe. Such a scenario, in which just one or a handful of training examples is provided, is referred to as few-shot learning fei2006one (); lake2015human () and has recently seen a tremendous surge in interest within the Machine Learning community (e.g. vinyals2016matching (); bertinetto2016learning (); ravi2017optimization (); finn2017model (); snell2017prototypical (); mishra2018simple (); garcia2018few (); sung2018learning ()).
Currently, few-shot learning is dominated by methods that operate within the general paradigm of meta-learning schmidhuber1987evolutionary (); naik1992meta (); bengio1992optimization (); thrun1998lifelong (). There are two main components in a meta-learning algorithm: a base learner and a meta learner. The base learner works at the level of individual episodes, namely learning problems characterised by having only a small set of labelled training images available. The meta learner, on the other hand, learns from several such episodes together with the goal of improving the performance of the base learner across tasks. In this way, knowledge from a single episode is not extracted in a vacuum as the meta learner can learn to capture the distribution of tasks.
Clearly, in any meta-learning algorithm it is of paramount importance to choose the base learner carefully. On one side of the spectrum, methods related to nearest-neighbours, such as learning similarity functions koch2015siamese (); vinyals2016matching (); snell2017prototypical (); sung2018learning (); garcia2018few () and learning how to access a memory module santoro2016meta (); kaiser2017learning (); munkhdalai2017meta (); sprechmann2018memory (), are fast but solely rely on the quality of the similarity metric, with no additional data-dependent adaptation at test-time. On the other side of the spectrum, methods that optimize standard iterative learning algorithms, such as backpropagating through gradient descent finn2017model (); finn2018meta () or explicitly learning the learner’s update rule bengio1992optimization (); younger2001meta (); hochreiter2001learning (); andrychowicz2016learning (); ravi2017optimization (), are accurate but slow.
In this paper, we propose to adopt an alternative family of base learners, namely the ones, such as ridge regression, that can be formulated as the closed-form solution of a simple optimization problem. We show that these algorithms hit a particularly sweet spot between expressiveness, generalization capability in a low-data regime, efficiency, and meta-learnability. Efficiency arises from the closed-form solution of these learners, and meta-learnability from the fact that such closed-form solutions can be back-propagated through efficiently. Furthermore, for the special case of ridge regressor and similar base learners, we show that we can use Woodbury’s identity petersen2008matrix () to exploit the low-data regime in which the base learner operates, obtaining a very significant gain in terms of computation speed. We demonstrate the strength of our approach by performing extensive experiments on Omniglot lake2015human (), CIFAR-100 krizhevsky2009learning () (adapted to the few-shot task) and miniImageNet vinyals2016matching (). We demonstrate that our base learners can achieve performance equivalent or superior to the state-of-the-art in terms of accuracy, while being significantly faster than most gradient-based methods.
2 Related Work
The concept of meta-learning (i.e. learning to learn) schmidhuber1987evolutionary (); naik1992meta (); thrun1998learning () has been of great importance in the Machine Learning community for several decades. It generally refers to a scenario in which a base learner adapts to a single new task, while a meta-learner (consisting of an outer training loop) is trained with several tasks to improve the performance of the base learner. Over the years the goal of adapting learners to different tasks has also been studied under the umbrellas of multi-task learning thrun1996learning (); caruana1998multitask () and domain adaptation ben2010theory (). These work adapted linear or kernel models typically by considering the transformation of a distribution of training data to a new space, spanned by the test samples. Recent years have seen a renewed interest around these topics, fueled by the inclusion of deep learning architectures, which enable more complex objective functions.
Perhaps the simplest approach to meta-learning is to train a similarity function bromley1993signature (); chopra2005learning () by exposing it to millions of “matching” tasks koch2015siamese (). Despite its simplicity, this general strategy is particularly effective and it is at the core of several state of the art few-shot classification algorithms vinyals2016matching (); snell2017prototypical (); sung2018learning (). Interestingly, Garcia et al. garcia2018few () interpret learning as information propagation from support (training) to query (test) images and propose a graph neural network that can generalize matching-based approaches. Since this line of work relies on learning a similarity metric, one distinctive characteristic is that parameter updates only occur within the long time horizon of the meta-learning loop. While this can clearly spare costly computations, it also prevents these methods from performing adaptation at test time. A possible way to overcome the lack of adaptability is to train a neural network capable of predicting (some of) its own parameters. This technique has been first introduced by Schmidhuber schmidhuber1992learning (); schmidhuber1993neural () and recently revamped by Bertinetto et al. bertinetto2016learning () and Munkhdalai et al. munkhdalai2017meta (), with application to object tracking and few-shot classification.
Another popular approach to meta-learning is to interpret the gradient update of SGD as a parametric and learnable function bengio1992optimization () rather than a fixed ad-hoc routine. Younger et al. younger2001meta () and Hochreiter et al. hochreiter2001learning () observe that, because of the sequential nature of a learning algorithm, a recurrent neural network can be considered as a meta-learning system. They identify LSTMs as particularly apt for the task because of their ability to span long-term dependencies, which are important in order to meta-learn. A modern take on this idea has been presented by Andrychowicz et al. andrychowicz2016learning () and Ravi & Larochelle ravi2017optimization (), showing benefits on classification, style transfer and few-shot learning.
A recent and promising research direction is the one set by MacLaurin et al. maclaurin2015gradient () and by the MAML algorithm of Finn et al. finn2017model (). Instead of explicitly designing a meta-learner module to learn the update rule, they backpropagate through the very operation of gradient descent to optimize for the hyperparameters maclaurin2015gradient () or the initial parameters finn2017model () of the learner. Follow-up work finn2018meta () shows that, in terms of representational power, this simpler strategy does not have drawbacks w.r.t. explicit meta-learners. However, back-propagation through gradient descent steps is costly in terms of memory, and thus the total number of steps must be kept small.
In order to alleviate the drawback of catastrophic forgetting typical of deep neural networks mccloskey1989catastrophic (), several recent methods santoro2016meta (); kaiser2017learning (); munkhdalai2017meta (); sprechmann2018memory () make use of memory-augmented models, which can first retain and then access important and previously unseen information associated with newly encountered tasks. While such memory modules store and retrieve information in the long time range, approaches based on attention like the one of Vinyals et al. vinyals2016matching () are useful to specify the most relevant pieces of knowledge within a task. Mishra et al. mishra2018simple () complement soft attention with temporal convolutions, thus allowing the attention mechanism to access information related to past episodes.
Despite significant diversity, a common trait of all the previously mentioned approaches is the adoption of SGD within both the meta- and base-learning scopes. At the single task level, rather than adapting SGD for faster convergence, we instead argue for differentiable base learners which have an inherently fast rate of convergence before any adaptation. In similar spirit, Valmadre et al. valmadre2017end () propose a method to backpropagate through the solution of a closed-form problem. However, they resort to the Correlation Filter algorithm kumar2005correlation (), whose application is limited to scenarios in which the data matrix is circulant, such as object detection and tracking.
The goal of meta-learning is to enable a base learning algorithm to adapt to new tasks efficiently, by generalizing from a set of training tasks . Each task generally consists of a probability distribution of example inputs and outputs , . Consider a generic feature extractor, such as commonly used pre-trained networks (note that in practice we do not use pre-trained networks, but are able to train them from scratch). Then, a much simpler task-specific predictor can be trained to map input embeddings to outputs. The predictor is parameterized by a set of parameters , which are specific to the task . For example, the predictor might be trained on the Omniglot lake2015human () task (section 4.1) of character recognition in the Roman alphabet, as opposed to the Greek alphabet (which would represent another task).
To train and assess the predictor on a given task, we are given access to training samples and test samples . We can then use a learning algorithm to obtain the parameters . With slight abuse of notation, the learning algorithm thus applies the same feature extractor to all the sample inputs in . The expected quality of the trained predictor is then computed by a standard loss or error function , which is evaluated on the test samples :
Other than abstracting away the complexities of the learning algorithm as , eq. (1) is not much different from the train-test protocol commonly employed in machine learning, here applied to a single task . However, simply re-training a predictor for each task ignores potentially useful knowledge that can be transferred between them, typically encoded in . For this reason, we now take the step of parameterizing with a set of meta-parameters , which are free to encode prior knowledge to bootstrap the training procedure. For example, the meta-parameters may represent the weights of a common set of convolutional layers, shared by all tasks. Learning these meta-parameters is what is commonly referred to as meta-learning schmidhuber1987evolutionary (); naik1992meta (); bengio1992optimization (); thrun1998learning (), although in some works additional meta-parameters are integrated into the learning algorithm andrychowicz2016learning (); ravi2017optimization ().
The meta-parameters will affect the generalization properties of the learned predictors. This motivates evaluating the result of training on a held-out test set (eq. (1)). In order to learn the meta-parameters , we want to minimize the expected loss on held-out test sets over all tasks :
Since eq. (2) consists of a composition of non-linear functions, we can leverage the same tools used successfully in deep learning, namely back-propagation and stochastic gradient descent (SGD), to optimize it. The main obstacle is to choose a learning algorithm that is amenable to optimization with such tools. This means that, in practice, must be quite simple.
Examples of meta-learning algorithms. Many of the meta-learning methods in the literature differ only in the choice of learning algorithm . The feature extrator is typically a standard CNN, whose intermediate layers are trained jointly as (and thus are not task-specific). The last layer represents the linear predictor , with task-specific parameters . In a Siamese network bromley1993signature (); chopra2005learning (); koch2015siamese (), is a nearest neighbor classifier, employing inner-products or various distance functions. Prototypical networks snell2017prototypical (); ren2018meta () generalize them to -nearest neighbors. The Learnet bertinetto2016learning () uses a factorized CNN or MLP to implement , while MAML finn2017model () implements it using SGD (and furthermore adapts all parameters of the CNN). We propose instead to use closed-form optimization methods as , namely least-squares based solutions for ridge regression and logistic regression.
3.2 Efficient ridge regression base learners
Similarly to the methods discussed in section 3.1, over the course of a single episode/task we adapt the final layer of a CNN, which is a linear predictor . Note that the remaining layers of a CNN are trained from scratch to generalize between tasks by the outer loop of meta-learning (eq. 2), but for the purposes of one task they are considered fixed. In this section we assume that the inputs were pre-processed by the CNN , and that we are dealing only with the final linear predictor , where the parameters are reorganized into a matrix .
The motivation for our work is that, while not quite as simple as nearest neighbors, least-squares regressors admit closed-form solutions. Although simple least-squares is prone to overfitting, it is easy to augment it with regularization (controlled by a positive factor ), in what is known as ridge regression:
where and contain the sample pairs of input embeddings and outputs from , respectively, stacked as rows.
Because ridge regression admits a closed form solution (eq. (4)), it is relatively easy to integrate into meta-learning (eq. (2)) using standard automatic differentiation packages. The only element that may have to be treated more carefully is the matrix inversion. For computational efficiency and numerical stability, eq. (4) should be implemented as solving a linear system () rather than a multiplication by an explicit inverse matrix (). This is commonly implemented as matrix-vector division (e.g. A\b).
Another concern about eq. (4) is that the intermediate matrix grows quadratically with the embedding size . Given the large sizes of layers typically employed in deep networks, the inversion could come at a very expensive cost. To alleviate this, we rely on the Woodbury formula petersen2008matrix (), obtaining:
The main difference between eq. (4) and eq. (5) is that the intermediate matrix only grows with the number of samples in the task, . As we are interested in one or few-shot learning, this is typically very small. The overall cost of eq. (5) is only linear in the embedding size . Although this method was originally designed for regression, we found that it also works well when the target outputs are one-hot vectors representing classes.
3.3 Iterative base learners and logistic regression
It is natural to ask whether other learning algorithms can be integrated as efficiently as ridge regression within our meta-learning framework. In general, a similar derivation is possible for iterative solvers, as long as the operations are differentiable. For linear models with convex loss functions, the optimal solver is generally Newton’s method murphy2012machine (), which uses curvature (second-order) information to reach the solution in very few steps.
One learning objective of particular interest is logistic regression, which unlike ridge regression directly produces classification labels. When one applies Newton’s method to logistic regression, the resulting algorithm takes a familiar form — it consists of a series of weighted least squares (or ridge regression) problems, giving it the name Iteratively Reweighted Least Squares (IRLS) murphy2012machine (). Given inputs and binary outputs , the -th iteration updates the parameters as:
where is an identity matrix, , , and applies a sigmoid function to the predictions using the previous parameters .
Since eq. (6) takes a similar form to ridge regression, we can use it for meta-learning in the same way as in section 3.2, with the difference that a small number of steps (eq. (6)) must be performed in order to obtain the final parameters . Similarly, we obtain a solution with a cost which is linear rather than quadratic in the embedding size by employing the Woodbury formula:
where the inner inverse has negligible cost since it is a diagonal matrix.
Note that essentially the same strategy could be followed for other learning algorithms based on IRLS, such as minimization and LASSO. We take logistic regression to be a sufficiently illustrative example, of particular interest for binary classification in one/few-shot learning, leaving the exploration of other variants as future work.
3.4 Training policy
Figure 1 illustrates our overall framework. Like most meta-learning techniques, we organize our training procedure into episodes, each of which corresponds to a few-shot classification task. In standard classification, training requires sampling from a distribution of sample inputs and outputs (labels). Instead, in our case we sample from a distribution of tasks, each containing its own (small) training set and test set. Each episode also contains two sets of labels: , to train the base learner, and , to compute the error of the just-trained base learner , which enables backpropagation at the meta-level in order to learn the generic meta-parameters (section 3.1). It is important not to confuse the small training and test sets that are used in an episode/task, , with the larger training set of tasks that they are drawn from, (section 4.1).
In our implementation, one episode corresponds to a mini-batch of SGD and it is composed by sample images belonging to different classes (ways). In particular, the base training set is represented by samples (shots) per class, while the test-set by query/test images per class (using numenclature from vinyals2016matching (); fei2006one ()).
In this section, we provide practical details for the two novel methods introduced in section 3.2 and 3.3, which we dub R2-D2 (Ridge Regression Differentiable Discriminator) and LR-D2 (Logistic Regression Differentiable Discriminator). We analyze their performance against the recent literature on multinomial and binary classification problems using three few-shot learning benchmarks: Omniglot lake2015human (), miniImageNet vinyals2016matching () and cifar-fs, which we introduce in this paper.
The code for both our methods will be made available online111Project page: http://www.robots.ox.ac.uk/~luca/r2d2.html.
4.1 Few-shot learning benchmarks
Let and be respectively the set of images and the set of classes belonging to a certain data split . In standard classification datasets, and . Instead, the few-shot setup requires both and .
Omniglot lake2015human () is a handwritten characters dataset that has been referred to as the “MNIST transpose” for its high number of classes and small number of instances per class. It contains 20 examples of 1623 characters, grouped in 50 different alphabets. In order to be able to compare against the state of the art, we adopt the same setup first introduced in santoro2016meta () and vinyals2016matching (). Hence, we resize images to , we sample character classes independently from the alphabet and we augment the dataset using four rotated versions of the each instance (, , , ). Including rotations, we use 4800 classes for meta-training and meta-validation and 1692 for meta-testing.
miniImageNet vinyals2016matching () aims at representing a challenging dataset without demanding large computational resources. It is randomly sampled from ImageNet russakovsky2015imagenet () and it is constituted by a total of 60,000 images from 100 different classes, each with 600 instances. All images are RGB and have been downsampled to . As all recent work, we adopt the same splits of ravi2017optimization (), who employ 64 classes for meta-training, 16 for meta-validation and 20 for meta-testing.
cifar-fs. On the one hand, despite being lightweight, Omniglot is becoming too simple for modern few-shot learning methods, especially with the splits and augmentations of vinyals2016matching (). On the other, miniImageNet is more challenging, but it might still require a model to train for several hours before convergence. Thus, we propose cifar-fs, which is sampled from CIFAR-100 krizhevsky2009learning () and exhibits exactly the same settings of miniImageNet. We observed that the average inter-class similarity is sufficiently high to represent a challenge for the current state of the art. Moreover, the limited original resolution of of CIFAR-100 makes the task harder and at the same time allows fast prototyping. To ensure reproducibility, the data splits are available on the project website.
4.2 Experimental results
In order to produce the features for the base-learners (eq. 4 and 6), as many recent methods we use a shallow network of four convolutional “blocks”, each consisting of:
Convolution Batch Normalization Max Pooling Leaky ReLU with factor 0.1.
The four convolutional layers have filters. Dropout is applied to the last two blocks for the experiments on miniImageNet and cifar-fs, respectively with probabilities 0.1 and 0.3. We do not use any fully connected layer. Instead, we flatten and concatenate the output of the third and fourth convolutional blocks and feed it to the base-learner. It is important to mention that the use of the Woodbury formula (section 3.2) allows us to make use of high-dimensional features without incurring in burdensome computations. In fact, in few-shot problems the data matrix is particulary “fat and short”. As an example, with a 5-way/5-shot problem from miniImageNet we have . Applying the Woodbury identity, we obtain significant gains in computation, as in eq. 5 we get to invert a matrix that is only instead of .
Similarly to Snell et al. snell2017prototypical (), we observe that using a higher number of classes during training is important. Hence, despite the few-shot tasks at test time being 5 or 20-way (i.e. number of classes), in our multinomial classification experiments we train using 60 classes for Omniglot and between 15 and 25 classes for miniImageNet and cifar-fs. At the meta-learning level, we train our methods with Adam kingma2015adam () with an initial learning rate of 0.005, dampened by 0.5 every 2000 episodes.
As for the base learners, the only hyper-parameter is the regularization factor , which we set to 50 for R2-D2 and to 1 for LR-D2. Since the ridge regression base learner does not directly produce classification labels, but rather “regresses” to class scores, we feed its output to an adjustment layer, consisting of a single learnable scale and bias. In the binary problem, unless differently specified, the logistic regression base learner performs ten steps of IRLS both at training and test time.
Multinomial classification. Tables 1 and 2 illustrate the performance of our closed-form base-learner R2-D2 against the current state of the art for shallow architectures on miniImageNet, cifar-fs and Omniglot. Values represent average classification accuracies obtained by sampling 1000 episodes from the meta test-set and are presented with 95% confidence intervals. For each column, the best performance is outlined in bold. For snell2017prototypical (), we report the results reproduced by the code provided by the authors, which are slightly inferior to the ones of the paper.
In terms of feature embeddings, vinyals2016matching (); finn2017model (); snell2017prototypical (); ravi2017optimization () use 64 filters per layer (which become for miniImageNet in ravi2017optimization (); finn2017model () to limit overfitting). On top of this, sung2018learning () also uses a relation module of two convolutional plus two fully connected layers. GNN garcia2018few () employs an embedding with filters, a fully connected layer and a graph neural network (with its own extra parameters). In order to ensure a fair comparison, we increased the capacity of the architectures of three representative methods (MAML finn2017model (), prototypical networks snell2017prototypical (), GNN garcia2018few ()) to match ours. The results of these experiments are reported with a on Table 1. We make use of dropout on the last two layers for all the experiments on baselines with , as we verified it is helpful to reduce overfitting. Moreover, we report results for experiments on our R2-D2 in which we use the 64 channels embedding of vinyals2016matching (); snell2017prototypical (); ravi2017optimization ().
Despite its simplicity, our proposed method achieves an average accuracy that, on miniImageNet and cifar-fs, is consistently superior to the state-of-the-art. For example, on the four tasks of Table 1, R2-D2 improves on average of a relative w.r.t. GNN (the second best method), despite being much faster. R2-D2 shows competitive results also on Omniglot (Table 2). It achieves the best performance for both 5-shot tasks and the second best for both 1-shot. The higher capacity is beneficial only for GNN on miniImageNet and prototypical networks on cifar-fs, while being detrimental in all the other cases, in particular for MAML. Finally, when we use the “lighter” embedding on our proposed method, we can still observe a performance which is in line with the state of the art.
|miniImageNet, 5-way||cifar-fs, 5-way|
|Matching net vinyals2016matching ()||NIPS '16||41.2%||56.2%||—||—|
|Matching net FCE vinyals2016matching ()||NIPS '16||44.2%||57%||—||—|
|MAML finn2017model ()||ICML '17||48.71.8%||63.10.9%||58.91.9%||71.51.0%|
|Meta-LSTM ravi2017optimization ()||ICLR '17||43.40.8%||60.60.7%||—||—|
|Proto net snell2017prototypical ()||NIPS'17||47.40.6%||65.40.5%||55.50.7%||72.00.6%|
|Relation net sung2018learning ()||CVPR '18||50.40.8%||65.30.7%||55.01.0%||69.30.8%|
|GNN garcia2018few ()||ICLR '18||50.3%||66.4%||61.9%||75.3%|
|Ours/R2-D2 (with 64C)||—||48.70.6%||65.50.6%||60.00.7%||76.10.6%|
|Omniglot, 5-way||Omniglot, 20-way|
|Siamese net koch2015siamese ()||ICML '15||96.7%||98.4%||88%||96.5%|
|Matching net vinyals2016matching ()||NIPS '16||98.1%||98.9%||93.8%||98.5%|
|MAML finn2017model ()||ICML '17||98.7%||99.90.1%||95.80.3%||98.90.2%|
|Proto net snell2017prototypical ()||NIPS '17||98.50.2%||99.50.1%||95.30.2%||98.70.1%|
|GNN garcia2018few ()||ICLR '18||99.2%||99.7%||97.4%||99.0%|
|Ours/R2-D2 (with 64C)||—||98.40.2%||99.70.1%||94.60.2%||98.80.1%|
Binary classification. Finally, in Table 3 we report the performance of both our ridge regression and logistic regression base learners, together with four other methods representative of the state of the art. Since LR-D2 is limited to operate in a binary classification setup, we run our R2-D2 and prototypical network without oversampling the number of ways. For both methods and prototypical networks, we report the performance obtained annealing the learning rate by a factor of 0.99, which within this binary setup works significantly better than originally described in snell2017prototypical (). Moreover, motivated by the small size of the mini-batches, we replace Batch Normalization with Group Normalization wu2018group ().
|miniImageNet, 2-way||cifar-fs, 2-way|
|MAML finn2017model ()||ICML '17||74.93.0%||84.41.2%||82.82.7%||88.31.1%|
|Proto nets snell2017prototypical ()||NIPS'17||71.71.0%||84.80.7%||76.40.9%||88.50.6%|
|Relation net sung2018learning ()||CVPR '18||76.21.2%||86.81.0%||75.01.5%||86.70.9%|
|GNN garcia2018few ()||ICLR '18||78.4%||87.1%||79.3%||89.1%|
Table 3 confirms the validity of both our approaches on the binary classification problem. In general, except from the 5-shot problem on cifar-fs, R2-D2 performs slightly better than LR-D2.
Although different in nature, both MAML and our LR-D2 make use of iterative base-learners: the former is based on SGD, while the latter on Newton’s method (under the form of Iteratively Reweighted Least Squares). The use of second-order optimization might suggest that LR-D2 is characterized by computationally demanding steps. However, we can apply the Woodbury identity at every iteration and obtain a very significant speedup. In Figure 2 we compare the performance of LR-D2 vs the one of MAML for a different number of steps of the base-learner. On both datasets, the two methods are comparable for the 1-shot case, but with a higher number of shots our logistic regression approach outperforms MAML, especially for higher number of steps.
In this paper, we explored the feasibility of incorporating fast solvers with closed-form solutions (such as those based on ridge regression) as the base learning component of a meta-learning system. Importantly, the use of the Woodbury identity allows significant computational gains in a scenario presenting few samples with a high dimensionality, like the one of few-shot learning. We showed that these differentiable learning blocks work remarkably well, with excellent results on few-shot learning benchmarks, generalizing to new tasks that were not seen during training. We believe that our findings points in an exciting direction of more sophisticated online adaptation methods able to leverage the potential of prior knowledge distilled in an offline training phase. In future work, we would like to explore Newton’s methods with more complicated second-order structure than ridge regression, and experiment with cross-modal task learning.
This work was supported by the EPSRC, ERC grant ERC-2012-AdG 321162-HELIOS, EPSRC grant Seebibyte EP/M013774/1, ERC/677195-IDIU and EPSRC/MURI grant EP/N019474/1.
- (1) S. Carey, “Less may never mean more,” Recent advances in the psychology of language, 1978.
- (2) S. Carey and E. Bartlett, “Acquiring a single new word.,” 1978.
- (3) H. Altae-Tran, B. Ramsundar, A. S. Pappu, and V. Pande, “Low data drug discovery with one-shot learning,” ACS central science, 2017.
- (4) L. Fei-Fei, R. Fergus, and P. Perona, “One-shot learning of object categories,” IEEE Transactions on Pattern Analysis and Machine Intelligence, 2006.
- (5) B. M. Lake, R. Salakhutdinov, and J. B. Tenenbaum, “Human-level concept learning through probabilistic program induction,” Science, 2015.
- (6) O. Vinyals, C. Blundell, T. Lillicrap, D. Wierstra, et al., “Matching networks for one shot learning,” in Advances in Neural Information Processing Systems, 2016.
- (7) L. Bertinetto, J. F. Henriques, J. Valmadre, P. Torr, and A. Vedaldi, “Learning feed-forward one-shot learners,” in Advances in Neural Information Processing Systems, 2016.
- (8) S. Ravi and H. Larochelle, “Optimization as a model for few-shot learning,” in International Conference on Learning Representations, 2017.
- (9) C. Finn, P. Abbeel, and S. Levine, “Model-agnostic meta-learning for fast adaptation of deep networks,” in International Conference on Machine Learning, 2017.
- (10) J. Snell, K. Swersky, and R. Zemel, “Prototypical networks for few-shot learning,” in Advances in Neural Information Processing Systems, 2017.
- (11) N. Mishra, M. Rohaninejad, X. Chen, and P. Abbeel, “A simple neural attentive meta-learner,” in International Conference on Learning Representations, 2018.
- (12) V. Garcia and J. Bruna, “Few-shot learning with graph neural networks,” in International Conference on Learning Representations, 2018.
- (13) F. Sung, Y. Yang, L. Zhang, T. Xiang, P. H. Torr, and T. M. Hospedales, “Learning to compare: Relation network for few-shot learning,” in IEEE Conference on Computer Vision and Pattern Recognition, 2018.
- (14) J. Schmidhuber, Evolutionary principles in self-referential learning, or on learning how to learn: the meta-meta-… hook. PhD thesis, Technische Universität München, 1987.
- (15) D. K. Naik and R. Mammone, “Meta-neural networks that learn by learning,” in Neural Networks, 1992. IJCNN., International Joint Conference on, IEEE, 1992.
- (16) S. Bengio, Y. Bengio, J. Cloutier, and J. Gecsei, “On the optimization of a synaptic learning rule,” in Preprints Conf. Optimality in Artificial and Biological Neural Networks, pp. 6–8, Univ. of Texas, 1992.
- (17) S. Thrun, “Lifelong learning algorithms,” in Learning to learn, Springer, 1998.
- (18) G. Koch, R. Zemel, and R. Salakhutdinov, “Siamese neural networks for one-shot image recognition,” in International Conference on Machine Learning workshops, 2015.
- (19) A. Santoro, S. Bartunov, M. Botvinick, D. Wierstra, and T. Lillicrap, “Meta-learning with memory-augmented neural networks,” in International Conference on Machine Learning, 2016.
- (20) Ł. Kaiser, O. Nachum, A. Roy, and S. Bengio, “Learning to remember rare events,” in International Conference on Learning Representations, 2017.
- (21) T. Munkhdalai and H. Yu, “Meta networks,” in International Conference on Machine Learning, 2017.
- (22) P. Sprechmann, S. M. Jayakumar, J. W. Rae, A. Pritzel, A. P. Badia, B. Uria, O. Vinyals, D. Hassabis, R. Pascanu, and C. Blundell, “Memory-based parameter adaptation,” 2018.
- (23) C. Finn and S. Levine, “Meta-learning and universality: Deep representations and gradient descent can approximate any learning algorithm,” 2018.
- (24) A. S. Younger, S. Hochreiter, and P. R. Conwell, “Meta-learning with backpropagation,” in Neural Networks, 2001. Proceedings. IJCNN’01. International Joint Conference on, IEEE, 2001.
- (25) S. Hochreiter, A. S. Younger, and P. R. Conwell, “Learning to learn using gradient descent,” in International Conference on Artificial Neural Networks, pp. 87–94, Springer, 2001.
- (26) M. Andrychowicz, M. Denil, S. Gomez, M. W. Hoffman, D. Pfau, T. Schaul, and N. de Freitas, “Learning to learn by gradient descent by gradient descent,” in Advances in Neural Information Processing Systems, 2016.
- (27) K. B. Petersen, M. S. Pedersen, et al., “The matrix cookbook,” Technical University of Denmark, 2008.
- (28) A. Krizhevsky and G. Hinton, “Learning multiple layers of features from tiny images,” 2009.
- (29) S. Thrun and L. Pratt, Learning to learn. Springer Science & Business Media, 1998.
- (30) S. Thrun, “Is learning the n-th thing any easier than learning the first?,” in Advances in Neural Information Processing Systems, 1996.
- (31) R. Caruana, “Multitask learning,” in Learning to learn, Springer, 1998.
- (32) S. Ben-David, J. Blitzer, K. Crammer, A. Kulesza, F. Pereira, and J. W. Vaughan, “A theory of learning from different domains,” Machine learning, 2010.
- (33) J. Bromley, J. W. Bentz, L. Bottou, I. Guyon, Y. LeCun, C. Moore, E. Säckinger, and R. Shah, “Signature verification using a “Siamese” time delay neural network,” International Journal of Pattern Recognition and Artificial Intelligence, 1993.
- (34) S. Chopra, R. Hadsell, and Y. LeCun, “Learning a similarity metric discriminatively, with application to face verification,” in IEEE Conference on Computer Vision and Pattern Recognition, 2005.
- (35) J. Schmidhuber, “Learning to control fast-weight memories: An alternative to dynamic recurrent networks,” Neural Computation, 1992.
- (36) J. Schmidhuber, “A neural network that embeds its own meta-levels,” in Neural Networks, 1993., IEEE International Conference on, IEEE, 1993.
- (37) D. Maclaurin, D. Duvenaud, and R. Adams, “Gradient-based hyperparameter optimization through reversible learning,” in International Conference on Machine Learning, pp. 2113–2122, 2015.
- (38) M. McCloskey and N. J. Cohen, “Catastrophic interference in connectionist networks: The sequential learning problem,” in Psychology of learning and motivation, 1989.
- (39) J. Valmadre, L. Bertinetto, J. Henriques, A. Vedaldi, and P. H. Torr, “End-to-end representation learning for correlation filter based tracking,” in IEEE Conference on Computer Vision and Pattern Recognition, 2017.
- (40) B. V. Kumar, A. Mahalanobis, and R. D. Juday, Correlation pattern recognition. Cambridge University Press, 2005.
- (41) M. Ren, E. Triantafillou, S. Ravi, J. Snell, K. Swersky, J. B. Tenenbaum, H. Larochelle, and R. S. Zemel, “Meta-learning for semi-supervised few-shot classification,” in International Conference on Learning Representations, 2018.
- (42) K. P. Murphy, Machine Learning: A Probabilistic Perspective. The MIT Press, 2012.
- (43) O. Russakovsky, J. Deng, H. Su, J. Krause, S. Satheesh, S. Ma, Z. Huang, A. Karpathy, A. Khosla, M. Bernstein, et al., “Imagenet large scale visual recognition challenge,” 2015.
- (44) D. P. Kingma and J. Ba, “Adam: A method for stochastic optimization,” 2015.
- (45) Y. Wu and K. He, “Group normalization,” CoRR, 2018.