Federated Learning with Additional Mechanisms on Clients to Reduce Communication Costs
Federated learning (FL) enables on-device training over distributed networks consisting of a massive amount of modern smart devices, such as smartphones and IoT (Internet of Things) devices. However, the leading optimization algorithm in such settings, i.e., federated averaging (FedAvg), suffers from heavy communication costs and the inevitable performance drop, especially when the local data is distributed in a non-IID way. To alleviate this problem, we propose two potential solutions by introducing additional mechanisms to the on-device training.
The first (FedMMD) is adopting a two-stream model with the MMD (Maximum Mean Discrepancy) constraint instead of a single model in vanilla FedAvg to be trained on devices. Experiments show that the proposed method outperforms baselines, especially in non-IID FL settings, with a reduction of more than 20% in required communication rounds.
The second is FL with feature fusion (FedFusion). By aggregating the features from both the local and global models, we achieve higher accuracy at fewer communication costs. Furthermore, the feature fusion modules offer better initialization for newly incoming clients and thus speed up the process of convergence. Experiments in popular FL scenarios show that our FedFusion outperforms baselines in both accuracy and generalization ability while reducing the number of required communication rounds by more than 60%.
Mobile phones, wearable devices, and IoT (Internet of Things) devices play an important role in modern life. Intelligent applications on these devices are becoming popular, such as intelligent personal assistant, machine translation, keyboard input suggestion, etc. These applications usually use pre-trained models and perform forward inference on clients, which lacks flexibility and personalization. Meanwhile, smart edge devices are generating a tremendous amount of valuable yet privacy-sensitive data that has the potential to improve existing models. To take full advantage of the on-device data, traditional machine learning strategies require collecting data from clients, training a centralized model on the servers and then issuing the model to clients, which puts a heavy burden on the communication networks and is exposed to high privacy risks.
Recently, a series of work called federated learning (FL) konevcny2015federated ; konevcny2016federated ; mcmahan2017communication enables on-device training directly over the distributed networks. The aim of FL is to train a model from data generated by distributed clients (each device is treated as a client). Each client, , generates data in a non-IID manner, which means the data distribution on client , , is not a uniform sample of the whole distribution. The leading algorithm in FL, i.e., federated averaging (FedAvg) mcmahan2017communication , assumes a synchronous updating scheme that proceeds in rounds of communication. Considering a fixed number of clients (each with a local dataset), a random fraction of clients is selected to participate in this round of updating at the beginning of each round. The server sends the current global model to each of these chosen clients. Then each client computes a unique local model based on the global model and its local data, and reports it to the server. Finally, the server updates the global model by model averaging and begins the next round. By adding more computational iterations to clients, FedAvg reduces the required communication rounds compared with traditional SGD based methods. However, further studies jeong2018communication ; yao2018twostream point out that communication costs remain the main constraint in FL compared to other factors, e.g., the computation costs, and the accuracy of FedAvg would drop significantly if the models were trained with pathological non-IID data.
To alleviate this problem, we propose two potential solutions by introducing additional mechanisms to the on-device training.
We first propose adopting a two-stream model, which is commonly used in transfer learning and domain adaptation sun2016deep ; long2015learning ; zhuo2017deep , instead of a single model to be trained on devices in FL settings. Maximum Mean Discrepancy (MMD) constraint gretton2012optimal is introduce to the on-device training iterations of our method, which forces the local model to integrating more knowledge from the global one. Further experiments show the proposed FedMMD brings a reduction in required communication rounds without affecting the final test performance.
Then we further propose a new FL algorithm with feature fusion mechanism (FedFusion) to reduce the communication costs. By introducing the feature fusion modules, we aggregate the features from both the local and global models after the feature extraction stage with little extra computation costs. These modules make the training process on each client more efficient and handle the non-IID data distribution more pertinently, as each client will learn the most appropriate feature fusion module for itself.
In conclusion, our contributions are as follows:
We use a two-stream model with MMD (Maximum Mean Discrepancy) constraint instead of a single model in vanilla FedAvg to be trained on devices.
We propose aggregating the features from both the local and global models during the on-device training.
Experiments on popular FL settings show that the proposed methods outperform baselines in both accuracy and generalization ability while reducing the number of communication rounds by up to more than 60%.
2 Related Work
2.1 Federated Learning
Federated learning (FL) is proposed by McMahan et al. mcmahan2017communication to tackle the problem of decentralized training over massively distributed intelligent devices without access to the privacy-sensitive data directly.
Considering that communication costs remain the main constraint in FL, some research efforts have already been made. Konečnỳ et al. konevcny2016federated proposed structured and sketched updates in the context of client-to-server communication. Yao et al. yao2018twostream introduced extra constraints to the on-device training procedure, aiming to integrate more knowledge from other clients while fitting the local data. Caldas et al. caldas2018expanding proposed federated dropout to train subsets on clients and extended the lossy compression suresh2017distributed to server-to-client communication.
2.2 Maximum Mean Discrepancy
As the name suggests, Maximum Mean Discrepancy (MMD) measures the distance between the means of two data distributions, and is widely used in domain adaptation problems long2015learning ; long2013transfer ; tzeng2014deep , describing the difference between features generated from source and target domains. By minimizing the MMD loss, they force the two-stream model extract more generalized features, which is very similar to our purpose of learning better representations in the whole dataset. In this paper, we focus on the multiple kernel variant of MMD (MK-MMD) proposed by Gretton et al gretton2012optimal , which first maps the data distributions to a Reproducing Kernel Hilbert Space (RKHS).
Given two data distributions and , the square of MMD between them can be expressed as:
where denotes the mapping to RKHS. In practice, this mapping is unknown. Using the kernel trick to replace the inner product and we have:
where is the desired kernel function. In this paper, we use a standard radial basis function (RBF) kernel with multiple width.
In this section, we will first introduce our two-stream FL with MMD (FedMMD), then introduce the proposed feature fusion modules and give our FL algorithm with feature fusion mechanism (FedFusion).
3.1 Two Stream Federated Learning
Formally, let and be the model parameters, that is, the weights and biases, of all the layers in the global and local models respectively. Let and denote the training data and the corresponding labels on client , where is the number of examples.
Figure 1 shows how our FedMMD works on the specific client . As shown in the figure, the global model is received from the server at the beginning of the current round and fixed in the following training process, while the local model is initialize with the parameters of the global model () and then trained on the local data and labels by minimizing a loss function in the form:
where and denote the outputs of the global model and local model with the corresponding input respectively. denotes a standard classification loss, such as cross-entropy loss function in our experiments. denotes the MMD loss between the outputs of the global and local models computed by Equation (2). This term is weighted by coefficient .
The training process of FedAvg is indeed a cycle process of learning local representations, merging knowledge from different clients and then learning again. In other words, the global model contains more knowledge from multiple clients while the local model learns better representations of the local data. Compared with a single local model trained in FedAvg, we keep the global model as a reference instead of throwing it away after initializing. As discussed in Section 2.2, MMD is a measurement of the distance between the means of two data distributions. By minimizing the MMD loss term between the outputs of the global and local models, we force the local model to integrate more knowledge from other clients in addition to the representations of data on the current client, thus accelerating the convergence of the overall training process, in other words, reducing the communication rounds.
Our two-stream FL algorithm with MMD is described as Algorithm 1.
3.2 Feature Fusion Modules
The detailed architectures of feature fusion modules are illustrated in Figure 2.
Concretely, an input image is transformed into two feature spaces by the local feature extractor and the global one respectively, with the feature maps . Then a fusion operator embeds and into a fusion feature space, where . In this paper, we introduce three types of fusion operator as follows.
Conv operator () is implemented with convolutions,
where is the learned weight matrix and denotes the operation that concatenates the feature maps along channel axis.
Multi operator () introduces a learned weight vector and computes the weighted sum between the local and global feature maps,
where the weighted vector is first broadcasted to the shape of and then multiplied by the feature maps element-wise, as illustrated in Figure 1(b).
Single operator () uses a learned weight scalar and computes the weighted sum between the local and global feature maps,
where the global and local feature maps are scaled by and respectively and then added together element-wise, as shown in Figure 1(c).
3.3 Federated Learning with Feature Fusion Mechanism
A typical training iteration of the proposed FedFusion is shown in Figure 3.
At the beginning of round , we keep the feature extractor of the global model () instead of throwing it away as in FedAvg after initialization. During training, is frozen and an additional feature fusion module described in Section 3.2 is introduced. In practice, it’s possible to record the global feature maps generated by in one round forward inference. In other words, the additional feature fusion module brings limited extra computation costs. After the on-device training procedure, the local model combined with the feature fusion module will be sent to the central server for model aggregation. For multi and single operators, we use an exponential moving average strategy to smooth the update.
The pseudo code of FedFusion is shown in Algorithm 2.
4.1 Experimental Setup
We use MNIST lecun1998gradient and CIFAR10 krizhevsky2009learning as basic datasets in our experiments. We further proposed three types of data partitions to benchmark our FedMMD and FedFusion, and the vanilla FedAvg.
The first is Artificial non-IID Partition, which is implemented by splitting an existing IID dataset to meet the FL settings and commonly used in previous FL studies konevcny2015federated ; konevcny2016federated ; mcmahan2017communication ; caldas2018expanding ; zhao2018federated . In this partition, a single client usually has a subset of the classes of the total data. For example, most clients have up to two digits of MNIST in mcmahan2017communication .
The second is User Specific non-IID Partition, where the data on different clients usually have similar classes but follows different distributions. This is commonly used in multi-task learning studies smith2017federated ; chen2018federated ; caldas2018leaf .
For MNIST digits recognition task, we use the the same model as FedAvg mcmahan2017communication : a CNN with two 55 convolution layers (the first with 32 channels while the second with 64, each followed by a ReLU activation and 22 max pooling), a fully connected layer with 512 units (with a ReLU activation and random dropout), and a final softmax output layer.
For CIFAR10 we use a CNN with two 55 convolution layers (both with 64 channels, each followed by a ReLU activation and 33 max pooling with stride size 2), two fully connected layers (the first with 384 units while the second with 192, each followed by a ReLU activation and random dropout) and a final softmax output layer.
4.2.1 CNN on CIFAR-10
For convenience but without loss of generality, we select a fixed group of hyper-parameters, that is, two clients participating in the training process, with a local batch size and local epochs . During the training, we use a SGD optimizer with the learning rate of . The penalty parameter for is set to 0.1 in this experiment. As for the penalty parameter for L2 norm, we have tried 0.1 and 0.01, and the results are shown in Figure 3(a), 3(b).
We explore both non-IID and IID data distributions and show that our FedMMD outperforms the baseline methods, especially in non-IID data distribution.
Figure 3(a) shows the non-IID situation, where we split the 10 classes of images in CIFAR-10 dataset into 2 parts, each containing 5 classes without overlap, indicating the non-IID data distribution. FedMMD needs fewer communication rounds to get to convergence. More concretely, FedMMD reaches the test accuracy of 0.72 by 260 rounds of communication, with a reduction of 20.2% compared to 326 rounds need by FedAvg. It is worth noting that, MMD forces the local model learn more knowledge from the global model but does not introduce new information into the overall optimization system compared with FedAvg, thus the final convergence results are the same, which is consistent as expected.
Figure 3(b) shows the IID situation. As we can see, FedMMD performs similar to the vanilla FedAvg. In this case, the data on each client is a uniform sample of the overall dataset, which means the local model is able to learn the complete representations, and as a result, the role of MMD constraint is weakened. Two-stream model constrained by L2 norm underperforms other methods, indicating that selection of constraints does matter.
4.2.2 CNN on MNIST
Similar to the CIFAR-10 experiments, we first explore the binary separated non-IID distributions and the results are shown in Figure 3(c). During the training, we use a SGD optimizer with the learning rate of . The penalty parameter for and are 0.1 and 0.001, respectively. We use a rather small for as we find that a larger one may lead to non-convergence situation. As we can see, our FedMMD are faster in convergence compared with other methods and does not lower the final test accuracy. The reduction of communication rounds is less than the result on CIFAR-10. We speculate that the variety of feature representations in MNIST is much less due to the black-and-white images and simple lines in them.
We further study a more complex non-IID data distribution described in mcmahan2017communication , where we first sort the data in MNIST by digit label, divide it into 200 shards of size 300, and assign each of 100 clients 2 shards. This is a typical non-IID partition of the data, as most clients will only have examples of two digits. And we set , with local batch size and local epochs . As shown in Figure 3(d), the curve of test accuracy dithers due to the extremely pathological data distribution. Our FedMMD achieves a test accuracy of 98% with 72 rounds of communication while FedAvg needs 128 rounds, which means a reduction of 23.4%.
4.3.1 Artificial non-IID Partition
Under the artificial non-IID scenarios, we use a learning rate of 0.003 with an exponential decay factor 0.985 each round for all our FedFusion methods (with different fusion operators) and the compared FedAvg. The convergence behaviors of them are illustrated in Figure 4(a) and 4(b) while the final accuracies at convergence are shown in Table 1.
The curve representing FedFusion with multi operator is always above others, which means it achieves a higher accuracy at fewer communication costs. The accuracy of FedFusion with conv also raises faster at the beginning but fails to reach a better convergence point. FedFusion with single and FedAvg perform relatively worse.
Such results are obviously due to the multi fusion operator. As stated before, most clients have a subset of the total classes in artificial non-IID scenarios. The multi operator allows the models on clients to select the feature maps that are helpful to their local data. In contrast, FedAvg does not offer the selection and the single operator does not provide enough room for adjustment.
4.3.2 User Specific non-IID Partition
To simulate the user specific non-IID partition, we apply different permutations to MNIST on each client, which is the so-called Permuted MNIST in several previous studies goodfellow2013empirical ; zenke2017continual . We use a learning rate of 0.002 with an exponential decay factor 0.99 each round for all the methods.
The number of communication rounds to reach certain accuracy milestones (94% and 95% here), as well as the reduction in communication rounds versus FedAvg, is shown in Table 2. The results indicate that FedFusion with conv leads in a large margin, which is different from that in artificial non-IID partition. FedFusion with conv achieves the best performance while reducing the number of communication rounds by more than 60%. In user specific non-IID partition, the data on clients have similar classes but follow different distributions. The conv operator has better ability to integrate the feature maps from both the local and global models, in other words, the knowledge from other clients and data distributions. It is worth noting that user specific non-IID partition is closer to the realistic FL scenarios, thus the improvement makes more sense in this case.
Additionally, we study the generalization ability of the model that was usually ignored in previous FL research works. The local epochs to reach convergence for newly incoming clients are illustrated in Figure 6. As we can see, when a new client joins an existing FL system, FedFusion with conv provides a better initialization than other algorithms and thus speeds up the process of convergence.
4.3.3 IID Partition
The IID partition is a simple yet necessary partition to evaluate FL algorithms. If one strategy cannot handle this partition, its effectiveness is questionable.
As shown in Figure 4(d), FedFusion with multi and conv achieve higher accuracy at fewer communication costs. In terms of the final convergence accuracy, FedFusion with multi and conv have an impressive improvement than other methods.
To make a brief conclusion on the feature fusion operators as follows: The multi operator offers flexible choices between the local and global feature maps and is more interpretable. Entries in the weight vector account for the proportions of the global feature maps in corresponding channels. When there were gaps in the classes of data, multi operators would learn to choose the most helpful feature maps. The conv operator is better at integrating the knowledge from the global and local models. If the data on clients had similar classes but followed different distributions, conv operator performs much better. Our experiments indicate that single operator has few improvements and should not be adopted in practice.
The heavy communication costs of FedAvg is an emergency problem to solve. In this paper, we first replace the single model trained on clients in FL settings with a two-stream model consisting of the global and local ones. Our experiments show that introducing the MMD constraint into the optimization algorithm will bring a reduction in communication rounds, especially in non-IID FL settings. Further we propose a new FL algorithm with feature fusion modules and evaluate it in popular FL setups. The experimental results show that the proposed method achieves a higher accuracy while reducing the communication rounds by more than 60%. What is more, we observe that FedFusion offers better generalization for newly incoming clients.
This work is supported by the National Key R&D Program of China (2018YFB1003703), National Natural Science Foundation of China (61472204 & 61521002), as well as Beijing Key Lab of Networked Multimedia (Z161100005016051).
-  S. Caldas, J. Konečny, H. B. McMahan, and A. Talwalkar. Expanding the reach of federated learning by reducing client resource requirements. arXiv preprint arXiv:1812.07210, 2018.
-  S. Caldas, P. Wu, T. Li, J. Konečnỳ, H. B. McMahan, V. Smith, and A. Talwalkar. Leaf: A benchmark for federated settings. arXiv preprint arXiv:1812.01097, 2018.
-  F. Chen, Z. Dong, Z. Li, and X. He. Federated meta-learning for recommendation. arXiv preprint arXiv:1802.07876, 2018.
-  I. J. Goodfellow, M. Mirza, D. Xiao, A. Courville, and Y. Bengio. An empirical investigation of catastrophic forgetting in gradient-based neural networks. arXiv preprint arXiv:1312.6211, 2013.
-  A. Gretton, D. Sejdinovic, H. Strathmann, S. Balakrishnan, M. Pontil, K. Fukumizu, and B. K. Sriperumbudur. Optimal kernel choice for large-scale two-sample tests. In Advances in neural information processing systems, pages 1205–1213, 2012.
-  E. Jeong, S. Oh, H. Kim, J. Park, M. Bennis, and S.-L. Kim. Communication-efficient on-device machine learning: Federated distillation and augmentation under non-iid private data. arXiv preprint arXiv:1811.11479, 2018.
-  J. Konečnỳ, B. McMahan, and D. Ramage. Federated optimization: Distributed optimization beyond the datacenter. arXiv preprint arXiv:1511.03575, 2015.
-  J. Konečnỳ, H. B. McMahan, F. X. Yu, P. Richtárik, A. T. Suresh, and D. Bacon. Federated learning: Strategies for improving communication efficiency. arXiv preprint arXiv:1610.05492, 2016.
-  A. Krizhevsky and G. Hinton. Learning multiple layers of features from tiny images. Technical report, Citeseer, 2009.
-  Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998.
-  M. Long, Y. Cao, J. Wang, and M. I. Jordan. Learning transferable features with deep adaptation networks. arXiv preprint arXiv:1502.02791, 2015.
-  M. Long, J. Wang, G. Ding, J. Sun, and S. Y. Philip. Transfer feature learning with joint distribution adaptation. In Computer Vision (ICCV), 2013 IEEE International Conference on, pages 2200–2207. IEEE, 2013.
-  B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas. Communication-efficient learning of deep networks from decentralized data. In Artificial Intelligence and Statistics, pages 1273–1282, 2017.
-  V. Smith, C.-K. Chiang, M. Sanjabi, and A. S. Talwalkar. Federated multi-task learning. In Advances in Neural Information Processing Systems, pages 4424–4434, 2017.
-  B. Sun and K. Saenko. Deep coral: Correlation alignment for deep domain adaptation. In European Conference on Computer Vision, pages 443–450. Springer, 2016.
-  A. T. Suresh, F. X. Yu, S. Kumar, and H. B. McMahan. Distributed mean estimation with limited communication. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pages 3329–3337. JMLR. org, 2017.
-  E. Tzeng, J. Hoffman, N. Zhang, K. Saenko, and T. Darrell. Deep domain confusion: Maximizing for domain invariance. arXiv preprint arXiv:1412.3474, 2014.
-  X. Yao, C. Huang, and L. Sun. Two-stream federated learning: Reduce the communication costs. In Visual Communications and Image Processing (VCIP), 2018, pages 1–4. IEEE, 2018.
-  X. Yao, T. Huang, C. Wu, R. Zhang, and L. Sun. Towards faster and better federated learning: A feature fusion approach. In IEEE International Conference on Image Processing, 2019.
-  F. Zenke, B. Poole, and S. Ganguli. Continual learning through synaptic intelligence. In International Conference on Machine Learning, pages 3987–3995, 2017.
-  Y. Zhao, M. Li, L. Lai, N. Suda, D. Civin, and V. Chandra. Federated learning with non-iid data. arXiv preprint arXiv:1806.00582, 2018.
-  J. Zhuo, S. Wang, W. Zhang, and Q. Huang. Deep unsupervised convolutional domain adaptation. In Proceedings of the 2017 ACM on Multimedia Conference, pages 261–269. ACM, 2017.