Adam with Bandit Sampling for Deep Learning
Abstract
Adam is a widely used optimization method for training deep learning models. It computes individual adaptive learning rates for different parameters. In this paper, we propose a generalization of Adam, called Adambs, that allows us to also adapt to different training examples based on their importance in the model’s convergence. To achieve this, we maintain a distribution over all examples, selecting a minibatch in each iteration by sampling according to this distribution, which we update using a multiarmed bandit algorithm. This ensures that examples that are more beneficial to the model training are sampled with higher probabilities. We theoretically show that Adambs improves the convergence rate of Adam— instead of in some cases. Experiments on various models and datasets demonstrate Adambs’s fast convergence in practice.
1 Introduction
Stochastic gradient descent (SGD) is a popular optimization method, which iteratively updates the model parameters by moving them in the direction of the negative gradient of the loss evaluated on a minibatch. However, standard SGD does not have the ability to use past gradients or adapt to individual parameters (a.k.a. features). Some variants of SGD, such as AdaGrad [16], RMSprop [34], AdaDelta [36], or Nadam [15] can exploit past gradients or adapt to individual features. Adam [19] combines the advantages of these SGD variants: it uses momentum on past gradients, but also computes adaptive learning rates for each individual parameter by estimating the first and second moments of the gradients. This adaptive behavior is quite beneficial as different parameters might be of different importance in terms of the convergence of the model training. In fact, by adapting to different parameters, Adam has shown to outperform its competitors in various applications [19], and as such, has gained significant popularity.
However, another form of adaptivity that has proven beneficial in the context of basic SGD variants is adaptivity with respect to different examples in the training set [18, 6, 8, 38, 13]. For the first time, to the best of our knowledge, in this paper we show that accounting for the varying importance of training examples can even improve Adam’s performance and convergence rate. In other words, Adam has only considered the varying importance among parameters but not among the training examples. Although some prior work exploits importance sampling to improve SGD’s convergence rate [18, 6, 8, 38, 13], they often rely on some special properties of different models to estimate a sampling distribution over examples in advance. As a more general approach, we focus on learning the distribution during the training procedure, which is more adaptive. This is key to achieving faster convergence rate, because the sampling distribution typically changes from one iteration to the next, depending on the relationship between training examples and model parameters at each iteration.
In this paper, we propose a new general optimization method based on bandit sampling, called Adambs (Adam with Bandit Sampling), that endows Adam with the ability to adapt to different examples. To achieve this goal, we maintain a distribution over all examples, representing their relative importance to the overall convergence. In Adam, at each iteration, a minibatch is selected uniformly at random from training examples. In contrast, we select our minibatch according to our maintained distribution. We then use this minibatch to compute the first and second moments of the gradient and update the model parameters, in the same way as Adam. While seemingly simple, this process introduces another challenge: how to efficiently obtain this distribution and update it at each iteration? Ideally, to obtain the optimal distribution, at each iteration one would need to compute the gradients for all training examples. However, in Adam, we only have access to gradients of the examples in the selected minibatch. Since we only have access to partial feedback, we use a multiarmed bandit method to address this challenge. Our idea is illustrated in Figure 1.
Specifically, we cast the process of learning an optimal distribution as an adversarial multiarmed bandit problem. We use a multiarmed bandit method to update the distribution over all of the training examples, based on the partial information gained from the minibatch at each iteration. We use the EXP3 method [3], but extend it to allow for sampling multiple actions at each step. The original EXP3 method only samples one action at a time, and collects the partial feedback by observing the loss incurred by that single action. In contrast, in optimization frameworks such as Adam, we need to sample a minibatch, which typically contains more than one example. We thus need a bandit method that samples multiple actions and observes the loss incurred by each of those actions. In this paper, we extend EXP3 to use feedback from multiple actions and update its distribution accordingly. Although ideas similar to bandit sampling have been applied to some SGD variants and coordinate descent methods [30, 29, 24], extending this idea to Adam is not straightforward, as Adam is considerably more complex, due to its momentum mechanism and parameter adaptivity. To the best of our knowledge, we are the first to propose and analyze the improvement of using bandit sampling for Adam. Maintaining and updating distribution over all training examples incur some computational overhead. We will show an efficient way to update this distribution, which has time complexity logarithmic with respect to the total number of training examples. With this efficient update, the periteration cost is dominated by gradient computation, whose time complexity depends on the minibatch size.
To endow Adam with the adaptive ability to different examples while keeping its original structure, we interleave our bandit method with Adam’s original parameter update, except that we select the minibatch according to our maintained distribution. Adambs therefore adapts to both different parameters and different examples. We provide a theoretical analysis of this new method showing that our bandit sampling does indeed improve Adam’s convergence rate. Through an extensive empirical study across various optimization tasks and datasets, we also show that this new method yields significant speedups in practice as well.
2 Related Work
Boosting and bandit methods. The idea of taking advantage of the difference among training examples has been utilized in many boosting algorithms [32, 33, 31, 22, 12]. The wellknown AdaBoost algorithm [32] builds an ensemble of base classifiers iteratively, of which each base classifier is trained on the same set of training examples with adjusted weights. Because of the way AdaBoost adjusts weights on training examples, it is able to focus on examples that are hard, thus decreasing the training error very quickly. In addition, it has been often observed in experiments that AdaBoost has very good generalization ability, which is discussed and analyzed in several work [33, 28, 22]. Both AdaBoost and our method aim to improve training by iteratively adjusting example weights. However, the amount of available information is very different every time they adjust example weights. AdaBoost receives full information, in the sense that each training example needs to run through the uptodate ensemble model to determine which examples are still misclassified. Our method only receives partial information because we can only select a minibatch at each iteration, which brings up the tradeoff between exploration (i.e., select other examples to get more information) and exploitation (i.e., select the empirically best examples based on already collected information). Multiarmed bandit problem is a general setting for studying the explorationexploitation tradeoff that also appears in many other cases [4, 3]. For example, it has been applied to speed up maximum inner product search when only a subset of vector coordinates can be selected for floating point multiplication at each round [23].
Importance sampling methods. Importance sampling for convex optimization problems has been extensively studied over the last few years. [27] proposed a generalized coordinate descent algorithm that samples coordinate sets to optimize the algorithm’s convergence rate. More recent works [38, 25] discuss the variance of the gradient estimates of stochastic gradient descent and show that the optimal sampling distribution is proportional to the persample gradient norm. [24] proposed an adaptive sampling method for both block coordinate descent and stochastic gradient descent. For coordinate descent, the parameters are partitioned into several prefixed blocks; for stochastic gradient descent, the training examples are partitioned into several prefixed batches. However, it is difficult to determine an effective way to partition blocks of parameters or batches of examples. [18] proposed to sample a big batch at every iteration to compute a distribution based on gradient norms of these examples from the big batch, followed by a minibatch that is sampled from the big batch for parameter update. However, it is unclear how much speedup their method can achieve in terms of theoretical convergence rate.
Other sample selection methods. Several strategies have been proposed to carefully select minibatches in order to improve on training deep learning models. Curriculum learning [5, 17] is another optimization strategy that leverages a pretrained teacher model to train the target model. The teacher model is critical in determining how to assign examples to minibatches. In this paper, we focus on the case when we do not have access to an extra teacher model. However, utilizing a teacher model is likely to further improve the performance of our method. For example, it can be used to initialize the example weights which can help the bandit method to learn the weights more quickly. In addition, LocalitySensitive Hashing (LSH) has been used to improve the convergence rate of SGD by adaptively selecting examples [9]. It is worth noting that a recent paper [26] points out that Adam’s rapid decay of the learning rate using the exponential moving averages of squared past gradients essentially limits the reliance of the update to only the past few gradients. This prevents Adam from convergence in some cases. They proposed a variant of Adam, called AMSGrad, with long term memory of past gradients. In the main text of this paper, we remain focused on the original Adam. Similar analysis could be carried over to AMSGrad, which is discussed in appendix. We note that Adam with the learning rate warmup (another variant of Adam), is the common practice in training transformer models for NLP tasks [14, 11]. However, due to the lack of wellstudied theoretical analysis of this variant in the literature, we still base our dicussion on the original Adam.
3 Preliminaries about Adam
We consider the following convex optimization problem: where is a scalarvalued objective function that needs to be minimized, and is the parameter of the model. Let the gradient of with respect to be denoted as . Assuming the training dataset is of size , we have , where is the gradient computed with only the th example. Furthermore, at each iteration , we select a minibatch of examples from the whole training set. We denote the realization of with respect to the minibatch selected at iteration as , and the gradient of with respect to as . Depending on the sampling strategy of a minibatch, in some cases, could be a biased estimate of , meaning . However, an unbiased estimate is required to update model parameters in stochastic optimization such as SGD and Adam. In such cases, we need to get an unbiased estimate , ensuring that . When a minibatch is selected by sampling uniformly from all of the training examples, we have , thus allowing to simply be .
Adam [19] selects every minibatch by uniform sampling and updates exponential moving averages of the gradient and the squared gradient with hyperparameters , which control the exponential decay rates of these moving averages: , where indicates the elementwise square of . The moving averages and are estimates of the st moment (the mean) and nd raw moment (the uncentered variance) of the gradient. These moment estimates are biased toward zero and are then corrected, resulting in biascorrected estimates and : , where and are and raised to the power , respectively. Next, the parameter is updated according to where is a small value, to avoid division by zero.
A flexible framework to analyze iterative optimization methods such as Adam is the online learning framework. In this online setup, at each iteration , the optimization algorithm picks a point . A loss function is then revealed based on the seleted minibatch, and the algorithm incurs loss . At the end of iterations, the algorithm’s regret is given by . In order for any optimization method to converge, it is necessary to ensure that . For Adam, the convergence rate is summarized in the Theorem 4.1 from [19]. Under further assumptions as in Corollary 4.2 from [19], it can be shown that . In this paper, we propose to endow Adam with bandit samping which could further improve the convergence rate under some assumptions.
4 Adam with Bandit Sampling
4.1 Adaptive MiniBatch Selection
Suppose there are training examples. At iteration , a minibatch of size is selected by sampling with replacement from all of the training examples, according to a distribution . Here, represents the relative importance of each example during the model training procedure. Denote the indices of examples selected in the minibatch as the set . Assume the gradient computed with respect to the only example is . Its unbiased estimate is . We can easily verify that is unbiased because . Therefore, we define the unbiased gradient estimate according to batch at iteration as
(1) 
Similarly, we can verify that .
It is worth noting that defined for Adambs is different than that of Adam. This is because the sampling strategy is different, and appropriate biascorrection is necessary here. We use , defined in Equation 1, to update first moment estimate and second moment estimate in each iteration. Specifically, let and be the first and second moment estimates at iteration , respectively. Then we update them in the following way
(2) 
where are hyperparameters that control the exponential decay rates of these moving averages. Our new method Adambs is described in Algorithm 1. Details about function in line are given in the next subsection.
Our method mantains a finegrained probability distribution over all examples. This provides more flexibility in choosing minibatches than prior work that uses coarsegrained probability distribution over prefixed minibatches [24], because it is generally hard to decide how to partition the batches for prefixed minibatches. If the training set is partitioned randomly, any minibatch is likely to contain some important examples and some unimportant examples, making any two minibatches equally good. In this case, prioritizing one minibatch over another will not bring any advantage. It requires a fair amount of time on preprocessing the training dataset to partition the dataset in a good way, especially when the dataset is large. Some might argue that we could simply set the batch size to one in [24]. While the issue of batch partitioning does not exist anymore, this would significantly hurt the convergence rate because only one example is processed at each iteration. In contrast, Adambs does not require prepartitioning minibatches. At every iteration, a new minibatch is formed dynamically by sampling from the whole dataset according to distribution . Here, the distribution is learned so that important examples can be selected into one minibatch with high probability. Thus, it is more likely to get a minibatch with all important examples, which could significantly boost the training performance.
We analyze the convergence of Adambs in Algorithm 1 using the same online learning framework [39] that is used by Adam.
The following theorem
Theorem 1.
Assume that the gradient estimate is bounded, , and distance between any is bounded, for any , and satisfy . Let , and , and . Adambs achieves the following convergence rate, for all ,
(3) 
where is the dimension of parameter space and
(4) 
4.2 Update of Distribution
From Equation 3, we can see that will affect the convergence rate. We wish to choose that could lead to a faster convergence rate. We derive how to update by minimizing the right side of Equation 3. Specifically, we want to minimize . It can be shown that for every iteration , the optimal distribution is proportional to the gradient norm of the individual example [25, 2]. Formally speaking, for any , the optimal solution to the problem is . It is computationally prohibitive to get the optimal solution, because we need to compute the gradient norm for every example at each iteration. Instead, we use a multiarmed bandit method to learn this distribution from the partial information that is available during training. The multiarmed bandit method maintains the distribution over all examples, and keeps updating this distribution at every training iteration. The partial information that we have at every iteration is the gradients of examples in the minibatch. We use a bandit method based on EXP3 [3] but extended to handle multiple actions at every iteration. The pseudocode is described in Algorithm 2.
To further illustrate our distribution update algorithm from the perspective of the bandit setting, the number of arms is , where each arm corresponds to each training example, and the loss of pulling the arm is which is defined at line . We denote as the upper bound on the perexample gradient norm, i.e., . Similar upper bound is commonly used in related literature [19, 24], because of the popular gradient clipping trick [37]. Each time we update the distribution, we only pull these arms from the set , which is the minibatch of training examples at current iteration. In line , is the KL divergence between and , and the set is defined as , where is a constant. From the definition of as in line from Algorithm 2, we can see that the loss is always nonnegative, and is inversely correlated with the gradient norm . This implies that an example with small gradient norm will receive large loss value, resulting in its weight getting decreased. Thus, examples with large gradient norms will be sampled with higher probabilities in subsequent iterations. Using a segment tree structure to store the weights for all the examples, we are able to update in time [24]. With the efficient way to update distribution, the periteration cost is dominated by the computation of gradients, especially for large models in deep learning. For simplicity, our theoretical analysis is still focused on convergence rate with respect to the number of iterations. In experiments, we demonstrate that our method can achieve a faster convergence rate with respect to time.
Let be the Bregman divergence associated with , where . By choosing this form of Bregman divergence, we can now study our EXP3based distribution update algorithm under the general framework of online mirror descent with bandit feedback [7]. We have the following lemma regarding the convergence rate of Algorithm 2.
Lemma 1.
Assume that for all and , and for any , if we set , the update in Algorithm 2 implies
(5) 
Theorem 2.
Obviously, we have that, for Adambs, , ensuring that it can converge.
4.3 Comparison with Uniform Sampling
We now study a setting where we can further bound Equation 6 to show that the convergence rate of our Adambs is provably better than that of Adam, which uses uniform sampling. For simplicity, we consider a neural network with one hidden layer for a binary classification problem. The hidden layer contains one neuron with ReLU activation, and output layer also contains one neuron with sigmoid activation. Crossentropy loss is used for this binary classification problem. The total loss of this neural network model can be written in the following way where is the sigmoid activation function, is the ReLU activation function, and are the label and the feature vector for th example, respectively. Denote that . Here, . It implies that . We further asssume that feature vector follows doubly heavytailed distribution, which means that, and , where .
Lemma 2.
If the examples are sampled with uniform distribution, i.e. , assuming that feature vector follows doubly heavytailed distribution, for the aforementioned neural network model, we have .
Following Lemma 2, we have
Theorem 3.
Assuming that feature vector follows doubly heavytailed distribution, for the aforementioned neural network model, original Adam achieves the following rate
(7) 
On the other hand, we have
Lemma 3.
Assuming that feature vector follows doubly heavytailed distribution, for the aforementioned neural network model, we have
(8) 
Theorem 4.
Assuming that feature vector follows doubly heavytailed distribution, for the aforementioned neural network model, Adambs achieves the following rate
(9) 
Comparing the second terms at Equations 7 and 9, we see that our Adambs converges faster than Adam. The convergence rates are summarized in Table 1. In this table, we also compare against adaptive sampling methods from [24]. They maintain a distribution over prefixed batches. Adam with adaptive sampling over prefixed batches is called Adamapt, and Adam with unifom sampling over prefixed batches is called Adamuni. We can see that our method Adambs achieves faster convergence rate than the others. Depending on constants not shown in the big notation, it’s also possible that the convergence rate is dominated by the first term , which makes our improvement marginal. We also rely on the experiments in the next section to demonstrate our method’s faster convergence rate in practice.
Algorithm  Convergence Rate  Algorithm  Convergence Rate 

Adambs  Adamapt  
Adam  Adamuni 
5 Experiments
5.1 Setup
To empirically evaluate the proposed method, we investigate different popular deep learning models.
We use the same parameter initialization when comparing different optimization methods.
In total, datasets are used: MNIST, Fashion MNIST, CIFAR10, CIFAR100 and IMDB. We run experiments on these datasets because they are benchmark datasets commonly used to compare optimization methods for deep learning [19, 26, 18].
It is worth noting that the importance sampling method proposed in [18] could also be applied to Adam.
In addition, they proposed an efficient way to upper bound the perexample gradient norm for neural networks to compute the distribution for importance sampling.
This could also be beneficial to our method, because the upper bound of gradient norm could be used in place of the gradient norm itself to update our distribution.
In the experiments, we compare our method against Adam and Adam with importance sampling (as described in [18], which we call Adamimpt).
To be fair, we use the upper bound in the place of perexample gradient norm in our method.
All the previous analysis also holds if an upper bound of gradient norm is used, because similar to Theorem , it will still upper bound .
Experiments are conducted using Keras [10] with TensorFlow [1] based on the code from [18].
To see if our method could accelerate the training procedure, we plot the curves of training loss value vs. wall clock time for these three methods



5.2 Convolutional Neural Networks
Convolutional neural networks (CNN) with several layers of convolution, pooling, and nonlinear units have shown considerable success in computer vision tasks. We train CNN models on three different datasets: CIFAR10, CIFAR100 and Fashion MNIST. CIFAR10 and CIFAR100 are labeled subsets of the million tiny images dataset [20]. CIFAR10 consists of color images of size in classes with images per class, whereas CIFAR100 consists of color images of size in classes with images per class. Fashion MNIST dataset [35] is similar to MNIST dataset except that images are in fashion categories.
For CIFAR10 and CIFAR100, our CNN architecture has layers of convolution filters and max pooling with size . Dropout with dropping probability , is applied to the nd and th convolutional layers. This is then followed by a fully connected layer of hidden units. For Fashion MNIST, since the dataset is simpler, we use a simpler CNN model. It contains layers of filters and max pooling with size is applied to the nd convolutional layer, which is followed by a fully connected layer of hidden units. The minibatch size is set to , and learning rate is set to for all methods on all three datasets. All three methods are used to train CNN models for epochs and the results are shown in Figure 2. For CIFAR10 and CIFAR100, we can see that our method Adambs achieved loss value lower than others very quickly, within or epochs. For Fashion MNIST, our method Adambs is worse than others at the very begining, but keeps decreasing the loss value at a faster rate. After around seconds, Adambs is able to achieve lower loss value than others.
5.3 Recurrent Neural Networks
Recurrent neural networks, such as LSTM and GRU, are popular models to learn from sequential data. To showcase the generality of our method, we use Adambs to accelerate the training of RNN models in image classification problems, where image pixels are fed to RNN as a sequence. Specifically, we train an LSTM with dimension in the hidden space, and ReLU as recurrent activation function, which is followed by a fully connected layer to predict image class. We use two datasets: MNIST and Fashion MNIST. The batch size is set as , the learning rate is set as , and the maximum number of epochs is set as for all methods on both datasets. The results are shown in Figure 4. Adambs was able to quickly achieve lower loss value than the others.
5.4 Recurrent Convolutional Neural Networks
Recently, it has been shown that RNN combined with CNN can achieve good performance on some NLP tasks [21]. This new architecture is called recurrent convolutional neural network (RCNN). We train an RCNN model for the sentiment classification task on an IMDB movie review dataset. It contains movie reviews from IMDB, labeled by sentiment (positive or negative). Reviews are encoded by a sequence of word indexes. On this dataset, we train an RCNN model, which consists of a convolutional layer with filter of size , and a max pooling layer of size , followed by an LSTM and a fully connected layer. We set batch size to and learning rate to , and run all methods for epochs. The result is shown in Figure 4. We can see that all methods converge to the same loss value, but Adambs arrives at convergence much faster than the others.
6 Conclusion
We have presented an efficient method for accelerating the training of Adam by endowing it with bandit sampling. Our new method, Adambs, is able to adapt to different examples in the training set, complementing Adam’s adaptive ability for different parameters. A distribution is maintained over all examples and represents their relative importance. Learning this distribution could be viewed as an adversarial bandit problem, because only partial information is available at each iteration. We use a multiarmed bandit approach to learn this distribution, which is interleaved with the original parameter update by Adam. We provided a theoretical analysis to show that our method can improve the convergence rate of Adam in some settings. Our experiments further demonstrate that Adambs is effective in reducing the training time for several tasks with different deep learning models.
Acknowledgements
This material is based upon work supported by the National Science Foundation under Grant No. 1629397 and the Michigan Institute for Data Science (MIDAS) PODS. The authors would like to thank Junghwan Kim and Morgan Lovay for their detailed feedback on the manuscript, and anynomous reviewers for their insightful comments.
Broader Impact
As machine learning techniques are being used in more and more reallife products, deep learning is the most notable driving force behind it. Deep learning models have achieved stateoftheart performance in scenarios such as image recognition, natural language processing, and so on. Our society has benefited greatly from the success of deep learning models. However, this success normally relies on large amount of data available to train the models using optimization methods such as Adam. In this paper, we propose a generalization of Adam that can be more efficient to train models on large amount of data, especially when the datasets are imbalanced. We believe our method could become a widely adopted optimization method for training deep learning models, thus bringing broad impact to many reallife products that rely on these models.
Appendix A Proof of Theorem
According to Theorem in [19], the convergence rate of Adam is
(10) 
First, we show
Proof.
(11) 
∎
Therefore, the above bound can be rewritten as
(12) 
where
(13) 
Proof.
(14) 
∎
we could get the following theorem for the convergence of Adam with sampling with replacement with batch size .
(15) 
Proof.
Since , we have
(16) 
From previous step, we know
(17) 
∎
Appendix B Proof of Theorem
Theorem follows from Theorem and Lemma . Therefore, we focus on the proof of Lemma here. We prove Lemma using the framework of online learning with bandit feedback.
Online optimization is interested in choosing to solve the following problem
(18) 
where is the loss that incurs at each iteration. Equivalently, the goal is the same as minimizing the pseudoregret:
(19) 
The following Algorithm 3 similar to EXP3 could be used to solve the above problem.
To be clear, the Bregman divergence . Note that, the updating step of Algorithm 3 is equivalent to
(20) 
We have the following convergence result for Algorithm 3.
Proposition B.1.
The Algorithm 3 has the following convergence result
(21) 
If is linear (i.e. ), then we have
(22) 
For example, let . Then its convex conjugate is . In this case, due to Equation (20), the updating step of Algorithm 3 becomes
(23) 
The convergence result can also be simplified because
(24) 
Linear Case: Let’s consider a special case where is linear, i.e. . Assume is a probability distribution, i.e. . In this case, . At iteration , assume that we can’t get the whole vector of . Instead, we can get only one coordinate , where is sampled according to the distribution . This is equivalent to . Obviously,