Adam with Bandit Sampling for Deep Learning

# Adam with Bandit Sampling for Deep Learning

## Abstract

Adam is a widely used optimization method for training deep learning models. It computes individual adaptive learning rates for different parameters. In this paper, we propose a generalization of Adam, called Adambs, that allows us to also adapt to different training examples based on their importance in the model’s convergence. To achieve this, we maintain a distribution over all examples, selecting a mini-batch in each iteration by sampling according to this distribution, which we update using a multi-armed bandit algorithm. This ensures that examples that are more beneficial to the model training are sampled with higher probabilities. We theoretically show that Adambs improves the convergence rate of Adam— instead of in some cases. Experiments on various models and datasets demonstrate Adambs’s fast convergence in practice.

## 1 Introduction

However, another form of adaptivity that has proven beneficial in the context of basic SGD variants is adaptivity with respect to different examples in the training set [18, 6, 8, 38, 13]. For the first time, to the best of our knowledge, in this paper we show that accounting for the varying importance of training examples can even improve Adam’s performance and convergence rate. In other words, Adam has only considered the varying importance among parameters but not among the training examples. Although some prior work exploits importance sampling to improve SGD’s convergence rate [18, 6, 8, 38, 13], they often rely on some special properties of different models to estimate a sampling distribution over examples in advance. As a more general approach, we focus on learning the distribution during the training procedure, which is more adaptive. This is key to achieving faster convergence rate, because the sampling distribution typically changes from one iteration to the next, depending on the relationship between training examples and model parameters at each iteration.

Specifically, we cast the process of learning an optimal distribution as an adversarial multi-armed bandit problem. We use a multi-armed bandit method to update the distribution over all of the training examples, based on the partial information gained from the mini-batch at each iteration. We use the EXP3 method [3], but extend it to allow for sampling multiple actions at each step. The original EXP3 method only samples one action at a time, and collects the partial feedback by observing the loss incurred by that single action. In contrast, in optimization frameworks such as Adam, we need to sample a mini-batch, which typically contains more than one example. We thus need a bandit method that samples multiple actions and observes the loss incurred by each of those actions. In this paper, we extend EXP3 to use feedback from multiple actions and update its distribution accordingly. Although ideas similar to bandit sampling have been applied to some SGD variants and coordinate descent methods [30, 29, 24], extending this idea to Adam is not straightforward, as Adam is considerably more complex, due to its momentum mechanism and parameter adaptivity. To the best of our knowledge, we are the first to propose and analyze the improvement of using bandit sampling for Adam. Maintaining and updating distribution over all training examples incur some computational overhead. We will show an efficient way to update this distribution, which has time complexity logarithmic with respect to the total number of training examples. With this efficient update, the per-iteration cost is dominated by gradient computation, whose time complexity depends on the mini-batch size.

To endow Adam with the adaptive ability to different examples while keeping its original structure, we interleave our bandit method with Adam’s original parameter update, except that we select the mini-batch according to our maintained distribution. Adambs therefore adapts to both different parameters and different examples. We provide a theoretical analysis of this new method showing that our bandit sampling does indeed improve Adam’s convergence rate. Through an extensive empirical study across various optimization tasks and datasets, we also show that this new method yields significant speedups in practice as well.

## 2 Related Work

Importance sampling methods. Importance sampling for convex optimization problems has been extensively studied over the last few years. [27] proposed a generalized coordinate descent algorithm that samples coordinate sets to optimize the algorithm’s convergence rate. More recent works [38, 25] discuss the variance of the gradient estimates of stochastic gradient descent and show that the optimal sampling distribution is proportional to the per-sample gradient norm. [24] proposed an adaptive sampling method for both block coordinate descent and stochastic gradient descent. For coordinate descent, the parameters are partitioned into several prefixed blocks; for stochastic gradient descent, the training examples are partitioned into several prefixed batches. However, it is difficult to determine an effective way to partition blocks of parameters or batches of examples. [18] proposed to sample a big batch at every iteration to compute a distribution based on gradient norms of these examples from the big batch, followed by a mini-batch that is sampled from the big batch for parameter update. However, it is unclear how much speedup their method can achieve in terms of theoretical convergence rate.

We consider the following convex optimization problem: where is a scalar-valued objective function that needs to be minimized, and is the parameter of the model. Let the gradient of with respect to be denoted as . Assuming the training dataset is of size , we have , where is the gradient computed with only the -th example. Furthermore, at each iteration , we select a mini-batch of examples from the whole training set. We denote the realization of with respect to the mini-batch selected at iteration as , and the gradient of with respect to as . Depending on the sampling strategy of a mini-batch, in some cases, could be a biased estimate of , meaning . However, an unbiased estimate is required to update model parameters in stochastic optimization such as SGD and Adam. In such cases, we need to get an unbiased estimate , ensuring that . When a mini-batch is selected by sampling uniformly from all of the training examples, we have , thus allowing to simply be .

Adam [19] selects every mini-batch by uniform sampling and updates exponential moving averages of the gradient and the squared gradient with hyperparameters , which control the exponential decay rates of these moving averages: , where indicates the element-wise square of . The moving averages and are estimates of the st moment (the mean) and nd raw moment (the uncentered variance) of the gradient. These moment estimates are biased toward zero and are then corrected, resulting in bias-corrected estimates and : , where and are and raised to the power , respectively. Next, the parameter is updated according to where is a small value, to avoid division by zero.

A flexible framework to analyze iterative optimization methods such as Adam is the online learning framework. In this online setup, at each iteration , the optimization algorithm picks a point . A loss function is then revealed based on the seleted mini-batch, and the algorithm incurs loss . At the end of iterations, the algorithm’s regret is given by . In order for any optimization method to converge, it is necessary to ensure that . For Adam, the convergence rate is summarized in the Theorem 4.1 from [19]. Under further assumptions as in Corollary 4.2 from [19], it can be shown that . In this paper, we propose to endow Adam with bandit samping which could further improve the convergence rate under some assumptions.

## 4 Adam with Bandit Sampling

Suppose there are training examples. At iteration , a mini-batch of size is selected by sampling with replacement from all of the training examples, according to a distribution . Here, represents the relative importance of each example during the model training procedure. Denote the indices of examples selected in the mini-batch as the set . Assume the gradient computed with respect to the only example is . Its unbiased estimate is . We can easily verify that is unbiased because . Therefore, we define the unbiased gradient estimate according to batch at iteration as

 ^Gt=1KK∑k=1^gItk. (1)

Similarly, we can verify that .

It is worth noting that defined for Adambs is different than that of Adam. This is because the sampling strategy is different, and appropriate bias-correction is necessary here. We use , defined in Equation 1, to update first moment estimate and second moment estimate in each iteration. Specifically, let and be the first and second moment estimates at iteration , respectively. Then we update them in the following way

 (2)

where are hyperparameters that control the exponential decay rates of these moving averages. Our new method Adambs is described in Algorithm 1. Details about function in line are given in the next subsection.

Our method mantains a fine-grained probability distribution over all examples. This provides more flexibility in choosing mini-batches than prior work that uses coarse-grained probability distribution over pre-fixed mini-batches [24], because it is generally hard to decide how to partition the batches for pre-fixed mini-batches. If the training set is partitioned randomly, any mini-batch is likely to contain some important examples and some unimportant examples, making any two mini-batches equally good. In this case, prioritizing one mini-batch over another will not bring any advantage. It requires a fair amount of time on preprocessing the training dataset to partition the dataset in a good way, especially when the dataset is large. Some might argue that we could simply set the batch size to one in [24]. While the issue of batch partitioning does not exist anymore, this would significantly hurt the convergence rate because only one example is processed at each iteration. In contrast, Adambs does not require pre-partitioning mini-batches. At every iteration, a new mini-batch is formed dynamically by sampling from the whole dataset according to distribution . Here, the distribution is learned so that important examples can be selected into one mini-batch with high probability. Thus, it is more likely to get a mini-batch with all important examples, which could significantly boost the training performance.

We analyze the convergence of Adambs in Algorithm 1 using the same online learning framework [39] that is used by Adam. The following theorem 1 holds, regarding the convergence rate of Adambs.

###### Theorem 1.

Assume that the gradient estimate is bounded, , and distance between any is bounded, for any , and satisfy . Let , and , and . Adambs achieves the following convergence rate, for all ,

 R(T)≤ρ1d√T+√dρ2 ⎷1n2KT∑t=1E[n∑j=1∥gtj∥2ptj]+ρ3 (3)

where is the dimension of parameter space and

 ρ1=D2G∞2α(1−β1),  ρ2=α(1+β1)G∞(1−β1)√1−β2(1−γ)2,  ρ3=d∑i=1D2∞G∞√1−β22α(1−β1)(1−γ)2. (4)

### 4.2 Update of Distribution pt

From Equation 3, we can see that will affect the convergence rate. We wish to choose that could lead to a faster convergence rate. We derive how to update by minimizing the right side of Equation 3. Specifically, we want to minimize . It can be shown that for every iteration , the optimal distribution is proportional to the gradient norm of the individual example [25, 2]. Formally speaking, for any , the optimal solution to the problem is . It is computationally prohibitive to get the optimal solution, because we need to compute the gradient norm for every example at each iteration. Instead, we use a multi-armed bandit method to learn this distribution from the partial information that is available during training. The multi-armed bandit method maintains the distribution over all examples, and keeps updating this distribution at every training iteration. The partial information that we have at every iteration is the gradients of examples in the mini-batch. We use a bandit method based on EXP3 [3] but extended to handle multiple actions at every iteration. The pseudocode is described in Algorithm 2.

To further illustrate our distribution update algorithm from the perspective of the bandit setting, the number of arms is , where each arm corresponds to each training example, and the loss of pulling the arm is which is defined at line . We denote as the upper bound on the per-example gradient norm, i.e., . Similar upper bound is commonly used in related literature [19, 24], because of the popular gradient clipping trick [37]. Each time we update the distribution, we only pull these arms from the set , which is the mini-batch of training examples at current iteration. In line , is the KL divergence between and , and the set is defined as , where is a constant. From the definition of as in line from Algorithm 2, we can see that the loss is always nonnegative, and is inversely correlated with the gradient norm . This implies that an example with small gradient norm will receive large loss value, resulting in its weight getting decreased. Thus, examples with large gradient norms will be sampled with higher probabilities in subsequent iterations. Using a segment tree structure to store the weights for all the examples, we are able to update in time [24]. With the efficient way to update distribution, the per-iteration cost is dominated by the computation of gradients, especially for large models in deep learning. For simplicity, our theoretical analysis is still focused on convergence rate with respect to the number of iterations. In experiments, we demonstrate that our method can achieve a faster convergence rate with respect to time.

Let be the Bregman divergence associated with , where . By choosing this form of Bregman divergence, we can now study our EXP3-based distribution update algorithm under the general framework of online mirror descent with bandit feedback [7]. We have the following lemma regarding the convergence rate of Algorithm 2.

###### Lemma 1.

Assume that for all and , and for any , if we set , the update in Algorithm 2 implies

 ET∑t=1n∑j=1∥gtj∥2ptj−minp∈PET∑t=1n∑j=1∥gtj∥2pj≤RL2p2min√2nT. (5)

Combining Theorem 1 and Lemma 1, we have the following theorem regarding the convergence of Adambs.

###### Theorem 2.

Under assumptions from both Theorem 1 and Lemma 1, Adambs with the distribution update rule in Algorithm 2 achieves the following convergence rate

 R(T)≤ρ1d√T+ρ2√dn√K√M+ρ2L√Rpmin√dn√K(2nT)1/4+ρ3 (6)

where , and are defined in Equation 4.

Obviously, we have that, for Adambs, , ensuring that it can converge.

### 4.3 Comparison with Uniform Sampling

We now study a setting where we can further bound Equation 6 to show that the convergence rate of our Adambs is provably better than that of Adam, which uses uniform sampling. For simplicity, we consider a neural network with one hidden layer for a binary classification problem. The hidden layer contains one neuron with ReLU activation, and output layer also contains one neuron with sigmoid activation. Cross-entropy loss is used for this binary classification problem. The total loss of this neural network model can be written in the following way where is the sigmoid activation function, is the ReLU activation function, and are the label and the feature vector for -th example, respectively. Denote that . Here, . It implies that . We further asssume that feature vector follows doubly heavy-tailed distribution, which means that, and , where .

###### Lemma 2.

If the examples are sampled with uniform distribution, i.e. , assuming that feature vector follows doubly heavy-tailed distribution, for the aforementioned neural network model, we have .

Following Lemma 2, we have

###### Theorem 3.

Assuming that feature vector follows doubly heavy-tailed distribution, for the aforementioned neural network model, original Adam achieves the following rate

 R(T)≤O(d√T)+O(√dlogd√K√nlognn√T). (7)

On the other hand, we have

###### Lemma 3.

Assuming that feature vector follows doubly heavy-tailed distribution, for the aforementioned neural network model, we have

 minpj≥pminn∑j=1∥gj∥2pj=O(logdlog2n). (8)

By plugging Lemma 3 into Theorem 2, we have

###### Theorem 4.

Assuming that feature vector follows doubly heavy-tailed distribution, for the aforementioned neural network model, Adambs achieves the following rate

 R(T)≤O(d√T)+O(√dlogd√K√log2nn√T). (9)

Comparing the second terms at Equations 7 and 9, we see that our Adambs converges faster than Adam. The convergence rates are summarized in Table 1. In this table, we also compare against adaptive sampling methods from [24]. They maintain a distribution over prefixed batches. Adam with adaptive sampling over prefixed batches is called Adam-apt, and Adam with unifom sampling over prefixed batches is called Adam-uni. We can see that our method Adambs achieves faster convergence rate than the others. Depending on constants not shown in the big- notation, it’s also possible that the convergence rate is dominated by the first term , which makes our improvement marginal. We also rely on the experiments in the next section to demonstrate our method’s faster convergence rate in practice.

## 5 Experiments

### 5.2 Convolutional Neural Networks

Convolutional neural networks (CNN) with several layers of convolution, pooling, and non-linear units have shown considerable success in computer vision tasks. We train CNN models on three different datasets: CIFAR10, CIFAR100 and Fashion MNIST. CIFAR10 and CIFAR100 are labeled subsets of the million tiny images dataset [20]. CIFAR10 consists of color images of size in classes with images per class, whereas CIFAR100 consists of color images of size in classes with images per class. Fashion MNIST dataset [35] is similar to MNIST dataset except that images are in fashion categories.

For CIFAR10 and CIFAR100, our CNN architecture has layers of convolution filters and max pooling with size . Dropout with dropping probability , is applied to the nd and th convolutional layers. This is then followed by a fully connected layer of hidden units. For Fashion MNIST, since the dataset is simpler, we use a simpler CNN model. It contains layers of filters and max pooling with size is applied to the nd convolutional layer, which is followed by a fully connected layer of hidden units. The mini-batch size is set to , and learning rate is set to for all methods on all three datasets. All three methods are used to train CNN models for epochs and the results are shown in Figure 2. For CIFAR10 and CIFAR100, we can see that our method Adambs achieved loss value lower than others very quickly, within or epochs. For Fashion MNIST, our method Adambs is worse than others at the very begining, but keeps decreasing the loss value at a faster rate. After around seconds, Adambs is able to achieve lower loss value than others.

### 5.3 Recurrent Neural Networks

Recurrent neural networks, such as LSTM and GRU, are popular models to learn from sequential data. To showcase the generality of our method, we use Adambs to accelerate the training of RNN models in image classification problems, where image pixels are fed to RNN as a sequence. Specifically, we train an LSTM with dimension in the hidden space, and ReLU as recurrent activation function, which is followed by a fully connected layer to predict image class. We use two datasets: MNIST and Fashion MNIST. The batch size is set as , the learning rate is set as , and the maximum number of epochs is set as for all methods on both datasets. The results are shown in Figure 4. Adambs was able to quickly achieve lower loss value than the others.

### 5.4 Recurrent Convolutional Neural Networks

Recently, it has been shown that RNN combined with CNN can achieve good performance on some NLP tasks [21]. This new architecture is called recurrent convolutional neural network (RCNN). We train an RCNN model for the sentiment classification task on an IMDB movie review dataset. It contains movie reviews from IMDB, labeled by sentiment (positive or negative). Reviews are encoded by a sequence of word indexes. On this dataset, we train an RCNN model, which consists of a convolutional layer with filter of size , and a max pooling layer of size , followed by an LSTM and a fully connected layer. We set batch size to and learning rate to , and run all methods for epochs. The result is shown in Figure 4. We can see that all methods converge to the same loss value, but Adambs arrives at convergence much faster than the others.

## 6 Conclusion

We have presented an efficient method for accelerating the training of Adam by endowing it with bandit sampling. Our new method, Adambs, is able to adapt to different examples in the training set, complementing Adam’s adaptive ability for different parameters. A distribution is maintained over all examples and represents their relative importance. Learning this distribution could be viewed as an adversarial bandit problem, because only partial information is available at each iteration. We use a multi-armed bandit approach to learn this distribution, which is interleaved with the original parameter update by Adam. We provided a theoretical analysis to show that our method can improve the convergence rate of Adam in some settings. Our experiments further demonstrate that Adambs is effective in reducing the training time for several tasks with different deep learning models.

## Acknowledgements

This material is based upon work supported by the National Science Foundation under Grant No. 1629397 and the Michigan Institute for Data Science (MIDAS) PODS. The authors would like to thank Junghwan Kim and Morgan Lovay for their detailed feedback on the manuscript, and anynomous reviewers for their insightful comments.

As machine learning techniques are being used in more and more real-life products, deep learning is the most notable driving force behind it. Deep learning models have achieved state-of-the-art performance in scenarios such as image recognition, natural language processing, and so on. Our society has benefited greatly from the success of deep learning models. However, this success normally relies on large amount of data available to train the models using optimization methods such as Adam. In this paper, we propose a generalization of Adam that can be more efficient to train models on large amount of data, especially when the datasets are imbalanced. We believe our method could become a widely adopted optimization method for training deep learning models, thus bringing broad impact to many real-life products that rely on these models.

## Appendix A Proof of Theorem 1

According to Theorem in [19], the convergence rate of Adam is

 T∑t=1[ft(θt)−ft(θ∗)]≤D22α(1−β1)d∑i=1√T^vT,i+α(1+β1)G∞(1−β1)√1−β2(1−γ)2d∑i=1∥g1:T,i∥2+d∑i=1D2∞G∞√1−β22α(1−β1)(1−λ)2 (10)

First, we show

###### Proof.
 d∑i=1√^vT,i=d∑i=1 ⎷T∑j=11−β21−βT2βT−j2g2j,i≤d∑i=1 ⎷T∑j=11−β21−βT2βT−j2G2∞=G∞√1−β21−βT2d∑i=1 ⎷T∑j=1βT−j2=G∞√1−β21−βT2d∑i=1√1−βT21−β2=dG∞ (11)

Therefore, the above bound can be rewritten as

 T∑t=1[ft(θt)−ft(θ∗)]≤ρ1d√T+√dρ2 ⎷T∑t=1∥^Gt∥2+ρ3 (12)

where

 ρ1=D2G∞2α(1−β1)ρ2=α(1+β1)G∞(1−β1)√1−β2(1−γ)2ρ3=d∑i=1D2∞G∞√1−β22α(1−β1)(1−γ)2 (13)
###### Proof.
 T∑t=1[ft(θt)−ft(θ∗)]≤ρ1d√T+ρ2d∑i=1∥g1:T,i∥2+ρ3=ρ1d√T+dρ2d∑i=11d ⎷T∑t=1(^Gt,i)2+ρ3   (due to √⋅ is concave)≤ρ1d√T+dρ2 ⎷d∑i=11dT∑t=1(^Gt,i)2+ρ3=ρ1d√T+√dρ2 ⎷T∑t=1∥^Gt∥2+ρ3 (14)

we could get the following theorem for the convergence of Adam with sampling with replacement with batch size .

 T∑t=1[ft(θt)−ft(θ∗)]=ρ1d√T+√dρ2 ⎷1n2KT∑t=1E[n∑j=1∥gt∥2ptj]+ρ3 (15)
###### Proof.

Since , we have

 E[∥^Gt∥2]≤1n2K2K∑k=1E⎡⎣∥gItk∥2p2Itk⎤⎦=1n2K2K∑k=1E[n∑j=1∥gtj∥2(ptj)2ptj]=1n2K2K∑k=1E[n∑j=1∥gtj∥2ptj]=1n2KE[n∑j=1∥gtj∥2ptj] (16)

From previous step, we know

 T∑t=1[ft(θt)−ft(θ∗)]≤ρ1d√T+√dρ2 ⎷T∑t=1∥^Gt∥2+ρ3≤ρ1d√T+√dρ2 ⎷T∑t=11n2KE[n∑j=1∥gtj∥2ptj]+ρ3=ρ1d√T+√dρ2 ⎷1n2KT∑t=1E[n∑j=1∥gtj∥2ptj]+ρ3 (17)

## Appendix B Proof of Theorem 2

Theorem follows from Theorem and Lemma . Therefore, we focus on the proof of Lemma here. We prove Lemma using the framework of online learning with bandit feedback.

Online optimization is interested in choosing to solve the following problem

 minpt∈P,1≤t≤TT∑t=1Lt(pt) (18)

where is the loss that incurs at each iteration. Equivalently, the goal is the same as minimizing the pseudo-regret:

 ¯RT=ET∑t=1Lt(pt)−minp∈PET∑t=1Lt(p) (19)

The following Algorithm 3 similar to EXP3 could be used to solve the above problem.

To be clear, the Bregman divergence . Note that, the updating step of Algorithm 3 is equivalent to

 wt+1=∇ϕ∗(∇ϕ(pt)−αp^ht)pt+1=argminy∈PBϕ(y,wt+1) (20)

We have the following convergence result for Algorithm 3.

###### Proposition B.1.

The Algorithm 3 has the following convergence result

 ¯RT=ET∑t=1Lt(~pt)−minp∈PET∑t=1Lt(p)≤Bϕ(p,p1)αp+1αpT∑t=1EBϕ∗(∇ϕ(pt)−αp^ht,∇ϕ(pt))    +T∑t=1E[∥pt−~pt∥∥^ht∥∗] (21)

If is linear (i.e. ), then we have

 ¯RT=ET∑t=1Lt(~pt)−minp∈PET∑t=1Lt(p)≤Bϕ(p,p1)αp+1αpT∑t=1EBϕ∗(∇ϕ(pt)−αp^ht,∇ϕ(pt))    +T∑t=1E[∥pt−E[~pt|pt]∥∥^ht∥∗] (22)

For example, let . Then its convex conjugate is . In this case, due to Equation (20), the updating step of Algorithm 3 becomes

 wt+1j=ptjexp(−αp^ht,j),∀1≤j≤n (23)

The convergence result can also be simplified because

 Bϕ∗(∇ϕ(pt)−αp^ht,∇ϕ(pt))=n∑j=1ptj(exp(−αp^ht,j)+αp^ht,j−1)  (assume ^ht,j≥0 and due to ez−z−1≤z2/2 for z≤0)≤α2p2n∑j=1ptj^h2t,j (24)

Linear Case: Let’s consider a special case where is linear, i.e. . Assume is a probability distribution, i.e. . In this case, . At iteration , assume that we can’t get the whole vector of . Instead, we can get only one coordinate , where is sampled according to the distribution . This is equivalent to . Obviously,

 E[~ptj]=E[