Meta-learning for mixed linear regression

In modern supervised learning, there are a large number of tasks, but many of them are associated with only a small amount of labelled data. These include data from medical image processing and robotic interaction. Even though each individual task cannot be meaningfully trained in isolation, one seeks to meta-learn across the tasks from past experiences by exploiting some similarities. We study a fundamental question of interest: When can abundant tasks with small data compensate for lack of tasks with big data? We focus on a canonical scenario where each task is drawn from a mixture of linear regressions, and identify sufficient conditions for such a graceful exchange to hold; The total number of examples necessary with only small data tasks scales similarly as when big data tasks are available. To this end, we introduce a novel spectral approach and show that we can efficiently utilize small data tasks with the help of medium data tasks each with examples.

1 Introduction

Recent advances in machine learning highlight successes on a small set of tasks where a large number of labeled examples have been collected and exploited. These include image classification with 1.2 million labeled examples Deng et al. (2009) and French-English machine translation with 40 million paired sentences Bojar et al. (2014). For common tasks, however, collecting clean labels is costly, as they require human expertise (as in medical imaging) or physical interactions (as in robotics), for example. Thus collected real-world datasets follow a long-tailed distribution, in which a dominant set of tasks only have a small number of training examples Wang et al. (2017).

Inspired by human ingenuity in quickly solving novel problems by leveraging prior experience, meta-learning approaches aim to jointly learn from past experience to quickly adapt to new tasks with little available data Schmidhuber (1987); Thrun and Pratt (2012). This has had a significant impact in few-shot supervised learning, where each task is associated with only a few training examples. By leveraging structural similarities among those tasks, one can achieve accuracy far greater than what can be achieved for each task in isolation Finn et al. (2017); Ravi and Larochelle (2016); Koch et al. (2015); Oreshkin et al. (2018); Triantafillou et al. (2019); Rusu et al. (2018). The success of such approaches hinges on the following fundamental question: When can we jointly train small data tasks to achieve the accuracy of large data tasks?

We investigate this trade-off under a canonical scenario where the tasks are linear regressions in -dimensions and the regression parameters are drawn i.i.d. from a discrete set of a support size . Although widely studied, existing literature addresses the scenario where all tasks have the same fixed number of examples. We defer formal comparisons to Section 6.

On one extreme, when large training data of sample size is available, each task can easily be learned in isolation; here, such tasks are sufficient to learn all regression parameters. This is illustrated by a solid circle in Figure 1. On the other extreme, when each task has only one example, existing approaches require exponentially many tasks (see Table 1). This is illustrated by a solid square.

Several aspects of few-shot supervised learning makes training linear models challenging. The number of training examples varies significantly across tasks, all of which are significantly smaller than the dimension of the data . The number of tasks are also limited, which restricts any algorithm with exponential sample complexity. An example distribution of such heterogeneous tasks is illustrated in Figure 1 with a bar graph in blue, where both the solid circle and square are far outside of the regime covered by the typical distribution of tasks.

Figure 1: Realistic pool of meta-learning tasks do not include large data tasks (circle) or extremely large number of small data tasks (square), where existing approaches achieve high accuracy. The horizontal axis denotes the number of examples per task, and the vertical axis denotes the number of tasks in the pool that have at least examples. The proposed approach succeeds whenever any point in the light (green) region, and any point in the heavy (yellow) region are both covered by the blue bar graph, as is in this example. The blue graph summarizes the pool of tasks in hand, illustrating the cumulative count of tasks with more than examples. We ignore constants and factors.

In this data scarce regime, we show that we can still efficiently achieve any desired accuracy in estimating the meta-parameters defining the meta-learning problem. This is shown in the informal version of our main result in Corollary 1.1. As long as we have enough number of light tasks each with examples, we can achieve any accuracy with the help of a small number of heavy tasks each with examples. We only require the total number of examples that we have jointly across all light tasks to be of order ; the number of light tasks and the number of examples per task trade off gracefully. This is illustrated by the green region in Figure 1. Further, we only need a small number of heavy tasks with , shown in the yellow region. As long as the cumulative count of tasks in blue graph intersects with the light (green) and heavy (yellow) regions, we can recover the meta-parameters accurately.

Corollary 1.1 (Special case of Theorem 1, informal).

Given two batch of samples, the first batch with

and the second batch with

Algorithm 1 estimates the meta-parameters up to any desired accuracy of with a high probability, under a certain assumptions on the meta-parameters.

We design a novel spectral approach inspired by Vempala and Wang (2004) that first learns a subspace using the light tasks, and then clusters the heavy tasks in the projected space. To get the desired tight bound on the sample complexity, we improve upon a perturbation bound from Li and Liang (2018), and borrow techniques from recent advances in property testing in Kong et al. (2019).

2 Problem formulation and notations

There are two perspectives on approaching meta-learning: optimization based Li et al. (2017); Bertinetto et al. (2019); Zhou et al. (2018); Zintgraf et al. (2019); Rajeswaran et al. (2019), and probabilistic Grant et al. (2018); Finn et al. (2018); Kim et al. (2018); Harrison et al. (2018). Our approach is motivated by the probabilistic view and we present a brief preliminary in Section 2.1. In Section 2.2, we present a simple but canonical scenario where the tasks are linear regressions, which is the focus of this paper.

2.1 Review of probabilistic view on meta-learning

A standard meta-training for few-shot supervised learning assumes that we are given a collection of meta-training tasks drawn from some distribution . Each task is associated with a dataset of size , collectively denoted as a meta-training dataset . Exploiting some structural similarities in , the goal is to train a model for a new task , coming from , from a small amount of training dataset .

Each task is associated with a model parameter , where the meta-training data is independently drawn from: for all . The prior distribution of the tasks, and hence the model parameters, is fully characterized by a meta-parameter such that .

Following the definition from Grant et al. (2018), the meta-learning problem is defined as estimating the most likely meta-parameter given meta-training data by solving

(1)

which is a special case of empirical Bayes methods for learning the prior distribution from data Carlin and Louis (2010). Once meta-learning is done, the model parameter of a newly arriving task can be estimated by a Maximum a Posteriori (MAP) estimator:

(2)

or a Bayes optimal estimator:

(3)

for a choice of a loss function . This estimated parameter is then used for predicting the label of a new data point in task as

(4)

General notations.

We define ; as the standard -norm; and . denotes the multivariate normal distribution with mean and covariance , and denotes the indicator of an event .

2.2 Linear regression with a discrete prior

In general, the meta-learning problem of (1) is computationally intractable and no statistical guarantees are known. To investigate the trade-offs involved, we assume a simple but canonical scenario where the tasks are linear regressions:

(5)

for the -th task and -th example. Each task is associated with a model parameter . The noise is i.i.d. as , and is a centered sub-Gaussian distribution with parameter . Without loss of generality, we assume that is an isotropic (i.e. ) centered sub-Gaussian distribution. If is not isotropic, we assume there are large number of ’s for whitening such that is sufficiently close to isotropic.

We do not make any assumption on the prior of ’s other than that they come from a discrete distribution of a support size . Concretely, the meta-parameter defines a discrete prior (which is also known as mixture of linear experts Chaganty and Liang (2013)) on ’s, where are the candidate model parameters, and are the candidate noise parameters. The -th task is randomly chosen from one of the components from distribution , denoted by . The training data is independently drawn from (5) for each with and .

We want to characterize the sample complexity of this meta-learning. This depends on how complex the ground truths prior is. This can be measured by the number of components , the separation between the parameters , the minimum mixing probability , and the minimum positive eigen-value of the matrix .

Notations. We define as the sub-Gaussian norm of a label in the -th task, and . Without loss of generality, we assume , which can be always achieved by scaling the meta-parameters appropriately. We also define , and and assume . is such that two matrices can be multiplied in time.

3 Algorithm

We propose a novel spectral approach (Algorithm 1) to solve the meta-learning linear regression, consisting of three sub-algorithms: subspace estimation, clustering, and classification. These sub-algorithms require different types of tasks, depending on how many labelled examples are available.

Clustering requires heay tasks, where each task is associated with many labelled examples, but we need a smaller number of such tasks. On the other hand, for subspace estimation and classification, light tasks are sufficient, where each task is associated with a few labelled examples. However, we need a large number of such tasks. In this section, we present the intuition behind our algorithm design, and the types of tasks required. Precisely analyzing these requirements is the main contribution of this paper, to be presented in Section 4.

3.1 Intuitions behind the algorithm design

We give a sketch of the algorithm below. Each step of meta-learning is spelled out in full detail in Section 5. This provides an estimated meta-parameter . When a new task arrives, this can be readily applied to solve for prediction, as defined in Definition 4.5.

Meta-learning

  1. Subspace estimation. Compute subspace which approximates , with singular value decomposition.

  2. Clustering. Project the heavy tasks onto the subspace of , perform distance-based clustering, and estimate for each cluster.

  3. Classification. Perform likelihood-based classification of the light tasks using estimated from the Clustering step, and compute the more refined estimates of for .

Prediction

  1. Prediction. Perform MAP or Bayes optimal prediction using the estimated meta-parameter as a prior.

Algorithm 1

Subspace estimation. The subspace spanned by the regression vectors, , can be easily estimated using data from the (possibly) light tasks with only . Using any two independent examples from the same task , it holds that . With a total of such examples, the matrix can be accurately estimated under spectral norm, and so is the column space . We call this step subspace estimation.

Clustering. Given an accurate estimation of the subspace , we can reduce the problem from a -dimensional to a -dimensional regression problem by projecting onto the subspace of . Tasks with examples can be individually trained as the unknown parameter is now in . The fundamental question we address is: What can we do when ? We propose clustering such light tasks based on their estimates of the regression vector ’s, and jointly solve a single regression problem for each cluster.

To this end, we borrow techniques from recent advances in property estimation for linear regression. Recently, in the contextual bandit setting, Kong et al. (2019) proposed an estimator for the correlation between the linear regressors between a pair of datasets. Concretely, given two datasets and whose true (unknown) regression vectors are and , one can estimate , and accurately with . We use this technique to estimate , whose value can be used to check if the two tasks are in the same clusters. We cluster the tasks with into disjoint clusters. We call this step clustering.

After clustering, resulting estimated ’s have two sources of error: the error in the subspace estimation, and the error in the parameter estimation for each cluster. If we cluster more heavy tasks, we can reduce the second error but not the first. We could increase the samples used in subspace estimation, but there is a more sample efficient way: classification.

Classification. We start the classification step, once each cluster has enough (i.e. ) datapoints to obtain a rough estimation of their corresponding regression vector. In this regime, we have error in the estimated ’s. This is sufficient for us to add more datapoints to grow each of the clusters. When enough data points are accumulated (i.e.  for each cluster), then we can achieve any desired accuracy with this larger set of accurately classified tasks. This separation of the roles of the three sub-algorithms is critical in achieving the tightest sample complexity.

In contrast to the necessary condition of for the clustering step, we show that one can accurately determine which cluster a new task belongs to with only examples once we have a rough initial estimation of the parameter . We grow the clusters by adding tasks with a logarithmic number of examples until we have enough data points per cluster to achieve the desired accuracy. We call this step classification. This concludes our algorithm for the parameter estimation (i.e. meta-learning) phase.

4 Main results

Suppose we have heavy tasks each with at least training examples, and light tasks each with at least training examples. If heavy tasks are data rich (), we can learn straightforwardly from a relatively small number, i.e. . If the light tasks are data rich (), they can be straightforwardly clustered on the projected -dimensional subspace. We therefore focus on the following challenging regime of data scarcity.

Assumption 1.

The heavy dataset consists of heavy tasks, each with at least samples. The first light dataset consists of light tasks, each with at least samples. The second light dataset consists of tasks, each with at least samples. We assume and .

To give more fine grained analyses on the sufficient conditions, we assume two types of light tasks are available with potentially differing sizes (Remark 4.3). In meta-learning step in Algorithm 1, subspace estimation uses , clustering uses , and classification uses . We provide proofs of the main results in Appendices A, B, and C.

4.1 Meta-learning

We characterize a sufficient condition to achieve a target accuracy in estimating the meta-parameters .

Theorem 1 (Meta-learning).

For any failure probability , and accuracy , given three batches of samples under Assumption 1, meta-learning step of Algorithm 1 estimates the meta-parameters with accuracy

with probability at least , if the following holds. The numbers of tasks satisfy

and the numbers of samples per task satisfy , , and , where is the smallest non-zero eigen value of .

In the following remarks, we explain each of the conditions.

Remark 4.1 (Dependency in ).

The total number of samples used in subspace estimation is . The sufficient condition scales linearly in which matches the information theoretically necessary condition up to logarithmic factors. If the matrix is well conditioned, for example when ’s are all orthogonal to each other, subspace estimation is easy, and scales as . Otherwise, the problem gets harder, and we need samples. Note that in this regime, tensor decomposition approaches often fails to provide any meaningful guarantee (see Table 1). In proving this result, we improve upon a matrix perturbation bound in Li and Liang (2018) to shave off a factor on (see Lemma A.11).

Remark 4.2 (Dependency in ).

The clustering step requires , which is necessary for distance-based clustering approaches such as single-linkage clustering. From Kong and Valiant (2018); Kong et al. (2019) we know that it is necessary (and sufficient) to have , even for a simpler testing problem between or , from two labelled datasets with two linear models and .

Our clustering step is inspired by Vempala and Wang (2004) on clustering under Gaussian mixture models, where the algorithm succeeds if . Although a straightforward adaptation fails, we match the sufficient condition.

We only require the number of heavy samples to be up to logarithmic factors, which is information theoretically necessary.

Remark 4.3 (Gain of using two types of light tasks).

To get the tightest guarantee, it is necessary to use a different set of light tasks to perform the final estimation step. First notice that the first light dataset does not cover the second light dataset since we need which does not need to hold for the first dataset . On the other hand, the second light dataset does not cover the first light dataset in the setting where or is very small.

Remark 4.4 (Dependency in ).

Classification and prediction use the same routine to classify the given task. Hence, the requirement in is tight, as it matches our lower bound in Proposition 4.6. The extra terms in the factor come from the union bound over all tasks to make sure all the tasks are correctly classified. It is possible to replace it by by showing that fraction of incorrectly classified tasks does not change the estimation by more than . We only require up to logarithmic factors, which is information theoretically necessary.

4.2 Prediction

Given an estimated meta-parameter , and a new dataset , we make predictions on the new task with unknown parameters using two estimators: MAP estimator and Bayes optimal estimator.

Definition 4.5.

Define the maximum a posterior (MAP) estimator as

Define the posterior mean estimator as

If the true prior, , is known. The posterior mean estimator achieves the smallest expected squared error, . Hence, we refer to it as Bayes optimal estimator. The MAP estimator maximizes the probability of exact recovery.

Theorem 2 (Prediction).

Under the hypotheses of Theorem 1 with , the expected prediction errors of both the MAP and Bayes optimal estimators are bound as

(6)

if , where the true meta-parameter is , the expectation is over the new task with model parameter , training dataset , and test data .

Note that the term in (6) is due to the noise in , and can not be avoided by any estimator. With an accurate meta-learning, we can achieve a prediction error arbitrarily close to this statistical limit, with . Although both predictors achieve the same guarantee, Bayes optimal estimator achieves smaller training and test errors in Figure 2, especially in challenging regimes with small data.

(a) Training error
(b) Prediction error
Figure 2: Bayes optimal estimator achieves smaller errors for an example. Here, , , , , , and and are standard Gaussian distributions. The parameters were learnt using the Meta-learning part of Algorithm 1 as a continuation of simulations discussed in Appendix E, where we provide extensive experiments confirming our analyses.

We show that training samples are necessary (even if the ground truths meta-parameter is known) to achieve error approaching this statistical limit. Let denote the set of all meta-parameters with components, satisfying for and for all . The following minimax lower bound shows that there exists a threshold scaling as below which no algorithm can achieve the fundamental limit of , which is in this minimax setting.

Remark 4.6 (Lower bound for prediction).

For any , if , then

(7)

where the minimization is over all measurable functions of the meta-parameter and the training data of size .

5 Details of the algorithm and the analyses

We explain and analyze each step in Algorithm 1. These analyses imply our main result in meta-learning, which is explicitly written in Appendix A.

5.1 Subspace estimation

In the following, we use _SVD routine that outputs the top -singular vectors. As , this outputs an estimate of the subspace spanned by the true parameters. We show that as long as , the accuracy only depends on the total number of examples, and it is sufficient to have .

  Input: data ,
  compute for all
  
  
   _SVD
  
Algorithm 2 Subspace estimation

The dependency on the accuracy changes based on the ground truths meta-parameters. In an ideal case when is an orthonormal matrix (with condition number one), the sample complexity is . For the worst case , it is .

Lemma 5.1 (Learning the subspace).

Suppose Assumption 1 holds, and let be the matrix with top eigen vectors of matrix . For any failure probability and accuracy , if the sample size is large enough such that

and , we have

(8)

for all with probability at least , where is the smallest non-zero eigen value of .

Time complexity: for computing , and for _SVD Allen-Zhu and Li (2016).

5.2 Clustering

Once we have the subspace, we can efficiently cluster any task associated with samples. In the following, the matrix estimates the distance between the parameters in the projected -dimensional space. If there is no error in , then if and are from different components, and zero otherwise. Any clustering algorithm can be applied treating as a distance matrix.

  Input: data , , , ,
  compute for all and
     
     
  compute for all and
     
  compute for all
      median
  Cluster using and return its partition
  compute for all
     
     
     
  
Algorithm 3 Clustering and estimation

This is inspired by Vempala and Wang (2004), where clustering mixture of Gaussians is studied. One might wonder if it is possible to apply their clustering approach to ’s directly. This approach fails as it crucially relies on the fact that with high probability for . Under our linear regression setting, does not concentrate. We instead propose median of estimates, to get the desired sufficient condition.

Lemma 5.2 (Clustering and initial parameter estimation).

Under Assumption 1, and given an orthonormal matrix satisfying (8) with any , Algorithm 3 correctly clusters all tasks with with probability at least , . Further, if

(9)

for any , with probability at least ,

(10a)
(10b)

where for all .

Time complexity: It takes time to compute . Then by using matrix multiplication, it takes time to compute the matrix , and the single linkage clustering algorithm takes time Sibson (1973).

5.3 Classification

Once we have from the clustering step, we can efficiently classify any task with samples, and an extra samples are necessary to apply the union bound. This allows us to use the light samples, in order to refine the clusters estimated with heavy samples. This separation allows us to achieve the desired sample complexity on light tasks , and heavy tasks .

In the following, we use Least_Squares routine that outputs the least-squares estimate of all the examples in each cluster. Once each cluster has samples, we can accurately estimate the meta-parameters.

  Input: data ,
  compute for all
     
  compute for all ,
      Least_Squares
     
     
  
Algorithm 4 Classification and estimation
Lemma 5.3 (Refined parameter estimation via classification).

Under Assumption 1 and given estimated parameters , satisfying , for all and task with examples per task, with probability , Algorithm 4 correctly classifies all the tasks. Further, for any if

(11)

the following holds for all ,

(12a)
(12b)
(12c)

Time complexity: Computing takes time, and least square estimation takes time.

6 Related Work


References Noise # Samples
Chaganty and Liang (2013) Yes
Yi et al. (2016) No
Zhong et al. (2016) No
Sedghi et al. (2016) Yes
Li and Liang (2018) No
Chen et al. (2020) No
Table 1: Sample complexity for previous work in MLR to achieve small constant error on parameters recovery of the mixed linear regression problem. We ignore the constants and factors. Let , and denote the number of samples, the dimension of the data points, and the number of clusters, respectively. Yi et al. (2016) and Chaganty and Liang (2013) requires , the -th singular value of some moment matrix. Sedghi et al. (2016) requires , the -th singular value of the matrix of the regression vectors. Note that and can be infinite even when . Zhong et al. (2016) algorithm requires and some spectral properties.

Meta-learning linear models have been studied in two contexts: mixed linear regression and multi-task learning.

Mixed Linear Regression (MLR). When each task has only one sample, (i.e. ), the problem has been widely studied. Prior work in MLR are summarized in Table 1. We emphasize that the sample and time complexity of all the previous work either has a super polynomial dependency on (specifically at least ) as in Zhong et al. (2016); Li and Liang (2018); Chen et al. (2020)), or depends on the inverse of the -th singular value of some moment matrix as in Chaganty and Liang (2013); Yi et al. (2016); Sedghi et al. (2016), which can be infinite. Chen et al. (2020) cannot achieve vanishing error when there is noise.

Multi-task learning. Baxter (2000); Ando and Zhang (2005); Rish et al. (2008); Orlitsky (2005) address a similar problem of finding an unknown -dimensional subspace, where all tasks can be accurately solved. The main difference is that all tasks have the same number of examples, and the performance is evaluated on the observed tasks used in training. Typical approaches use trace-norm to encourage low-rank solutions of the matrix . This is posed as a convex program Argyriou et al. (2008); Harchaoui et al. (2012); Amit et al. (2007); Pontil and Maurer (2013).

Closer to our work is the streaming setting where tasks are arriving in an online fashion and one can choose how many examples to collect for each. Balcan et al. (2015) provides an online algorithm using a memory of size only , but requires some tasks to have examples. In comparison, we only need but use memory. Bullins et al. (2019) also use only small memory, but requires total samples to perform the subspace estimation under the setting studied in this paper.

Empirical Bayes/Population of parameters. A simple canonical setting of probabilistic meta-learning is when