Federated Learning for Keyword spotting
We propose a practical approach based on federated learning to solve out-of-domain issues with continuously running embedded speech-based models such as wake word detectors. We conduct an extensive empirical study of the federated averaging algorithm for the “Hey Snips“ wake word based on crowdsourced data on two distinct tasks: learning from scratch and language adaptation. We also reformulate the global averaging step of the federated averaging algorithm as a gradient update step, applying per-coordinate adaptive learning rate strategies such as Adam  in place for standard weighted model averaging. We then empirically demonstrate that using adaptive averaging strategies highly reduces the number of communication rounds required to reach a target performance.
David Leroy, Alice Coucke, Thibaut Lavril, Thibault Gisselbrecht and Joseph Dureau\addressSnips, 18 rue Saint Marc, 75002 Paris, France
Keyword spotting, embedded speech recognition, federated learning
Wake word detection is used to start an interaction with a voice assistant. A specific case of keyword spotting (KWS), it continuously listens to an audio stream to detect a predefined keyword or set of keywords. Well-known examples of wake words include Apple’s “Hey Siri“ or Google’s “OK Google“. Once the wake word is detected, voice input is activated and processed by a Spoken Language Understanding (SLU) engine, powering the perception abilities of the voice assistant .
Wake word detectors usually run on device in an always-on fashion which brings two major difficulties. First, it should run with minimal memory footprint and computational cost. The resource constraints for our wake word detector are 200k parameters (based on the medium sized model proposed in ), and 20 MFLOPS.
Secondly, the wake word detector should be robust in any usage setting and should behave consistently with background noise. The audio signal is highly sensitive to recording proximity (close or far field), the recording hardware, but also to the room configuration. Robustness also implies a strong speaker variability coverage (genders, accents, etc.). While the use of digital signal processing (DSP) front-ends can help mitigate issues related to bad recording conditions, speaker variability remains a major challenge. High accuracy is all the more important since the model can be triggered at any time: it is therefore expected to capture most of the commands (high recall, or low false rejection rate) while not triggering when it should not (low false alarm rate). In the case of cloud-based speech recognition, wake word false alarms can lead to major privacy leaks as the subsequent audio signal that can include very sensitive content is streamed to a remote server.
Today, wake word detectors are typically trained on collected datasets that are not representative of the data seen at run-time, hence the need to train the detector directly on in-domain data. While centralized collection can be considered to perform model adaptation over time, it comes with major privacy drawbacks. In this work, we investigate the use of Federated Learning (FL)  in the context of an embedded wake word detector that aims at being well behaved in real usage settings. Federated Learning is a decentralized optimization procedure that enables to train a central model on the local data of many users without the need to ever upload this data to a central server. The training workload is moved towards the user’s devices which perform training steps on the local data. Local updates from users are then averaged by a parameter server in order to create a global model.
2 Related work
Most research around decentralized learning has historically been done in the context of a highly controlled cluster/data center setting, e.g with a dataset evenly partitioned in an i.i.d fashion. The multi-core and multi-gpu distributed training setting has been specifically studied in the context of speech recognition in . Efforts on decentralized training with highly distributed, unbalanced and non-i.i.d data is relatively recent, as the foundations were laid down in  with the introduction of the federated averaging (FedAvg) algorithm and its application to a set of computer vision (MNIST, CIFAR-10) and NLP tasks (Shakespeare dataset and Google Plus for Language Modeling). There are for now very few real life experiments that we know of - except for Google’s keyboard Gboard for Android  and more recently Mozilla’s URL suggestion bar . To our knowledge, the present work is the first experiment of its kind on decentralized user speech data.
The federated optimization problem in the context of convex objective functions has been studied in . The authors proposed a stochastic variance-reduced gradient descent optimization procedure (SVRG) with both local and global per-coordinate gradient scaling to improve convergence. Their global per-coordinate gradient averaging strategy relies on a sparsity measure of the given coordinate in users local datasets and is only applicable in the context of sparse linear-in-the-features models. The latest assumption does not hold in the context of neural networks for speech-based applications.
Several improvements to the initial FedAvg algorithm have been suggested with a focus on client selection , budget-constrained optimization  and upload cost reduction for clients . A dynamic model averaging strategy robust to concept drift based on a local model divergence criterion was recently introduced in . While these contributions present efficient strategies to reduce the communication costs inherent to federated optimization, the present work is as far as we know the first one introducing a dynamic per-coordinate gradient update in place of the global averaging step.
The next section describes the federated optimization procedure, and how its global averaging step can be seen as a gradient descent update. It is followed by the experiments section, where both the crowdsourced data and model used to train our wake word detector are introduced. Two types of experiments are run: learning a model from scratch and language adaptation from a pre-trained model. A communication cost analysis is also provided. Finally, the next steps towards training a real wake word detector on decentralized user data are described.
3 Federated optimization
The model is initialized with a given architecture on a central server with weights
The central model is shared with a subset of users that have been randomly selected from the pool of online users K given a participation ratio C
Each user performs one or several training steps on their local data using mini-batch stochastic gradient descent (SGD) with a local learning rate . The number of steps performed locally is , being the number of datapoints available locally, E the number of local epochs and B the local batch size.
Users from send back their model updates
to the parameter server once local training is finished.
The server computes an average model based on the user’s individual updates , each user’s update being weighted by , âââwhereâââ
. It is assumed that . In average, each communication round involves the same total number of datapoints.
One back and forth interaction between the parameter server and the selected user’s devices is called a communication round. In order to obtain empirical convergence, many communication rounds are required, in the same way that one requires many training steps to train a model in a centralized fashion.
When (i.e the batch size is equal to the local dataset size) and , then a single gradient update is performed on each user’s data. It is strictly equivalent to doing a single gradient computation on a batch including all of selected user data points. This specific case is called FedSGD, e.g stochastic gradient descent with each batch being the data of the federation of selected users at a given round. FedAvg (Federated averaging) is the generic case when more than one update is performed locally for each user. The more local training is performed, the faster the model is expected to converge, assuming each individual model does not diverge too much one from each other.
The global averaging step can be rewritten as a global optimization step with a learning rate . Setting the global learning rate to 1 is equivalent to is equivalent to the weighted averaging case:
Equation 1 highlights the fact that averaging is a gradient update . Alternatives that have proven successful for deep neural networks optimization such as the Adam optimizer  can be applied in place of the standard averaging case. We empirically show that using the Adam optimization algorithm (2) in place of the global averaging step allows to smooth the global gradient estimate for a given communication round by taking into account the previous rounds updates that were computed on different user subsets . The exponentially-decayed first and second order moments perform the same kind of regularization that occurs in the mini-batch gradient descent setting, where Adam has proven to be successful on a wide range of tasks with various neural-network based architectures.
Unlike generic speech recognition tasks, there is no reference dataset for wake word detection. The reference dataset for multi-class keyword spotting is the speech command dataset  that includes 65k utterances from 1.8k speakers on 37 distinct labels, including silence and unkwnown word. The speech command task is generally preceded by a wake word detector and is focused on minimizing the confusion across classes and not robustness to false alarms, hence the need for a wake word specific dataset.
|Dataset||Train Set||Test Set||Total|
|EN||1,370 users||400 users||1770 users|
|Hey Snips||53,867 rec.||15,217 rec.||69,084 rec.|
|18% pos.||18% pos.||18% pos.|
|FR||201 users||60 users||261 users|
|Hey Snips||10,101 rec.||2,987 rec.||13,088 rec.|
|10.7% pos.||10.1% pos.||10.5% pos.|
The data was collected from thousands of people that
recorded themselves on their device with their own microphone while saying several occurences of the Hey Snips wake word along with randomly chosen negative sentences. Each recorded audio sample has gone through a validation process, where at least two of three distinct contributors have validated that the pronounced utterance matches the transcript. Collection was run through several campaigns on various crowdsourcing platforms, some contributors took part in most of them while others only participated once, resulting in highly imbalanced user datasets. The statistics about the collected datasets are summarized in Table 1.
The federated optimization procedure is different from a classic distributed optimization procedure in the way the data is shared across users. We recall these properties below and further specify how they apply to the specific use-case of wake word detection:
Non-i.i.d: the data laying on each user device is intrinsically not independent and identically distributed (i.i.d), as it is the data generated in their own usage setting ( specific speakers, room configuration, hardware setup, etc.). The crowdsourced datasets (see Table 1) used in the experiments were collected in close field, on the contributor’s device with various microphone qualities and recording environments.
Unbalanced: the amount of data available locally for each user depends on their interactions with the device itself. The crowdsourced datasets are representative of this discrepancy, as the standard deviation of the number of data points available locally is 23 samples for the French dataset and 32 for English.
Massively distributed: the number of clients participating is much larger than the number of examples per client which is typically true in the wake word setting. This condition also applies to the crowdsourced datasets. For instance, the English training set consists of 1,370 distinct users with 40 local datapoints per user in average.
This crowdsourcing-induced data distribution mimicks a real-world distributed setting, and a parallel is therefore drawn in the following between a crowdsourcing contributor and a voice assistant user for the sake of our demonstration.
Acoustic features are generated based on 40-dimensional mel-frequency cepstrum coefficients (MFCC) computed every 10ms over a window of 25ms. The input window consists in 32 stacked frames, symmetrically distributed in left and right contexts (i.e a 320ms context is used for each prediction). The architecture is a CNN with 5 stacked dilated convolutional layers of increasing dilation rate, followed by two fully-connected layers and a softmax. The total number of parameters is 190,852. The model is trained using cross entropy loss on frames prediction. The neural network has 4 output labels, assigned via a custom aligner specialized on the target utterance “Hey Snips“: “Hey“, “sni“, “ps“, and “filler“. A posterior handling  generates a confidence score for every frame by combining the smoothed label posteriors. The model triggers if the confidence score reaches a certain threshold , defining the operating point that maximizes recall for a certain amount of False Alarms per Hour (FAH). We set the number of false alarms per hour to 5 on a Â« hard Â» negative dataset consisting in the remaining crowdsourced utterances, recorded by the same contributors. This usually yields much lower FAH in a practical setting.
We conduct an extensive empirical study of the federated averaging algorithm for the Hey Snips wake word based on crowdsourced data from Table 1 on two distinct tasks. The first one is the task of learning the model from scratch in a federated fashion - train and test are run on the English dataset. Federated optimization results are compared with a standard setting i.e centralized mini-batch SGD with data from users being randomly shuffled. In a realistic scenario, an initial wake word model would be trained in a standard setting on crowdsourced data, and federated optimization would be used to perform domain adaptation on the decentralized in-domain data laying on user’s devices. Domain adaptation is tested through the second experiment consisting in language adaptation. A base model is trained in a standard manner on the English dataset and fine-tuned on French data in a federated setting. Testing is done on leftover French users.
Our aim is to evaluate the number of communication rounds that are required in order to reach our target metric of 95% recall for 5 False Alarms per hour (FAH). Experiments are summarized in Table 2.
|Learning||Dataset EN||Standard setting -|
|from scratch||Hey Snips||training from scratch|
|Language adaptation from||Dataset FR||FL setting -|
|model trained in standard||Hey Snips||training from scratch|
|setting on Dataset EN|
4.3.1 Learning from scratch\tab
Figure 1 shows the standard training results on the English dataset. Performance plateaus at 98% recall for 5 FAH on the test set in 1000 training steps when trained with the Adam optimizer with a batch size of 256. In this setting, the target performance is reached in 300 steps ( epochs) while SGD with and without clipping are much slower to converge.
User parallelism: The higher the ratio of user selected at each round C, the more data is seen, and the faster the convergence, assuming that local training does not diverge too much. The aim is to choose a value that is realistic in a practical setup with the constraint that selected users have to be online. Figure 2 shows the impact of C on convergence - the gain of using half of users is limited with comparison with using 10%, specifically in the later stages of convergence. A fraction of 10% of users per round is also more realistic in a practical setup. With lower participation ratios (), the gradients are much more sensitive and a smoothing strategy such as learning rate decay could be used to ease learning. In the next experiments, C is set to 10%. In this setting, the target of 95% recall per 5 FAH is reached after 142 communication rounds. In average, each user datapoint is seen 14 times during training. In comparison, the same performance is reached within 1.5 epochs in a standard setting.
Global averaging: Experimentally, standard training with Adam is converging much faster (see Figure 1). Same conclusion are drawn in the federated setting (Figure
3): adaptive learning rates based on Adam accelerates convergence when compared with standard averaging strategies with or without moving averages. The standard global averaging strategy from  with a global learning rate of 1.0 (e.g no weighted averaging is performed) yields 70% recall per 5 FAH after 400 communication rounds. A much greater number of rounds would be expected to reach the target performance in this setting.
Local training: Picking the local batch size B and the number of epochs E is more challenging - these two parameters account for the number of local training steps as described in Algorithm 1. Experiments from Table 3 show that doing more local training allows to reduce the number of communication rounds needed to reach target performance. Fastest convergence is obtained for and yielding a 50% speedup with an average of 4.2 local updates per worker taking part in a round. We also experimented with more local epochs and smaller batch sizes that mostly caused the training to diverge due to local data overfitting. Unlike some experiments presented in , the speedup coming from increasing the amount of local training steps does not lead to order of magnitude improvements in convergence speed. This difference can be related to the input’s semantic variability across users. In MNIST and CIFAR experiments from  the input semantics are the same across users. For instance, images of the 9 digit that are attributed to various users are all very similar. In the wake word setting, each user have their own vocalization of the same wake word utterance with significant differences in pitch and accent. This input discrepancy for the same output label can lead to a great variability in the representation learned by the first dilated convolutional layers, with an increased risk of divergence across users. With diverging lower stage representations, learning a joint semantic of the input by averaging individual representations is therefore harder, hence the limited performance gain obtained from increasing the number of local training steps at each round.
|FedAvg||3||3.0||122 (1.2 x)|
|FedAvg||1||50||1.4||112 (1.3 x)|
|FedAvg||3||50||4.2||97 (1.5 x)|
|FedAvg||1||20||2.4||114 (1.3 x)|
|FedAvg||3||20||7.2||109 (1.3 x)|
A dynamic model averaging that prevents the local model to diverge too much from the base model as suggested in  could ease averaging. Smaller local learning rates did not prove to work better as they strongly slow down early stage convergence. While local learning rate decay could greatly help with that matter, it might be hard to configure properly.
4.3.2 Domain adaptation
In this second experiment, the task of domain adaptation to French speaker is studied, starting from an existing ”Hey Snips” model trained on English data.
|Type||E||B||u||Nb Rounds||Nb Rounds|
|FedAvg||3||3.0||100 (1.5 x)||184 (1.3 x)|
|FedAvg||1||50||1.4||98 (1.6 x)||208 (1.2x)|
|FedAvg||3||50||4.6||139 (1.1 x)||120 (2.0 x)|
|FedAvg||1||20||2.9||104 (1.5 x)||130 (1.8 x)|
|FedAvg||3||20||8.8||119 (1.3 x)||262 (0.9 x)|
Results from Table 4 show the same speedup of 50% that was highlighted in experiment 1 when using FedAvg with comparison to standard FedSgd. Transfer learning using pre-training is slowing down convergence on the French target distribution. Two potential explanations can be put forward to account for this result. First, considering the model size constraints, this is likely due to the fact that our model architecture does not over-parametrize significantly the objective function in which case pre-training can be equivalent to a bad weight initialization strategy. The second explanation is related to the transfer learning strategy used. In the wake word detection task, the semantics of the output are the same across languages while those of the input vary widely since vocalizations for the same phonemes can be very different. One countermeasure suggested in  is to only share the upper layers and learn lower stages representations from scratch for each task.
4.4 Communication cost analysis
Communication cost is a strong constraint when learning from decentralized data, especially when userâs devices have limited connectivity and bandwidth.
Considering the asymmetrical nature of broadband speeds, the communication bottleneck for federated learning is the data transfer from clients to the parameter server once local training is completed . The total client upload bandwidth requirement is provided in the equation below:
Based on our results for experiment 1, this would yield a cost of 8MB per client for both experiments. On its end, the server receives 137 updates per round when , amounting for 110GB over the course of the whole optimization process with 1.4k users involved.
5 Conclusion and future Work
In this work, we investigate the use of federated learning on crowdsourced speech data to learn a resource-constrained wake word detector. We show that a revisited Federated Averaging algorithm with per-coordinate Adam update in place of standard global averaging allows the training to converge to the target metric of 95% recall per 5 FAH within 100 communication rounds on English and French datasets. The associated upstream communication costs per client are estimated at 8MB in our chosen setting.
The next step towards a real-life implementation is to figure out a strategy for local data collection and labeling. The challenge lies in building unbiased local datasets as it is crucial to avoid the pitfall of selection bias. Collection can not rely solely on true and false positives, an appropriate strategy for false negative is also needed. A dedicated application is being developed specifically for this purpose. Semi-supervised approaches for labeling are also studied in order to reduce the need for the user’s intervention.
-  Diederik P. Kingma and Jimmy Ba, “Adam: A method for stochastic optimization,” CoRR, vol. abs/1412.6980, 2014.
-  Alice Coucke, Alaa Saade, Adrien Ball, Théodore Bluche, Alexandre Caulier, David Leroy, Clément Doumouro, Thibault Gisselbrecht, Francesco Caltagirone, Thibaut Lavril, Maël Primet, and Joseph Dureau, “Snips voice platform: an embedded spoken language understanding system for private-by-design voice interfaces,” CoRR, vol. abs/1805.10190, 2018.
-  H. Brendan McMahan, Eider Moore, Daniel Ramage, and Blaise Agüera y Arcas, “Federated learning of deep networks using model averaging,” CoRR, vol. abs/1602.05629, 2016.
-  Daniel Povey, Xiaohui Zhang, and Sanjeev Khudanpur, “Parallel training of deep neural networks with natural gradient and parameter averaging,” CoRR, vol. abs/1410.7455, 2014.
-  Brendan McMahan and Daniel Ramage, “Federated learning: Collaborative machine learning without centralized training data,” https://ai.googleblog.com/2017/04/federated-learning-collaborative.html, 2017.
-  Florian Hartmann, “Federated learning,” https://florian.github.io/federated-learning/, 2018.
-  Jakub Konecný, H. Brendan McMahan, Daniel Ramage, and Peter Richtárik, “Federated optimization: Distributed machine learning for on-device intelligence,” CoRR, vol. abs/1610.02527, 2016.
-  Takayuki Nishio and Ryo Yonetani, “Client selection for federated learning with heterogeneous resources in mobile edge,” CoRR, vol. abs/1804.08333, 2018.
-  Shiqiang Wang, Tiffany Tuor, Theodoros Salonidis, Kin K. Leung, Christian Makaya, Ting He, and Kevin Chan, “When edge meets learning: Adaptive control for resource-constrained distributed machine learning,” CoRR, vol. abs/1804.05271, 2018.
-  Jakub Konecný, H. Brendan McMahan, Felix X. Yu, Peter Richtárik, Ananda Theertha Suresh, and Dave Bacon, “Federated learning: Strategies for improving communication efficiency,” CoRR, vol. abs/1610.05492, 2016.
-  Joachim Sicking Fabian HÃ¼ger Peter Schlicht Tim Wirtz Michael Kamp, Linara Adilova and Stefan Wrobel, “Efficient decentralized deep learning by dynamic model averaging,” CoRR, vol. abs/1807.03210, 2018.
-  Pete Warden, “Speech commands: A dataset for limited-vocabulary speech recognition,” CoRR, vol. abs/1804.03209, 2018.
-  Ian Goodfellow, Yoshua Bengio, and Aaron Courville, Deep Learning, MIT Press, 2016, http://www.deeplearningbook.org.