Multi-Step Model-Agnostic Meta-Learning: Convergence and Improved Algorithms
As a popular meta-learning approach, the model-agnostic meta-learning (MAML) algorithm has been widely used due to its simplicity and effectiveness. However, the convergence of the general multi-step MAML still remains unexplored. In this paper, we develop a new theoretical framework, under which we characterize the convergence rate and the computational complexity of multi-step MAML. Our results indicate that although the estimation bias and variance of the stochastic meta gradient involve exponential factors of (the number of the inner-stage gradient updates), MAML still attains the convergence with complexity increasing only linearly with with a properly chosen inner stepsize. We then take a further step to develop a more efficient Hessian-free MAML. We first show that the existing zeroth-order Hessian estimator contains a constant-level estimation error so that the MAML algorithm can perform unstably. To address this issue, we propose a novel Hessian estimator via a gradient-based Gaussian smoothing method, and show that it achieves a much smaller estimation bias and variance, and the resulting algorithm achieves the same performance guarantee as the original MAML under mild conditions. Our experiments validate our theory and demonstrate the effectiveness of the proposed Hessian estimator.
Meta-learning or learning to learn (Thrun and Pratt, 2012; Naik and Mammone, 1992; Bengio et al., 1991) is a powerful tool for quickly learning new tasks by using the prior experience from related tasks. Recent works have empowered this idea with neural networks, and their proposed meta-learning algorithms have been shown to enable fast learning over unseen tasks using only a few samples by efficiently extracting the knowledge from a range of observed tasks (Santoro et al., 2016; Vinyals et al., 2016; Finn et al., 2017a). Current meta-learning algorithms can be generally categorized into metric-learning based (Koch et al., 2015; Snell et al., 2017), model-based (Vinyals et al., 2016; Munkhdalai and Yu, 2017), and optimization-based (Finn et al., 2017a; Nichol and Schulman, 2018; Rajeswaran et al., 2019) approaches. Among them, optimization-based meta-learning is a simple and effective approach used in a wide range of domains including classification/regression (Rajeswaran et al., 2019), reinforcement learning (Finn et al., 2017a), robotics (Al-Shedivat et al., 2017), federated learning (Chen et al., 2018), and imitation learning (Finn et al., 2017b).
Model-agnostic meta-learning (MAML) (Finn et al., 2017a) is a popular optimization-based method, which is simple and compatible generally with models trained with gradient descents. MAML consists of two nested stages, where the inner stage runs a few steps of (stochastic) gradient descent for each individual task, and the outer stage updates the meta parameter over all the sampled tasks. The goal of MAML is to find a good meta initialization based on the observed tasks such that for a new task, starting from this , a few (stochastic) gradient steps suffice to find a good model parameter. Although having been widely used (Antoniou et al., 2018; Grant et al., 2018; Zintgraf et al., 2018; Nichol et al., 2018), the theoretical convergence of MAML is not well explored except a few attempts recently. In particular, Finn et al. (2019) extended MAML to the online setting, and analyzed the regret for the strongly convex objective function. More recently, Fallah et al. (2019) provided an analysis for one-step MAML for general nonconvex functions, where each inner stage takes only a single stochastic gradient descent (SGD) step. However, the convergence of the more general and frequently-used multi-step MAML has not been explored yet. Several new challenges arise here in theoretical analysis due to the multi-step inner stage updates in MAML. First, the meta gradient of multi-step MAML has a nested and recursive structure, which requires the analysis of the performance of an optimization path over a nested structure. In addition, multi-step update also yields a complicated bias error in the Hessian estimation as well as the statistical correlation between the Hessian and gradient estimators, both of which cause further difficulty in the analysis of the meta gradient.
The first contribution of this paper lies in the development of a new theoretical framework for analyzing multi-step MAML with techniques for handling the above challenges. For the resampling case (presented in the main body of the paper) where each iteration needs sampling of fresh data (e.g., in reinforcement learning), our analysis enables to decouple the Hessian approximation error from the gradient approximation error based on a novel bound on the distance between two different inner optimization paths, which facilitates the analysis of the overall convergence of MAML. For the finite-sum case (presented in the appendix) where the objective function is based on pre-assigned samples (e.g., supervised learning), we develop novel techniques to handle the difference between two losses over the training and test sets in the analysis.
Our analysis provides a guideline for choosing the inner-stage stepsize at the order of and shows that -step MAML is guaranteed to converge with the gradient and Hessian computation complexites growing only linearly with . In addition, for problems where Hessians are small, e.g., most classification/regression meta-learning problems (Finn et al., 2017a), we show that the inner stepsize can be set larger while still maintaining the convergence, which explains the empirical findings for MAML training in Finn et al. (2017a); Antoniou et al. (2018); Rajeswaran et al. (2019).
Although achieving promising performance in many areas, MAML has a high computation and memory cost due to the computation of Hessians, and such a cost increases dramatically if the inner stage takes multiple gradient steps (Rajeswaran et al., 2019). To address such an issue, Hessian-free MAML algorithms have gained a lot of attention recently (Nichol and Schulman, 2018; Finn et al., 2017a; Song et al., 2020). Among them, Song et al. (2020) recently proposed an evolution strategies based MAML (ES-MAML) approach by using the zeroth-order (i.e., the function value) information for Hessian approximation. Though ES-MAML exhibits promising empirical performance, Song et al. (2020) did not provide the convergence guarantee for this algorithm, and their experiments demonstrated that the zeroth-order Hessian approximation sometimes leads to an inferior and unstable training in some RL problems.
The second contribution of this paper includes two parts. We first provide a theoretical analysis of the zeroth-order Hessian approximation (Song et al., 2020) and show that such estimation contains a constant-level error so that ES-MAML cannot converge exactly to a stationary point but with a possibly large error. This explains the inferior and unstable training by ES-MAML observed in Song et al. (2020).
We then propose a first-order Hessian estimator by Gaussian smoothing, which we show to achieve a much smaller bias and variance. More importantly, we show the resulting algorithm (which we call as GGS-MAML) achieves the same performance guarantee as the original MAML under mild conditions. Our analysis develops novel techniques for charactering the variance of Gaussian smoothing based Hessian estimators by random matrix theory, which can be of independent interest. Our experiments validate the effectiveness of GGS-MAML.
1.1 Related Work
Optimization-based meta-learning. Optimization-based meta-learning approaches have been widely used due to its simplicity and efficiency (Li et al., 2017; Ravi and Larochelle, 2016; Finn et al., 2017a). As a pioneer along this line, MAML (Finn et al., 2017a) aims to find an initialization such that gradient descent from it achieves fast adaptation. Many follow-up studies (Grant et al., 2018; Finn et al., 2019; Jerfel et al., 2018; Finn and Levine, 2017; Finn et al., 2018; Mi et al., 2019; Liu et al., 2019; Rothfuss et al., 2018; Foerster et al., 2018; Fallah et al., 2019; Collins et al., 2020; Fallah et al., 2020) have extended MAML from different perspectives. For example, Finn et al. (2019) provided a follow-the-meta-leader extension of MAML for online learning. Alternatively to meta-initialization algorithms such as MAML, meta-regularization approaches aim to learn a good bias for a regularized empirical risk minimization problem for intra-task learning (Alquier et al., 2016; Denevi et al., 2018b, a, 2019; Rajeswaran et al., 2019; Balcan et al., 2019; Zhou et al., 2019). For example, Rajeswaran et al. (2019) proposed efficient iMAML using a conjugate gradient (CG) based solver. Balcan et al. (2019) formalized a connection between meta-initialization and meta-regularization from an online learning perspective. Zhou et al. (2019) proposed an efficient meta-learning approach based on a minibatch proximal update.
Theory for MAML-type algorithms. There have been only a few studies on the statistical and convergence performance of MAML-type algorithms. Finn and Levine (2017) showed that MAML is a universal learning algorithm approximator under certain conditions. Finn et al. (2019) analyzed online MAML for a strongly convex objective function under a bounded-gradient assumption. Rajeswaran et al. (2019) proposed a meta-regularization variant of MAML named iMAML, and analyzed its convergence by assuming that the regularized empirical risk minimization problem in the inner optimization stage is strongly convex. Fallah et al. (2020) proposed a variant of RL-MAML named stochastic gradient meta-reinforcement Learning (SG-MRL), and analyzed the convergence and complexity performance of one-step SG-MRL. Fallah et al. (2019) developed a convergence analysis for one-step MAML for a general nonconvex objective in the resampling case. Our study here provides a new convergence analysis for multi-step MAML in the nonconvex setting for both the resampling and finite-sum cases.
Hessian-free MAML. Various Hessian-free MAML algorithms have been proposed, which include but not limited to FOMAML (Finn et al., 2017a), Reptile (Nichol and Schulman, 2018), ES-MAML (Song et al., 2020), and HF-MAML (Fallah et al., 2019). In particular, FOMAML (Finn et al., 2017a) omits all second-order derivatives in its meta-gradient computation, ES-MAML (Song et al., 2020) approximates Hessian matrices in RL using a zeroth-order smoothing method, and HF-MAML (Fallah et al., 2019) estimates the meta gradient in one-step MAML using Hessian-vector product approximation. In this paper, we propose a new Hessian-free MAML algorithm based on a first-order Hessian estimator and study its analytical and empirical performance.
2 Problem Setup
In this paper, we study the convergence of the multi-step MAML algorithm. Two types of objective functions are commonly used in practice: (a) resampling case (Finn et al., 2017a; Fallah et al., 2019), where loss functions take the form in expectation and new data are sampled as the algorithm runs; and (b) finite-sum case (Antoniou et al., 2018), where loss functions take the finite-sum form with given samples. The resampling case occurs often in reinforcement learning where data are continuously sampled as the algorithm iterates, whereas the finite-sum case typically occurs in classification problems where the datasets are already sampled in advance. The main body of this paper will focus on the resampling case, and we relegate all materials for the finite-sum case to Appendix D.
2.1 Multi-Step MAML Based on Full Gradient
Suppose a set of tasks are available for learning and tasks are sampled based on a probability distribution over the task set. Assume that each task is associated with a loss parameterized by .
The goal of multi-step MAML is to find a good initial parameter such that after observing a new task, a few gradient descend steps starting from such a point can efficiently approach the optimizer (or a stationary point) of the corresponding loss function. Towards this end, multi-step MAML consists of two nested stages, where the inner stage consists of multiple steps of (stochastic) gradient descent for each individual tasks, and the outer stage updates the meta parameter over all the sampled tasks. More specifically, at each inner stage, each initializes at the meta parameter, i.e., , and runs gradient descent steps as
Thus, the loss of task after the -step inner stage iteration is given by , where depends on the meta parameter through the iteration updates in (1), and can hence be written as . We further define , and hence the overall meta objective is given by
Then the outer stage of meta update is a gradient decent step to optimize the above objective function. Using the chain rule, we provide a simplified form of gradient by
where for all task (see Appendix F).
2.2 Muti-Step MAML Based on Stochastic Gradient
The inner- and outer-stage updates of MAML given in (1) and (4) involve the gradient and the Hessian of the loss function , which takes the form of the expectation over the distribution of data samples as given by
where represents the data sample. In practice, these two quantities based on the population loss function are estimated by samples. In specific, each task samples a batch of data under the current parameter, and uses and as unbiased estimates of the gradient and the Hessian , respectively. See Section B.1 for a RL example.
For practical multi-step MAML as shown in Algorithm 1, at the outer stage, we sample a set of tasks. Then, at the inner stage, each task samples a training set for each iteration in the inner stage, uses as an estimate of in (1), and runs a SGD update as
where the initialization parameter for all .
At the outer stage, we draw a batch and of data samples independent from each other and both independent from (for ) and use and to estimate and in (4), respectively. Then, the meta parameter at the outer stage is updated by the following SGD step
where of task is given by
For simplicity, we assume the sizes of the sample sets , and are , and in this paper.
3 Convergence of Multi-Step MAML
In this section, we provide convergence analysis for multi-step MAML algorithm in the resampling case.
3.1 Basic Assumptions
The loss of task given by (5) satisfies
The loss is bounded below, i.e., .
The gradient is -Lipschitz for any , i.e., for any ,
The Hessian is -Lipschitz for any , i.e., for any ,
The following assumptions impose the bounded-variance conditions on , and .
The stochastic gradient (with uniformly randomly chosen from set ) has bounded variance, i.e., there exists a constant such that, for any ,
where the expected loss function .
For any and , there exist constants such that
Note that the above assumptions are made only on individual loss functions rather than on the total loss , because some conditions do not hold for , as shown later.
3.2 Challenges of Analyzing Multi-Step MAML
Several new challenges arise when we analyze the convergence of multi-step MAML (with ) compared to the one-step case (with ).
First, each iteration of the meta parameter affects the overall objective function via a nested structure of -step SGD optimization paths over all tasks. Hence, our analysis of the convergence of such a meta parameter needs to characterize the nested structure and the recursive updates.
Second, the meta gradient estimator given in (8) involves for , which are all biased estimators of in terms of the randomness over . This is because is a stochastic estimator of obtained via random training sets along an -step SGD optimization path in the inner stage. In fact, such a bias error occurs only for multi-step MAML with (which equals zero for ), and requires additional efforts to handle.
Third, both the Hessian term for and the gradient term in the meta gradient estimator given in (8) depend on the sample sets used for inner stage iteration to obtain , and hence they are statistically correlated even conditioned on . Such complication also occurs only for multi-step MAML with and requires new treatment (the two terms are independent for ).
In Section 3.3, we develop a theoretical framework to handle the above challenges and establish the convergence for -step MAML.
3.3 Properties of Meta Gradient
Differently from the conventional gradient whose corresponding loss is evaluated directly at the current parameter , the meta gradient has a more complicated nested structure with respect to , because its loss is evaluated at the final output of the inner optimization stage, which is -step SGD updates. As a result, analyzing the meta gradient is very different and more challenging compared to analyzing the conventional gradient. In this subsection, we establish some important properties of the meta gradient which are useful for characterizing the convergence of multi-step MAML. For notational convenience, let .
Recall that with given by (3). The following proposition characterizes the Lipschitz property of the gradient .
The proof of Proposition 1 handles the first challenge described in Section 3.2. More specifically, we bound the differences between and along two separate paths and , and then connect these differences to the distance . Proposition 1 shows that the objective has a gradient-Lipschitz parameter , which can be unbounded due to the fact that may be unbounded. Similarly to Fallah et al. (2019), we use
to estimate at the meta parameter , where we independently sample the data sets . As will be shown in Theorem 1, we set the meta stepsize to be inversely proportional to to handle the possibly unboundedness. In the experiments, we find that the gradients are well bounded during the optimization process, and hence a constant outer-stage stepsize is sufficient in practice.
We next characterize several estimation properties of the meta gradient estimator in (8). Here, we address the second and third challenges described in Section 3.2. We first quantify how far is away from , and then provide upper bounds on the first- and second-moment distances between and for all as below.
Proposition 2 (Point-wise distance between two paths).
Proposition 2 shows that we can effectively upper-bound the point-wise distance between two paths by choosing and properly. Based on Proposition 2, we provide an upper bound on the first-moment estimation error of .
Proposition 3 (First-moment error).
In contrast to the one-step case, i.e., Lemma 4.12 in Fallah et al. (2019), the estimation error for the multi-step case shown in Proposition 3 involves an additional term , which cannot be avoided due to the Hessian approximation error caused by the randomness over the samples sets . Somewhat interestingly, our later analysis shows that this term does not affect the final convergence rate if we choose the size properly. The following proposition provides an upper-bound on the second moment of the meta gradient estimator .
Proposition 4 (Second-moment error).
By choosing set sizes and the inner stepsize properly, the factor in the second-moment error bound in (14) can be made at a constant level and the first two error terms and can be made sufficiently small so that the variance of the meta gradient estimator can be well controlled in the convergence analysis, as shown later.
3.4 Main Convergence Result
The proof of Theorem 1 (see Appendix G for details) consists of four main steps: step of bounding an iterative meta update by the meta-gradient smoothness established by Proposition 1; step of characterizing first-moment estimation error of the meta-gradient estimator by Proposition 3; step of characterizing second-moment estimation error of the meta-gradient estimator by Proposition 4; and step of combining steps 1-3, and telescoping to yield the convergence.
In Theorem 1, the convergence rate given by (16) mainly contains three parts: the first term indicates that the meta parameter converges sublinearly with the number of meta iterations, the second term captures the estimation error of for approximating the full gradient which can be made sufficiently small by choosing a large sample size , and the third term captures the estimation error and variance of the stochastic meta gradient, which can be made small by choosing large and (note that is proportional to both and ).
Our analysis reveals several insights for the convergence of multi-step MAML as follows. (a) To guarantee convergence, we require (e.g., ). Hence, if the number of inner gradient steps is large and is not small (e.g., for some RL problems), we need to choose a small inner stepsize so that the last output of the inner stage has a strong dependence on the initialization (i.e., meta parameter), as also shown and explained in Rajeswaran et al. (2019). (b) For problems with small Hessians such as many classification/regression problems (Finn et al., 2017a), (which is an upper bound on the spectral norm of Hessian matrices) is small, and hence we can choose a larger . This explains the empirical findings in Finn et al. (2017a); Antoniou et al. (2018).
Corollary 1 (Stepsize ).
Differently from the conventional SGD that requires a gradient complexity of , MAML requires a higher gradient complexity by a factor of , which is unavoidable because MAML requires tasks to achieve an -accurate meta point, whereas SGD runs only over one task.
Corollary 1 shows that given a properly chosen inner stepsize, e.g., , MAML is guaranteed to converge with both the gradient and the Hessian computation complexities growing only linearly with . These results explain some empirical findings for MAML training in Rajeswaran et al. (2019). The above results can also be obtained by using a larger stepsize such as with a certain constant .
4 Efficient Hessian-Free MAML Algorithms
The presence of Hessian in MAML causes significant computational complexity and storage cost, especially in the multi-step case. More recently, a Hessian-free MAML algorithm based on a zeroth-order Hessian estimator has been proposed in Song et al. (2020). In this section, we first theoretically show the performance limitation of the zeroth-order Hessian estimator, and then propose a new Hessian-free MAML algorithm that has provable performance guarantee.
4.1 Limitations of Zeroth-Order Hessian Estimation
The zeroth-order Hessian estimator
where and with , and sampled under parameters , and , are i.i.d. standard Gaussian vectors, is the batch size of , and is a positive smoothing parameter.
We next provide an analysis on the estimation error of , which characterizes its performance limitation.
Suppose Assumption 1 holds. Let be the smoothed approximation of the function . Then, conditioning on , we have
where is bounded by .
Though the first term in the above equality can be sufficiently close to the true Hessian for a small , and the error term can also be small for small enough , there exists a constant-level error term . Due to such a constant bias term, the corresponding algorithm can converge only to a neighborhood around a stationary point with an error , which can be substantial. Furthermore, we show by Proposition 8 in Appendix C that the inner stepsize must satisfy a more restricted requirement rather than in the original MAML. Such an issue still exists in the finite-sum case.
4.2 A Simple MAML Variant by Gradient-Based Gaussian Smoothing
Our analysis in Section 4.1 shows that zeroth-order MAML does not provide desired performance due to a possibly large estimation error of the Hessian. Here, we develop a more accurate Hessian estimator via a gradient-based Gaussian smoothing (GGS) method, and propose an easy-to-implement MAML variant (GGS-MAML). We show in the next two subsections that GGS-MAML admits a much bigger stepsize and achieves better convergence guarantee than the zeroth-order MAML.
As shown in Algorithm 2, GGS-MAML has a similar structure as the original MAML, but constructs Hessian approximation using Gaussian random vectors by
where is the batch size of and . We then update by
4.3 Convergence Analysis for GGS-MAML
The following proposition bounds the estimation error of for approximating the true Hessian .
Suppose that Assumption 1 hold. Then, conditioning on , we have
Proposition 6 shows that our Hessian estimator eliminates the constant error term in the zeroth-order Hessian estimation, and its estimation error can be made sufficiently small by choosing a small . We next provide an upper bound on the estimation variance of by random matrix theory.
Suppose Assumption 1 holds. Then, conditioning on , we have
where the parameter is given by
where and is a constant independent of and .
Our analysis of the variance bound involves novel techniques using random matrix theory and a generalized Rudelson’s inequality, which may be of independent interest.
Theorem 2 captures a trade-off between the inner stepsize and the number of gradients for Hessian approximation in (18). To see this, consider the case with for illustration. The choice of in the above theorem requires , where . Such a requirement indicates that the increase of yields more accurate Hessian approximation and hence admits a larger inner stepsize , but causes higher gradient computation cost. For example, if we choose and , then the corresponding inner stepsize is the same as that in Corollary 1 used for original MAML, but the gradient computation complexity becomes larger. On the other hand, if we choose a smaller or the norm of Hessian is small (i.e., is small), we can choose a smaller to save computation cost. This is because the Hessian-related components in the meta gradient estimator are in terms of , and thus a smaller yields a smaller variance of meta-gradient estimation.
In this section, we validate our theory and the effectiveness of the proposed GGS-MAML algorithm. We present our experiments on two meta-learning problems, i.e., rank-one matrix factorization and regression, in this section, and relegate further experiments on reinforcement learning to Appendix A. All experiments are run using PyTorch (Paszke et al., 2019). The code for matrix factorization and regression is available at https://github.com/JunjieYang97/GGS-MAML-DL and the code for reinforcement learning is available at https://github.com/JunjieYang97/GGS-MAML-RL.
We compare the performance among the following four algorithms: GGS-MAML proposed in this paper (Hessian-free and based on first-order Gaussian smoothing Hessian estimator), ZO-MAML
5.1 Rank-One Matrix Factorization
We study a rank-one matrix factorization problem. For each task , consider a loss function , where is a random vector sampled from the Gaussian distribution . At each meta iteration, different tasks are sampled for meta-training. For all algorithms, we run inner-stage gradient updates with a fixed stepsize . At the outer stage (meta iteration), we use SGD as the meta-optimizer with the learning rate choosing from . For GGS-MAML and ZO-MAML, we use Gaussian vectors for Hessian approximation and choose the best smoothing parameter from for each algorithm.
In Figure 1, we plot how the loss function changes as the algorithms iterate. Since we sample new tasks at each iteration (which maximizes the expectation form of the objective function in the resampling case), the meta loss value also represents the test loss of the meta parameter. Figure 1 illustrates that the proposed GGS-MAML outperforms the other two Hessian-free algorithms ZO-MAML and FO-MAML, and achieves a comparable accuracy to the original MAML. This validates the effectiveness of the proposed Hessian estimator used in GGS-MAML. The figure also indicates that ZO-MAML converges much slower than other algorithms due to a large bias of the zeroth-order Hessian estimator, as we theoretically show in Proposition 5.
5.2 Sine Wave Regression
Following Finn and Levine (2017), we further consider a sine wave regression problem here. Each task uses a loss function