MultiStep ModelAgnostic MetaLearning: Convergence and Improved Algorithms
Abstract
As a popular metalearning approach, the modelagnostic metalearning (MAML) algorithm has been widely used due to its simplicity and effectiveness. However, the convergence of the general multistep MAML still remains unexplored. In this paper, we develop a new theoretical framework, under which we characterize the convergence rate and the computational complexity of multistep MAML. Our results indicate that although the estimation bias and variance of the stochastic meta gradient involve exponential factors of (the number of the innerstage gradient updates), MAML still attains the convergence with complexity increasing only linearly with with a properly chosen inner stepsize. We then take a further step to develop a more efficient Hessianfree MAML. We first show that the existing zerothorder Hessian estimator contains a constantlevel estimation error so that the MAML algorithm can perform unstably. To address this issue, we propose a novel Hessian estimator via a gradientbased Gaussian smoothing method, and show that it achieves a much smaller estimation bias and variance, and the resulting algorithm achieves the same performance guarantee as the original MAML under mild conditions. Our experiments validate our theory and demonstrate the effectiveness of the proposed Hessian estimator.
1 Introduction
Metalearning or learning to learn (Thrun and Pratt, 2012; Naik and Mammone, 1992; Bengio et al., 1991) is a powerful tool for quickly learning new tasks by using the prior experience from related tasks. Recent works have empowered this idea with neural networks, and their proposed metalearning algorithms have been shown to enable fast learning over unseen tasks using only a few samples by efficiently extracting the knowledge from a range of observed tasks (Santoro et al., 2016; Vinyals et al., 2016; Finn et al., 2017a). Current metalearning algorithms can be generally categorized into metriclearning based (Koch et al., 2015; Snell et al., 2017), modelbased (Vinyals et al., 2016; Munkhdalai and Yu, 2017), and optimizationbased (Finn et al., 2017a; Nichol and Schulman, 2018; Rajeswaran et al., 2019) approaches. Among them, optimizationbased metalearning is a simple and effective approach used in a wide range of domains including classification/regression (Rajeswaran et al., 2019), reinforcement learning (Finn et al., 2017a), robotics (AlShedivat et al., 2017), federated learning (Chen et al., 2018), and imitation learning (Finn et al., 2017b).
Modelagnostic metalearning (MAML) (Finn et al., 2017a) is a popular optimizationbased method, which is simple and compatible generally with models trained with gradient descents. MAML consists of two nested stages, where the inner stage runs a few steps of (stochastic) gradient descent for each individual task, and the outer stage updates the meta parameter over all the sampled tasks. The goal of MAML is to find a good meta initialization based on the observed tasks such that for a new task, starting from this , a few (stochastic) gradient steps suffice to find a good model parameter. Although having been widely used (Antoniou et al., 2018; Grant et al., 2018; Zintgraf et al., 2018; Nichol et al., 2018), the theoretical convergence of MAML is not well explored except a few attempts recently. In particular, Finn et al. (2019) extended MAML to the online setting, and analyzed the regret for the strongly convex objective function. More recently, Fallah et al. (2019) provided an analysis for onestep MAML for general nonconvex functions, where each inner stage takes only a single stochastic gradient descent (SGD) step. However, the convergence of the more general and frequentlyused multistep MAML has not been explored yet. Several new challenges arise here in theoretical analysis due to the multistep inner stage updates in MAML. First, the meta gradient of multistep MAML has a nested and recursive structure, which requires the analysis of the performance of an optimization path over a nested structure. In addition, multistep update also yields a complicated bias error in the Hessian estimation as well as the statistical correlation between the Hessian and gradient estimators, both of which cause further difficulty in the analysis of the meta gradient.

The first contribution of this paper lies in the development of a new theoretical framework for analyzing multistep MAML with techniques for handling the above challenges. For the resampling case (presented in the main body of the paper) where each iteration needs sampling of fresh data (e.g., in reinforcement learning), our analysis enables to decouple the Hessian approximation error from the gradient approximation error based on a novel bound on the distance between two different inner optimization paths, which facilitates the analysis of the overall convergence of MAML. For the finitesum case (presented in the appendix) where the objective function is based on preassigned samples (e.g., supervised learning), we develop novel techniques to handle the difference between two losses over the training and test sets in the analysis.
Our analysis provides a guideline for choosing the innerstage stepsize at the order of and shows that step MAML is guaranteed to converge with the gradient and Hessian computation complexites growing only linearly with . In addition, for problems where Hessians are small, e.g., most classification/regression metalearning problems (Finn et al., 2017a), we show that the inner stepsize can be set larger while still maintaining the convergence, which explains the empirical findings for MAML training in Finn et al. (2017a); Antoniou et al. (2018); Rajeswaran et al. (2019).
Although achieving promising performance in many areas, MAML has a high computation and memory cost due to the computation of Hessians, and such a cost increases dramatically if the inner stage takes multiple gradient steps (Rajeswaran et al., 2019). To address such an issue, Hessianfree MAML algorithms have gained a lot of attention recently (Nichol and Schulman, 2018; Finn et al., 2017a; Song et al., 2020). Among them, Song et al. (2020) recently proposed an evolution strategies based MAML (ESMAML) approach by using the zerothorder (i.e., the function value) information for Hessian approximation. Though ESMAML exhibits promising empirical performance, Song et al. (2020) did not provide the convergence guarantee for this algorithm, and their experiments demonstrated that the zerothorder Hessian approximation sometimes leads to an inferior and unstable training in some RL problems.

The second contribution of this paper includes two parts. We first provide a theoretical analysis of the zerothorder Hessian approximation (Song et al., 2020) and show that such estimation contains a constantlevel error so that ESMAML cannot converge exactly to a stationary point but with a possibly large error. This explains the inferior and unstable training by ESMAML observed in Song et al. (2020).
We then propose a firstorder Hessian estimator by Gaussian smoothing, which we show to achieve a much smaller bias and variance. More importantly, we show the resulting algorithm (which we call as GGSMAML) achieves the same performance guarantee as the original MAML under mild conditions. Our analysis develops novel techniques for charactering the variance of Gaussian smoothing based Hessian estimators by random matrix theory, which can be of independent interest. Our experiments validate the effectiveness of GGSMAML.
1.1 Related Work
Optimizationbased metalearning. Optimizationbased metalearning approaches have been widely used due to its simplicity and efficiency (Li et al., 2017; Ravi and Larochelle, 2016; Finn et al., 2017a). As a pioneer along this line, MAML (Finn et al., 2017a) aims to find an initialization such that gradient descent from it achieves fast adaptation. Many followup studies (Grant et al., 2018; Finn et al., 2019; Jerfel et al., 2018; Finn and Levine, 2017; Finn et al., 2018; Mi et al., 2019; Liu et al., 2019; Rothfuss et al., 2018; Foerster et al., 2018; Fallah et al., 2019; Collins et al., 2020; Fallah et al., 2020) have extended MAML from different perspectives. For example, Finn et al. (2019) provided a followthemetaleader extension of MAML for online learning. Alternatively to metainitialization algorithms such as MAML, metaregularization approaches aim to learn a good bias for a regularized empirical risk minimization problem for intratask learning (Alquier et al., 2016; Denevi et al., 2018b, a, 2019; Rajeswaran et al., 2019; Balcan et al., 2019; Zhou et al., 2019). For example, Rajeswaran et al. (2019) proposed efficient iMAML using a conjugate gradient (CG) based solver. Balcan et al. (2019) formalized a connection between metainitialization and metaregularization from an online learning perspective. Zhou et al. (2019) proposed an efficient metalearning approach based on a minibatch proximal update.
Theory for MAMLtype algorithms. There have been only a few studies on the statistical and convergence performance of MAMLtype algorithms. Finn and Levine (2017) showed that MAML is a universal learning algorithm approximator under certain conditions. Finn et al. (2019) analyzed online MAML for a strongly convex objective function under a boundedgradient assumption. Rajeswaran et al. (2019) proposed a metaregularization variant of MAML named iMAML, and analyzed its convergence by assuming that the regularized empirical risk minimization problem in the inner optimization stage is strongly convex. Fallah et al. (2020) proposed a variant of RLMAML named stochastic gradient metareinforcement Learning (SGMRL), and analyzed the convergence and complexity performance of onestep SGMRL. Fallah et al. (2019) developed a convergence analysis for onestep MAML for a general nonconvex objective in the resampling case. Our study here provides a new convergence analysis for multistep MAML in the nonconvex setting for both the resampling and finitesum cases.
Hessianfree MAML. Various Hessianfree MAML algorithms have been proposed, which include but not limited to FOMAML (Finn et al., 2017a), Reptile (Nichol and Schulman, 2018), ESMAML (Song et al., 2020), and HFMAML (Fallah et al., 2019). In particular, FOMAML (Finn et al., 2017a) omits all secondorder derivatives in its metagradient computation, ESMAML (Song et al., 2020) approximates Hessian matrices in RL using a zerothorder smoothing method, and HFMAML (Fallah et al., 2019) estimates the meta gradient in onestep MAML using Hessianvector product approximation. In this paper, we propose a new Hessianfree MAML algorithm based on a firstorder Hessian estimator and study its analytical and empirical performance.
2 Problem Setup
In this paper, we study the convergence of the multistep MAML algorithm. Two types of objective functions are commonly used in practice: (a) resampling case (Finn et al., 2017a; Fallah et al., 2019), where loss functions take the form in expectation and new data are sampled as the algorithm runs; and (b) finitesum case (Antoniou et al., 2018), where loss functions take the finitesum form with given samples. The resampling case occurs often in reinforcement learning where data are continuously sampled as the algorithm iterates, whereas the finitesum case typically occurs in classification problems where the datasets are already sampled in advance. The main body of this paper will focus on the resampling case, and we relegate all materials for the finitesum case to Appendix D.
2.1 MultiStep MAML Based on Full Gradient
Suppose a set of tasks are available for learning and tasks are sampled based on a probability distribution over the task set. Assume that each task is associated with a loss parameterized by .
The goal of multistep MAML is to find a good initial parameter such that after observing a new task, a few gradient descend steps starting from such a point can efficiently approach the optimizer (or a stationary point) of the corresponding loss function. Towards this end, multistep MAML consists of two nested stages, where the inner stage consists of multiple steps of (stochastic) gradient descent for each individual tasks, and the outer stage updates the meta parameter over all the sampled tasks. More specifically, at each inner stage, each initializes at the meta parameter, i.e., , and runs gradient descent steps as
(1) 
Thus, the loss of task after the step inner stage iteration is given by , where depends on the meta parameter through the iteration updates in (1), and can hence be written as . We further define , and hence the overall meta objective is given by
(2) 
Then the outer stage of meta update is a gradient decent step to optimize the above objective function. Using the chain rule, we provide a simplified form of gradient by
(3) 
where for all task (see Appendix F).
2.2 MutiStep MAML Based on Stochastic Gradient
The inner and outerstage updates of MAML given in (1) and (4) involve the gradient and the Hessian of the loss function , which takes the form of the expectation over the distribution of data samples as given by
(5) 
where represents the data sample. In practice, these two quantities based on the population loss function are estimated by samples. In specific, each task samples a batch of data under the current parameter, and uses and as unbiased estimates of the gradient and the Hessian , respectively. See Section B.1 for a RL example.
For practical multistep MAML as shown in Algorithm 1, at the outer stage, we sample a set of tasks. Then, at the inner stage, each task samples a training set for each iteration in the inner stage, uses as an estimate of in (1), and runs a SGD update as
(6) 
where the initialization parameter for all .
At the outer stage, we draw a batch and of data samples independent from each other and both independent from (for ) and use and to estimate and in (4), respectively. Then, the meta parameter at the outer stage is updated by the following SGD step
(7) 
where of task is given by
(8) 
For simplicity, we assume the sizes of the sample sets , and are , and in this paper.
3 Convergence of MultiStep MAML
In this section, we provide convergence analysis for multistep MAML algorithm in the resampling case.
3.1 Basic Assumptions
Assumption 1.
The loss of task given by (5) satisfies

The loss is bounded below, i.e., .

The gradient is Lipschitz for any , i.e., for any ,

The Hessian is Lipschitz for any , i.e., for any ,
By the definition of the objective function in (2), item 1 of Assumption 1 implies that is bounded below. In addition, item 2 implies for any . For notational convenience, we take and .
The following assumptions impose the boundedvariance conditions on , and .
Assumption 2.
The stochastic gradient (with uniformly randomly chosen from set ) has bounded variance, i.e., there exists a constant such that, for any ,
where the expected loss function .
Assumption 3.
For any and , there exist constants such that
Note that the above assumptions are made only on individual loss functions rather than on the total loss , because some conditions do not hold for , as shown later.
3.2 Challenges of Analyzing MultiStep MAML
Several new challenges arise when we analyze the convergence of multistep MAML (with ) compared to the onestep case (with ).
First, each iteration of the meta parameter affects the overall objective function via a nested structure of step SGD optimization paths over all tasks. Hence, our analysis of the convergence of such a meta parameter needs to characterize the nested structure and the recursive updates.
Second, the meta gradient estimator given in (8) involves for , which are all biased estimators of in terms of the randomness over . This is because is a stochastic estimator of obtained via random training sets along an step SGD optimization path in the inner stage. In fact, such a bias error occurs only for multistep MAML with (which equals zero for ), and requires additional efforts to handle.
Third, both the Hessian term for and the gradient term in the meta gradient estimator given in (8) depend on the sample sets used for inner stage iteration to obtain , and hence they are statistically correlated even conditioned on . Such complication also occurs only for multistep MAML with and requires new treatment (the two terms are independent for ).
In Section 3.3, we develop a theoretical framework to handle the above challenges and establish the convergence for step MAML.
3.3 Properties of Meta Gradient
Differently from the conventional gradient whose corresponding loss is evaluated directly at the current parameter , the meta gradient has a more complicated nested structure with respect to , because its loss is evaluated at the final output of the inner optimization stage, which is step SGD updates. As a result, analyzing the meta gradient is very different and more challenging compared to analyzing the conventional gradient. In this subsection, we establish some important properties of the meta gradient which are useful for characterizing the convergence of multistep MAML. For notational convenience, let .
Recall that with given by (3). The following proposition characterizes the Lipschitz property of the gradient .
Proposition 1.
The proof of Proposition 1 handles the first challenge described in Section 3.2. More specifically, we bound the differences between and along two separate paths and , and then connect these differences to the distance . Proposition 1 shows that the objective has a gradientLipschitz parameter , which can be unbounded due to the fact that may be unbounded. Similarly to Fallah et al. (2019), we use
(10) 
to estimate at the meta parameter , where we independently sample the data sets . As will be shown in Theorem 1, we set the meta stepsize to be inversely proportional to to handle the possibly unboundedness. In the experiments, we find that the gradients are well bounded during the optimization process, and hence a constant outerstage stepsize is sufficient in practice.
We next characterize several estimation properties of the meta gradient estimator in (8). Here, we address the second and third challenges described in Section 3.2. We first quantify how far is away from , and then provide upper bounds on the first and secondmoment distances between and for all as below.
Proposition 2 (Pointwise distance between two paths).
Proposition 2 shows that we can effectively upperbound the pointwise distance between two paths by choosing and properly. Based on Proposition 2, we provide an upper bound on the firstmoment estimation error of .
Proposition 3 (Firstmoment error).
In contrast to the onestep case, i.e., Lemma 4.12 in Fallah et al. (2019), the estimation error for the multistep case shown in Proposition 3 involves an additional term , which cannot be avoided due to the Hessian approximation error caused by the randomness over the samples sets . Somewhat interestingly, our later analysis shows that this term does not affect the final convergence rate if we choose the size properly. The following proposition provides an upperbound on the second moment of the meta gradient estimator .
Proposition 4 (Secondmoment error).
By choosing set sizes and the inner stepsize properly, the factor in the secondmoment error bound in (14) can be made at a constant level and the first two error terms and can be made sufficiently small so that the variance of the meta gradient estimator can be well controlled in the convergence analysis, as shown later.
3.4 Main Convergence Result
By using the properties of the meta gradient established in Section 3.3, we provide the convergence rate for multistep MAML of Algorithm 1 in the following theorem.
Theorem 1.
The proof of Theorem 1 (see Appendix G for details) consists of four main steps: step of bounding an iterative meta update by the metagradient smoothness established by Proposition 1; step of characterizing firstmoment estimation error of the metagradient estimator by Proposition 3; step of characterizing secondmoment estimation error of the metagradient estimator by Proposition 4; and step of combining steps 13, and telescoping to yield the convergence.
In Theorem 1, the convergence rate given by (16) mainly contains three parts: the first term indicates that the meta parameter converges sublinearly with the number of meta iterations, the second term captures the estimation error of for approximating the full gradient which can be made sufficiently small by choosing a large sample size , and the third term captures the estimation error and variance of the stochastic meta gradient, which can be made small by choosing large and (note that is proportional to both and ).
Our analysis reveals several insights for the convergence of multistep MAML as follows. (a) To guarantee convergence, we require (e.g., ). Hence, if the number of inner gradient steps is large and is not small (e.g., for some RL problems), we need to choose a small inner stepsize so that the last output of the inner stage has a strong dependence on the initialization (i.e., meta parameter), as also shown and explained in Rajeswaran et al. (2019). (b) For problems with small Hessians such as many classification/regression problems (Finn et al., 2017a), (which is an upper bound on the spectral norm of Hessian matrices) is small, and hence we can choose a larger . This explains the empirical findings in Finn et al. (2017a); Antoniou et al. (2018).
We next specify the selection of parameters to simplify the convergence result in Theorem 1 and derive the complexity of Algorithm 1 for finding an accurate stationary point.
Corollary 1 (Stepsize ).
Differently from the conventional SGD that requires a gradient complexity of , MAML requires a higher gradient complexity by a factor of , which is unavoidable because MAML requires tasks to achieve an accurate meta point, whereas SGD runs only over one task.
Corollary 1 shows that given a properly chosen inner stepsize, e.g., , MAML is guaranteed to converge with both the gradient and the Hessian computation complexities growing only linearly with . These results explain some empirical findings for MAML training in Rajeswaran et al. (2019). The above results can also be obtained by using a larger stepsize such as with a certain constant .
4 Efficient HessianFree MAML Algorithms
The presence of Hessian in MAML causes significant computational complexity and storage cost, especially in the multistep case. More recently, a Hessianfree MAML algorithm based on a zerothorder Hessian estimator has been proposed in Song et al. (2020). In this section, we first theoretically show the performance limitation of the zerothorder Hessian estimator, and then propose a new Hessianfree MAML algorithm that has provable performance guarantee.
4.1 Limitations of ZerothOrder Hessian Estimation
The zerothorder Hessian estimator
(17) 
where and with , and sampled under parameters , and , are i.i.d. standard Gaussian vectors, is the batch size of , and is a positive smoothing parameter.
We next provide an analysis on the estimation error of , which characterizes its performance limitation.
Proposition 5.
Suppose Assumption 1 holds. Let be the smoothed approximation of the function . Then, conditioning on , we have
where is bounded by .
Though the first term in the above equality can be sufficiently close to the true Hessian for a small , and the error term can also be small for small enough , there exists a constantlevel error term . Due to such a constant bias term, the corresponding algorithm can converge only to a neighborhood around a stationary point with an error , which can be substantial. Furthermore, we show by Proposition 8 in Appendix C that the inner stepsize must satisfy a more restricted requirement rather than in the original MAML. Such an issue still exists in the finitesum case.
4.2 A Simple MAML Variant by GradientBased Gaussian Smoothing
Our analysis in Section 4.1 shows that zerothorder MAML does not provide desired performance due to a possibly large estimation error of the Hessian. Here, we develop a more accurate Hessian estimator via a gradientbased Gaussian smoothing (GGS) method, and propose an easytoimplement MAML variant (GGSMAML). We show in the next two subsections that GGSMAML admits a much bigger stepsize and achieves better convergence guarantee than the zerothorder MAML.
As shown in Algorithm 2, GGSMAML has a similar structure as the original MAML, but constructs Hessian approximation using Gaussian random vectors by
(18) 
where is the batch size of and . We then update by
(19) 
where .
4.3 Convergence Analysis for GGSMAML
The following proposition bounds the estimation error of for approximating the true Hessian .
Proposition 6.
Suppose that Assumption 1 hold. Then, conditioning on , we have
Proposition 6 shows that our Hessian estimator eliminates the constant error term in the zerothorder Hessian estimation, and its estimation error can be made sufficiently small by choosing a small . We next provide an upper bound on the estimation variance of by random matrix theory.
Proposition 7.
Suppose Assumption 1 holds. Then, conditioning on , we have
Var  (20) 
where the parameter is given by
(21) 
where and is a constant independent of and .
Our analysis of the variance bound involves novel techniques using random matrix theory and a generalized Rudelson’s inequality, which may be of independent interest.
Based on Propositions 6 and 7, we next provide the convergence and complexity analysis for GGSMAML.
Theorem 2.
Suppose that Assumptions 1, 2, and 3 hold. Set and choose such that , , and , where and are given in Proposition 7. We then have
To achieve an stationary point, i.e., , Algorithm 2 requires at most iterations, and gradient computations per meta iteration.
Theorem 2 captures a tradeoff between the inner stepsize and the number of gradients for Hessian approximation in (18). To see this, consider the case with for illustration. The choice of in the above theorem requires , where . Such a requirement indicates that the increase of yields more accurate Hessian approximation and hence admits a larger inner stepsize , but causes higher gradient computation cost. For example, if we choose and , then the corresponding inner stepsize is the same as that in Corollary 1 used for original MAML, but the gradient computation complexity becomes larger. On the other hand, if we choose a smaller or the norm of Hessian is small (i.e., is small), we can choose a smaller to save computation cost. This is because the Hessianrelated components in the meta gradient estimator are in terms of , and thus a smaller yields a smaller variance of metagradient estimation.
5 Experiments
In this section, we validate our theory and the effectiveness of the proposed GGSMAML algorithm. We present our experiments on two metalearning problems, i.e., rankone matrix factorization and regression, in this section, and relegate further experiments on reinforcement learning to Appendix A. All experiments are run using PyTorch (Paszke et al., 2019). The code for matrix factorization and regression is available at https://github.com/JunjieYang97/GGSMAMLDL and the code for reinforcement learning is available at https://github.com/JunjieYang97/GGSMAMLRL.
We compare the performance among the following four algorithms: GGSMAML proposed in this paper (Hessianfree and based on firstorder Gaussian smoothing Hessian estimator), ZOMAML
5.1 RankOne Matrix Factorization
We study a rankone matrix factorization problem. For each task , consider a loss function , where is a random vector sampled from the Gaussian distribution . At each meta iteration, different tasks are sampled for metatraining. For all algorithms, we run innerstage gradient updates with a fixed stepsize . At the outer stage (meta iteration), we use SGD as the metaoptimizer with the learning rate choosing from . For GGSMAML and ZOMAML, we use Gaussian vectors for Hessian approximation and choose the best smoothing parameter from for each algorithm.
In Figure 1, we plot how the loss function changes as the algorithms iterate. Since we sample new tasks at each iteration (which maximizes the expectation form of the objective function in the resampling case), the meta loss value also represents the test loss of the meta parameter. Figure 1 illustrates that the proposed GGSMAML outperforms the other two Hessianfree algorithms ZOMAML and FOMAML, and achieves a comparable accuracy to the original MAML. This validates the effectiveness of the proposed Hessian estimator used in GGSMAML. The figure also indicates that ZOMAML converges much slower than other algorithms due to a large bias of the zerothorder Hessian estimator, as we theoretically show in Proposition 5.
5.2 Sine Wave Regression
Following Finn and Levine (2017), we further consider a sine wave regression problem here. Each task uses a loss function