Personalized Federated Learning: A MetaLearning Approach
Abstract
The goal of federated learning is to design algorithms in which several agents communicate with a central node, in a privacyprotecting manner, to minimize the average of their loss functions. In this approach, each node not only shares the required computational budget but also has access to a larger data set, which improves the quality of the resulting model. However, this method only develops a common output for all the agents, and therefore, does not adapt the model to each user data. This is an important missing feature especially given the heterogeneity of the underlying data distribution for various agents. In this paper, we study a personalized variant of the federated learning in which our goal is to find a shared initial model in a distributed manner that can be slightly updated by either a current or a new user by performing one or a few steps of gradient descent with respect to its own loss function. This approach keeps all the benefits of the federated learning architecture while leading to a more personalized model for each user. We show this problem can be studied within the ModelAgnostic MetaLearning (MAML) framework. Inspired by this connection, we propose a personalized variant of the wellknown Federated Averaging algorithm and evaluate its performance in terms of gradient norm for nonconvex loss functions. Further, we characterize how this performance is affected by the closeness of underlying distributions of user data, measured in terms of distribution distances such as Total Variation and 1Wasserstein metric.
tabular \makesavenoteenvtable
1 Introduction
In Federated Learning (FL), we consider a network of users that are all connected to a central node (i.e., a star connectivity graph) where each user has access only to its local data (Konečnỳ et al., 2016). In this setting, the goal of the users is to come up with a model that is trained over all the data points in the network without exchanging their local data with other users or the central node, i.e., the server, due to privacy issues or communication limitations.
More formally, the classic FL setting studies a starshaped network with users and one server, and they all coordinate to solve the following optimization problem:
(1) 
where denotes the loss function corresponding to user . In particular, consider a supervised learning application, where represents expected loss over the data distribution of user , i.e.,
(2) 
where measures the error of model in predicting the true label given the input , and is the distribution over . We would like to emphasize that in this paper we study the case that the probability distribution of users in the network are not identical and we face a heterogeneous data probability distribution.
To illustrate this formulation, as an example, consider the problem of training a Natural Language Processing (NLP) model over the devices of a set of users. In this problem, represenrts the empirical distribution of words and expressions used by user , and hence, can be expressed as
(3) 
where is the data set corresponding to user and is the probability that user assigns to a specific word which is proportional to the frequency of using this word by user .
In most algorithms designed for FL, the problem in (1) is solved in multiple rounds, where at each round the center sends the current model to a fraction of the users and those users update the model with respect to their own loss functions, usually by performing a few steps of a gradientbased method. Then, these users return their updated models to the center, and the center combines the received models to update the global model (for example by averaging, as in FedAvg Algorithm (McMahan et al., 2016)) and sends the updated model to a (possibly different) fraction of the users for the next round. This way, the computational power of all the users in the network are used to train the global model. In addition, the shared model is trained over a larger data set which could lead to a better model. Indeed, this approach leads to a model that solves the problem in (1) and the resulted solution performs well over all users on average.
Closeness of data distributions of users is crucial for the success of the federated learning framework. However, it is not necessarily the case that the data samples of all users are drawn from a common underlying distribution. This heterogeneity leads to an issue with formulation (1) in that the resulting model is only good on average and it does not take into account the heterogeneity of data distribution of users. In other words, the solution of problem (1) is not personalized for each user. To better highlight this point, recall the NLP example above, where although the distribution over the words and expressions varies from one person to another, the solution to problem (1) only provides a shared answer for all users, and therefore, it is not fully capable of achieving a useradapted model.
Hence, in the setting that the underlying distribution of data points of the users are not identical, solving the average problem defined in (1) could lead to poor local performance for each user. In this paper, we overcome this issue by considering a new problem formulation. We further introduce an efficient method for solving the proposed formulation and characterize its convergence properties. A detailed list of our contributions follows:

We consider a modified formulation of the federated learning problem which incorporates personalization (Section 2). Building on the ModelAgnostic MetaLearning (MAML) problem formulation introduced by Finn et al. (2017), the goal of our formulation is to find an initial point shared between all users which performs well after each user updates it with respect to its own loss function, potentially by performing a few steps of a gradientbased method. This way, while the initial model is derived in a distributed manner over the whole network (same as the classic FL setting), the final model implemented by each user differs from other ones based on his or her own data.

We also propose a Personalized variant of the FedAvg algorithm, called PerFedAvg, designed for solving the proposed personalized FL problem (Section 3). In particular, we elaborate on its connections with the original FedAvg algorithm (McMahan et al., 2016), and also, discuss a number of considerations that one need to take into account for implementing PerFedAvg.

We study the convergence properties of the proposed PerFedAvg algorithm for solving nonconvex loss functions in terms of the objective function gradient norm (Section 4). In particular, we characterize the role of data heterogeneity and closeness of data distribution of different users, measured by distribution distances, such as Total Variation (TV) or 1Wasserstein, on convergence of PerFedAvg method.
1.1 Related Work
As mentioned earlier, McMahan et al. (2016) proposed the FedAvg algorithm, where the global model is updated by averaging local SGD updates. Later, Guha et al. (2019) proposed oneshot Federated Learning (FL) in which the master node learns the model after a single round of communication. Also, several approaches have been used to address the communication limitations in FL. This includes quantization and compression ideas (Reisizadeh et al., 2019; Dai et al., 2019) as well as performing multiple local updates before communicating with the master (Stich, 2018; Lin et al., 2018; Wang and Joshi, 2018). Several works have studied the problem of preserving privacy in federated learning (Duchi et al., 2014; McMahan et al., 2017; Agarwal et al., 2018; Zhu et al., 2019). More related to our paper, there are several works that study statistical heterogeneity of users’ data points in FL (Zhao et al., 2018; Sahu et al., 2018; Karimireddy et al., 2019; Haddadpour and Mahdavi, 2019; Khaled et al., 2019; Li et al., 2019), but they do not attempt to find a personalized solution for each user. In addition, Smith et al. (2017) used multitask learning framework and proposed a new method, MOCHA, to address these statistical and systems challenges (including data heterogeneity as well as communication efficiency).
The idea of personalization in FL and its connections with metalearning has recently gained attention in a number of papers. Khodak et al. (2019) proposed ARUBA, a metalearning algorithm inspired by online convex optimization, and showed how applying it to FedAvg method improves its performance empirically. Jiang et al. (2019) proposed a personalized FedAvg algorithm in which the classic FedAvg is first deployed, and then they switch to Reptile, a metalearning algorithm proposed in (Nichol et al., 2018), and finally run local updates to achieve personalization. Note that this approach is different from our proposed framework, as in this paper we do not perform the classic FedAvg and instead we look for a good initial point which performs well after it is finetuned for each user. Moreover, Chen et al. (2018) focused on recommendation systems and proposed a metafederated learning framework in which a parameterized metaalgorithm is used to train parameterized recommendation models and both metaalgorithm and local models’ parameters need to be optimized. For the special case that the metaalgorithm parameter is its initialization, this framework reduces to our formulation. The authors evaluated the success of this framework empirically over various data sets and by taking different metaalgorithms. However, in our work, we specifically focus on the case that the metaalgorithm parameter is the initial point, and characterize its convergence theoretically, and highlight the role of different parameters including heterogeneity of data distributions. We further provide empirical results for our proposed method. For a detailed survey on the connections of FL and multitask and metalearning check Section 3.3 of (Kairouz et al., 2019).
2 Personalized Federated Learning via ModelAgnostic MetaLearning (MAML)
As we stated in Section 1, our goal in this section is to show how the fundamental idea behind the ModelAgnostic MetaLearning (MAML) framework in (Finn et al., 2017) can be exploited to design a personalized variant of the FL problem. To do so, let us first briefly recap the MAML formulation. Given a set of tasks drawn from an underlying distribution, in MAML, in contrast to the traditional supervised learning setting, the goal is not finding a model which performs well on all the tasks in expectation. Instead, in MAML, we assume we have a limited computational budget to update our model after a new task arrives, and in this new setting, we look for an initialization which performs well after it is updated with respect to this new task, possibly by one or a few steps of gradient descent. In particular, if we assume each user takes the initial point and updates it using one step of gradient descent with respect to its own loss function, then problem (1) changes to
(4) 
where is the learning rate (stepsize). The strength of this formulation is that, not only it allows us to maintain the advantages of FL (limited communication), but also it captures the difference between users as either existing or new users can take the solution of this new problem as an initial point and slightly update it with respect to their own data. Going back to the NLP example (3), this means that each users could take this resulting initialization and update it by going over her/his own data and performing just one or few steps of gradient descent to obtain a model that works well for her/his own dataset.
As we mentioned earlier, for the considered heterogeneous model of data distribution, solving problem (1) is not the ideal choice as it returns a single model that even after a few steps of local gradient may not quickly adjust to each users local data, but by solving (4) we find an initial model (Metamodel) which is trained in a way that after one step of local gradient leads to a good model for each individual user. Indeed, this formulation can also be extended to the case that each user runs a few steps of gradient update, but to simplify our notation we only focus on the single gradient update case.
The centralized version of this formulation was first proposed by Finn et al. (2017) and followed by a number of papers studying its empirical characteristics (Antoniou et al., 2019; Li et al., 2017; Grant et al., 2018; Nichol et al., 2018; Zintgraf et al., 2019; Behl et al., 2019) as well as its convergence properties (Fallah et al., 2019). In this work, we focus on exploiting the MAML formulation to introduce a personalized solution for the federated learning setting. The analysis of the proposed algorithm for the FL setting is more challenging than the centralized case as we discuss in Section 4.
3 Personalized FedAvg
In this section, we introduce our proposed Personalized FedAvg method for solving problem (4). This algorithm is inspired by the FedAvg algorithm originally proposed for the classic federated learning problem (1), but it has been modified in a way that the resulting method finds the optimal solution of (4) instead of (1). To better highlight this connection, let us recap the main steps of the FedAvg algorithm. In FedAvg, at each round, server chooses a fraction of users with size (with ) and sends its current model to these users. Each selected user updates this model according to its own loss function and by running steps of stochastic gradient descent. Then, the users return their updated models to the server. Finally, the server updates the global model by computing the average of the models received from these selected users, and then the next round follows.
The proposed personalized FedAvg method follows the same principle and it aims to implement a similar algorithm for minimizing the function defined in (4). Before formally stating the update of personalized FedAvg let us mention that the global objective function in (4) can be written as the average of metafunctions where the metafunction associated with user is defined as
(5) 
In other words, in this case, each local function is defined as the value of the local loss function after running one step of gradient descent.
To follow a similar scheme as FedAvg for solving problem (4), the first step is to compute the gradient of local functions, which in this case, the gradient , that is given by
(6) 
Note that, computing the exact gradient at every round is not usually computationally tractable, and we therefore, take a batch of data with respect to distribution to obtain an unbiased estimate given by
(7) 
Similarly, we could replace the Hessian in (6) by its unbiased estimate over the batch .
At round of Personalized FedAvg algorithm, similar to FedAvg, first the server sends the current global model to a fraction of users chosen uniformly at random with size . Each user performs steps of stochastic gradient descent locally and with respect to . In particular, these local updates generates a local sequence where and, for ,
(8) 
where is the local learning rate (stepsize) and is an estimate of in (6). Note that the stochastic gradient for all local iterates is computed using independent batches , , and as follows
(9) 
We would like to emphasize that is a biased estimator of due to the fact that is a stochastic gradient that contains another stochastic gradient inside.
Once, the local updates are evaluated, all users send their updated models to the server, and the server updates its global model by averaging over the received models, i.e.,
(10) 
These steps are depicted in Algorithm 1. Note that as in other MAML Algorithms (Finn et al., 2017; Fallah et al., 2019), the update in (8) which exploits the stochastic gradient estimation in (9) can be implemented in two levels: (i) First for each user and each iteration we perform the following update
and then evaluate by following the update
Indeed, it can be verified the outcome of the these two steps is equivalent to the update in (8). To simplify the notation, throughout the paper, we assume that the size of , , and is equal to , , and , respectively, and for any and .
4 Theoretical Results
In this section, we study the convergence properties of our proposed Personalized FedAvg (PerFedAvg) method. We focus on nonconvex settings, and characterize the overall communication rounds between server and users for achieving firstorder stationarity. To do so, we first formally define the notion of an approximate firstorder stationary point.
4.1 Definitions and Assumptions
Definition 4.1.
A random vector is called an approximate FirstOrder Stationary Point (FOSP) for problem (4) if it satisfies
Next, we formally state the assumptions required for proving our main results.
Assumption 1.
Function is bounded below, i.e., .
Assumption 2.
For every , is twice continuously differentiable and smooth, and also, its gradient is bounded by a nonnegative constant , i.e.,
(11a)  
(11b) 
It is worth noting that (11b) also implies that satisfies the following conditions for all :
(12a)  
(12b) 
As we discussed in Section 3, the secondorder derivative of all functions appears in the update rule of PerFedAvg Algorithm. Hence, in the next Assumption, we impose a regularity condition on the Hessian of each which is also a customary assumption in the analysis of secondorder methods.
Assumption 3.
For every , the Hessian of function is Lipschitz continuous, i.e.,
(13) 
To simplify the analysis, in the rest of the paper, we define , , and which can be, respectively, considered as a bound on the norm of gradient of , smoothness parameter of , and Lipschitz continuity parameter of Hessian , for all .
Now, we state the next assumption which provides upper bounds on the variances of gradient and Hessian estimation.
Assumption 4.
For any and any , the stochastic gradient and Hessian , computed with respect to a single data point , has bounded variance, i.e.,
(14)  
(15) 
where and are nonnegative constants.
Finally, we state our last assumption which characterizes the similarity between the tasks of users.
Assumption 5.
For any , the variance of gradient and Hessian are bounded, i.e., for some nonnegative and , we have
(16a)  
(16b) 
for any .
Note that Assumption 2 implies that this assumption holds automatically for and . However, we state this assumption separately to highlight the role of similarity of functions corresponding to different users in convergence analysis of PerFedAvg. In particular, in the following subsection, we highlight the connections between this assumption and the similarity of distributions for the case of supervised learning (2) under two different distribution distances.
4.2 On the Connections of Task Similarity and Distribution Distances
Recall the definition of for the supervised learning problem stated in (2). As mentioned above, Assumption 5 captures the similarity of loss functions of different users, and one fundamental question here is whether this has any connection with the closeness of distributions . We study this connection by considering two different distances: Total Variation (TV) distance and 1Wasserstein distance. Throughout this subsection, we assume all users have the same loss function over the same set of inputs and labels, i.e., where . Also, let denote the average of all users’ distributions.
Total Variation (TV) Distance: For distributions and over countable set , their TV distance is given by
(17) 
If we further assume a stronger version of Assumption 2 holds where for any and , we have
(18) 
then, Assumption 5 holds with (check Appendix A for the proof)
(19a)  
(19b) 
This simple derivation shows that and exactly capture the difference between the probability distributions of the users in a heterogeneous setting.
1Wasserstein Distance: The 1Wasserstein distance between two probability distributions measures and over a metric space defined as
(20) 
where is a distance function over metric space and denotes the set of all measures on with marginals and on the first and second coordinate, respectively. Here, we assume all have bounded support (note that this assumption holds in many cases as either itself is bounded or because we normalize the data). Also, we assume that for any , the gradient and the Hessian are both Lipschitz with respect to parameter and distance , i.e,
(21a)  
(21b) 
Then, Assumption 5 holds with (check Appendix A for the proof)
(22a)  
(22b) 
It is worth noting that this derivation does not use other Assumptions such as Assumption 2 and holds in general when (21a) and (21b) are satisfied.
4.3 Convergence Analysis of PerFedAvg Algorithm
In this subsection, we derive the overall complexity of PerFedAvg for achieving an firstorder stationary point. To do so, we first prove the following intermediate result which shows that under Assumptions 2 and 3, the local metafunctions defined in (5) and their average function are smooth.
Lemma 4.2.
Proof.
Check Appendix B. ∎
The conditions in Assumption 4 provide upper bounds on the variances of gradient and Hessian estimation for functions . To analyze the convergence of PerFedAvg, however, we need an upper bound on the variance of gradient estimation of the functions . We derive such an upper bound in the following lemma.
Lemma 4.3.
Proof.
Check Appendix C. ∎
To measure the tightness of this result, we consider two special cases. First, if the exact gradients and Hessians are available, i.e., , then as well which is expected as we can compute exact . Second, for the classic federated learning problem, i.e., and , we have which is tight up to constants.
Next, we use the similarity conditions for the functions in Assumption 5 to study the similarity between gradients of the functions .
Lemma 4.4.
Proof.
Check Appendix D. ∎
It is worth going over the two special cases that we discussed for Lemma 4.3 to see how tight Lemma 4.4 is. First, if are all equal, i.e., , then as well. This is indeed expected as all are equal to each other in this case. Second, for the classic federated learning problem, i.e., and , we have which is optimal up to a constant factor given the conditions in Assumption 5.
Now, we are ready to state the main result of our paper on the convergence of our proposed PerFedAvg method.
Theorem 4.5.
Consider the objective function defined in (4) for the case that . Suppose that the conditions in Assumptions 14 are satisfied, and recall the definitions of , , and from Lemmas 4.24.4. Consider running Algorithm 1 for rounds with local updates in each round and with . Then, the following firstorder stationary condition holds
where is the average of iterates of users in at time , i.e.,
and in particular, and .
Proof.
Check Appendix F. ∎
The result in Theorem 4.5 shows that if each user runs local updates at each iteration, after rounds of communication between users and server the average squared gradient norm in expectation converges at a sublinear rate of to a neighborhood of with radius . This result shows to find an FOSP, we need to ensure that the parameters and satisfy the condition .
Note that is not a constant, and as expressed in Lemma 4.3, we can make it arbitrary small by choosing batch sizes , , or large enough. Also, and as we discussed after Lemma 4.4, would be zero if we assume we have access to the exact the gradient and Hessians. Similarly, Lemma 4.4 implies that having small values for and would imply that is also small. As we discussed in Section 4.2, this observation is related to the closeness of data distribution of agents with respect to distribution measures such as Total Variation or 1Wasserstein metric. In particular, consider the special case when admits the finite sum representation (3) and the data distributions are homogeneous, i.e., all users data distributions are drawn from an underlying distribution . Then, having more samples for each user, i.e., larger in (3), will lead to smaller and as the empirical distribution of each user becomes closer to (see (Reisizadeh et al., 2019)).
Remark 4.6.
The result of Theorem 4.5 provides an upper bound on the average of for all and . However, one concern here is that due to the structure of Algorithm 1, for any , we only have access to for . To address this issue, at any iteration , the center can choose uniformly at random, and ask all the users in to send back to the server, possibly in addition to . If follow such a scheme then we can ensure that
5 Numerical Experiments
In this section, we design a numerical setting to highlight the role of personalization when the data distributions are heterogeneous. In particular, we consider the problem of classifying handwritten digits from the MNIST dataset (LeCun, 1998) and distribute the training data between users as follows:

Half of the users have images of each of the digits 04.

The rest, each have images from one of 04 digits and images from one of 59 digits.
This way, we create an example where the distribution of images over all the users are different from each other. Similarly, we divide the test data over the nodes with the same distribution as the one for the training data.
We consider three algorithms in this setting: First, the classic FedAvg method, where the users find a shared model which all implement without any update during the test timet. Second, we take the output of the FedAvg method, and update it with one step of gradient descent with respect to the test data, and then evaluate its performance. Third, we consider our proposed algorithm, PerFedAvg, and update its output, again with one step of gradient descent, during the test time. Similar to MAML, implementation of PerFedAvg requires access to secondorder information which is computationally costly. To address this issue, we replace the gradient estimate at each iteration with its firstorder approximation which ignores the Hessian term, i.e., in (9) is approximated by
(23) 
This is the same idea deployed in FirstOrder MAML (FOMAML) in (Finn et al., 2017), and it has been shown that it almost achieves the same level of performance as MAML when the the learning rate is small (Fallah et al., 2019). Also, in Appendix G, we discuss how our analysis can be extended to firstorder approximations of PerFedAvg, such as the one implemented for this experiment.
For this experiment, we use a neural network classifier with two hidden layers with sizes 80 and 60, respectively, and we use Exponential Linear Unit (ELU) activation function. We run all three algorithms for rounds. At each round, we assume a fraction of agents with size are chosen to run local updates. The batch sizes and the learning rates are chosen as and . Further, we consider the case that there are users in the network. We would like to mention that part of the code is adopted from (Langelaar, 2019).
The results for different values of number of local updates and ratio of active users are illustrated in Figure 1. As expected, in all considered cases, the model trained by running the update of FedAvg to solve the classic FL problem in (1) performs worse than the same model after running one step of local gradient in the test phase. Hence, if extra computation is available at the test time, the model of FedAvg after one step of gradient descent leads to a more personalized solution.
More importantly, the PerFedAvg method, which is originally designed to find a point which performs well once it is updated using one step of local gradient descent has the best performance among the three considered approaches. In other words, its model has a better test accuracy compared to the model that is obtained by running one step of local gradient over the solution of FedAvg. These experiments show that by solving the MAML variant of the FL problem we obtain a solution that performs better in heterogeneous settings.
6 Conclusion
In this paper, we studied the Federated Learning (FL) problem in a heterogeneous case that the probability distribution of the users in the network are not identical and could be different. To solve this problem, we studied a personalized variant of the classic FL formulation in which our goal is to find a proper initialization model for the users in the network that can be quickly adapted to the local data of each user after the training phase. In particular, we introduced a ModelAgnostic MetaLearning (MAML) variant of FL in which instead of minimizing the average loss over the data of all users, we find the best initial model that after one step of local gradient leads to a good model for each individual user. As expected, this approach leads to a more personalized model for each user. We then introduced a personalized variant of the FedAvg algorithm, called PerFedAvg, to solve the proposed personalized FL problem. We also characterized the overall complexity of the PerFedAvg method for nonconvex settings. Specifically, for the case that each user runs local updates at each iteration, we showed that after rounds of communication between users and server PerFedAvg converges to a neighborhood of a firstorder stationary point at a rate of , where the radius of this neighborhood depends on the closeness of data distribution of different users. Finally, we provided a numerical experiment to illustrate the performance of PerFedAvg and its comparison with FedAvg method.
Appendix
Appendix A Proofs of results in Subsection 4.2
a.1 TV Distance
Note that
(24) 
where for the inequality we used the assumption that for any and . Plugging (24) in
(25) 
gives us the desired result. The other result on Hessians can be proved similarly.
a.2 1Wasserstein Distance
We claim that for any and , we have
(26) 
which will immediately gives us one of the two results. To show this, first, note that
Thus, we need to show for any with , we have
(27) 
Next, note that since and both have bounded support, by Kantorovich–Rubinstein Duality Theorem Villani (2008), we have
(28) 
Using this result, to show (27), it suffices to show is Lipschitz. Note that CauchySchwarz inequality implies
(29) 
where the last inequality is obtained using along with (21).
Finally, note that we can similarly show the result for by considering the fact that
and taking the function and using Kantorovich–Rubinstein Duality Theorem again.
Appendix B Proof of Lemma 4.2
Recall that
(30) 
Given this, note that
(31)  
(32) 
where (31) is obtained by adding and subtracting and the last inequality follows from the triangle inequality and the definition of matrix norm. Now, we bound two terms of (32) separately.
First, note that by (12a), . Using this along with smoothness of , we have
(33) 
where we used smoothness of along with for the last line.
Appendix C Proof of Lemma 4.3
Recall that the expression for the stochastic gradient is given by
(35) 
which can be written as
(36) 
Note that in the above expression and are given by
and
It can be easily shown that