Fast Convergence for Stochastic and Distributed Gradient Descent in the Interpolation Limit
Modern supervised learning techniques, particularly those using so called deep nets, involve fitting high dimensional labelled data sets with functions containing very large numbers of parameters. Much of this work is empirical, and interesting phenomena have been observed that require theoretical explanations, however the non-convexity of the loss functions complicates the analysis. Recently it has been proposed that some of the success of these techniques resides in the effectiveness of the simple stochastic gradient descent algorithm in the so called interpolation limit in which all labels are fit perfectly. This analysis is made possible since the SGD algorithm reduces to a stochastic linear system near the interpolating minimum of the loss function. Here we exploit this insight by analyzing a distributed algorithm for gradient descent, also in the interpolating limit. The algorithm corresponds to gradient descent applied to a simple penalized distributed loss function, . Here each node is allowed its own parameter vector and denotes edges of a connected graph defining the communication links between nodes. It is shown that this distributed algorithm converges linearly (ie the error reduces exponentially with iteration number), with a rate where is the smallest nonzero eigenvalue of the sample covariance or the Hessian H. In contrast with previous usage of similar penalty functions to enforce consensus between nodes, in the interpolating limit it is not required to take the penalty parameter to infinity for consensus to occur. The analysis further reinforces the utility of this limit in the theoretical treatment of modern machine learning algorithms.
Empirical performance advances in using so-called deep networks for machine learning have given rise to a growing body of theoretical work that tries to explain the success of these methods. While this theoretical area is not yet in a mature state, several interesting observations have been made and ideas put forth. One observation is that overfitting surprisingly does not seem to degrade the generalization performance of deep nets in supervised learning, . A related phenomenon is that relatively simple randomized optimization algorithms, particularly stochastic gradient descent (SGD), are quite effective, despite the complicated convex landscapes of the associated loss functions. Recently these phenomena have been connected and it has been shown that in the overfitting or interpolating limit, stochastic gradient descent converges rapidly to an interpolating minimum of the loss function. In addition, under suitable assumptions, the performance of the SGD algorithm saturates at relatively small mini-batch sizes, raising questions about the role of data parallelism.
Data and model parallelism plays an important role in the area of deep nets since both data and model sizes can be very large, and resource constraints enforce the need for such parallelism. Many algorithmic variants of parallelized SGD have been proposed and studied. Most of these variants employ centralized communications, although some fully distributed algorithms have also been studied. Much of this work is empirical. While there are theoretical results concerning convergence, these can be involved and lack intuitive simplicity. A particular benefit of the interpolating limit is that the analysis can be effectively carried out within the analytically tractable reaches of linear systems theory, with the caveat that it is necessary to understand the behavior in the size of large system sizes as well as the presence of randomness.
We exploit this simplicity to gain insight into the role of data correlations in determining the impact of parallelism on SGD, and to study a fully distributed consensus-based algorithm for gradient descent, with one data sample per computational node. The algorithm corresponds to gradient descent applied to a penalized distributed loss function,
Here each node is allowed its own parameter vector and denotes edges of a graph defining the communication links between nodes. At the minimum of the distributed loss function, each node recovers the same weight corresponding to an interpolating minimum of the centralized loss, . It is shown that this algorithm converges linearly (ie the approximation error decreases exponentially with iteration). The proposed distributed algorithm points to a number of avenues for future work.
Ii Relation to previous work
There is a large relevant literature on stochastic optimization, including considerations of parallelism and distributed computation, spanning multiple disciplines, that is not practical to review here, although we provide a few salient references and pointers. Closely related to the current work is the analysis of the exponential convergence of the randomized Kacmarz algorithm, which is analogous to SGD with a minibatch size of 1, and its distributed version, as well as recent work on consensus-based distributed SGD for learning deep nets. However the distributed algorithm obtained by performing gradient descent on the loss function in Eq.(1), does not appear to have been analyzed in the previous literature, particularly in the interpolating limit which greatly simplifies the analysis.
There is also a relevant body of literature on penalty-based methods for distributed optimization, as well as ADMM related approaches where the consensus constraint is enforced using Lagrange multipliers. While the current approach is in effect penalty-based, the important point is that in the interpolating limit the penalty term does not have to be made large to achieve consensus, with the exact optimum being obtained for any value of the penalty term (thus the penalty parameter may be optimized to speed up convergence). In contrast with ADMM the algorithm does not require the introduction of dual variables. Nevertheless, one future direction from the current work would be to re-examine these other methods of performing distributed optimization specifically in the interpolating limit, where analytical simplicities may emerge.
The energy function represented in Eq.(1) has a standard form familiar from (and motivated by) statistical physics, representing a sum of on-site energies and quadratic or elastic couplings between neighboring sites (on a suitable graph). In the statistical physics case, the variables relevant to physical systems are typically low dimensional, whereas we are interested in very high dimensional parameter vectors (modern machine learning applications often deal with parameter vectors with a dimensionality in the range. Note that although not relevant for real physical systems, interaction energies with such high dimensional vectors are still of theoretical interest to physicists since mean field analysis can become exact in this limit.
Convergence proofs in the relevant literature are often presented for more general loss functions than considered in this manuscript. Here the loss function reduces to a quadratic form near the interpolating minimum and the convergence analysis corresponds to bounding the largest singular value of the relevant linear operator governing the parameter iteration. Literature proofs are often presented nominally for more general loss functions, together with smoothness and strong convexity constraints. However, these constraints in turn imply that the loss functions are bounded above and below by a quadratic form, at least near the optimum. It should be possible to reformulate the results presented here into a more general context along with smoothness and strong convexity constraints but it is not clear that any further insights will be gained by such reformulations, which introduce considerably more notational complexity.
Ii-a Fast convergence of SGD in the interpolating limit
The starting point for this work is the recent work by Belkin et al presenting an analysis of the fast convergence of the SGD algorithm in the interpolating limit. We briefly recount some of the results of this paper and also present a slightly modified derivation of the convergence rate, as it is relevant to understanding the efficiency of parallel distributed computation for the problem at hand.
Consider the standard supervised learning setup with a data set consisting of labelled pairs , . The task is to learn a parametrized function chosen from a suitable function class, by minimizing the empirical risk , corresponding to a loss function
The interpolating limit is defined by the conditions , ie the loss is zero at each sample point (ie the interpolating function passes through each data point). In this limit is close to zero if is small; we will assume it is quadratic around , and that the function is differentiable at . Under these conditions where is . Suitably redefining variables as and , we see that we are left with a loss function corresponding to a linear model, , where . In the following we will drop the tildes for notational simplicity. We are now effectively analyzing linear regression, and we denote the dimensionality of and by d. To be in the linear interpolation regime one generally needs to overfit, ie . Note that in this case the Hessian has a number of zero singular values, corresponding to a null space about which the data is not informative. The ERM procedure will not reduce error in this null space, so our attention will be confined to the range of H corresponding to its nonzero singular values. In the linear regime the null space is left invariant. For notational simplicity the vectors denote only the projections orthogonal to the null space (the projection parallel to the null space is simply left invariant by the iterative procedures below).
With the above setup it is easy to verify that the corresponding GD algorithm is given by ( is the learning rate):
We have made some notational choices to simplify the following considerations and written the Hessian H (equivalently the sample covariance of the vectors ) in terms of a sum of projection operators corresponding to the individual data points. Note that , and if the vectors are orthogonal then when the indices are unequal (in general the do not commute). We now introduce iteration-dependent stochastic binary variables where the variables will be chosen i.i.d from a Bernoulli distribution, with . The idea is that iff the sample is picked in the minibatch used in the time step. Note also that the GD case is recovered for where each sample is picked with probability 1.
This stochastic formulation is slightly different from the usual setting, where mini batches are picked with fixed size m; here the batch sizes will fluctuate from one iteration to another, with an average size of . It is easy to treat the fixed batch size case using the same formulation, as long as one keeps track of the correlations introduced between different at a given iteration due to the constraint of the fixed batch size, but we will not present the corresponding formulae here. Note that the minibatch sampling procedure used by Belkin et al is with replacement, which is slightly different from the setting here. The reason for this definition of the SGD algorithm is that the randomness has been made explicit as an uncorrelated binary process associated with each sample, which makes the analysis a bit simpler. With this notation we define an SGD algorithm (technically note the variable batch size) by the iteration
Note that is uncorrelated with . Therefore can be computed as , where the expectation is over the stochastic processes . To analyze convergence one needs to bound the largest singular value of the matrix and optimize it with respect to the learning rate . Let us first consider the case where the sample vectors are orthogonal (equivalently for ). Then, noting that , we have
Thus in the orthogonal case the eigenvalues of are given by . To obtain the best bound for the decay rate one has to maximize this expression over , then minimize that result over . Consider the case , ie orthonormal sample vectors. Then the eigenvalues are all equal and are given by . The minimum value is obtained for , and is given by . If the are not all equal to 1, we get where , with and being the geometric and arithmetic means of and . Thus is less than 1 for any and this shoes exponential convergence to zero error with iteration number.
The number of iterations required to achieve (on average) a relative error of is given by ie , whereas during that same time a computational cost of is incurred. The total computation cost to achieve a fixed total error depends on the batch size as . This cost decreases as increases, so that bigger batch sizes will produce the same error at a lower computational cost, indicating that the problem will continue to benefit from data parallelism as increases. However, the situation is different if the data vectors have strong correlations, and in particular the Hessian matrix has a few large eigenvalues that dominates its trace. In this case, Belkin et al show that the gain from parallelism is limited, and the parallelism gains saturate when m reaches a value given by .
Next we consider the more general case where are not orthogonal but are normalized (). Noting that it is easy to show that . As expected setting we recover the GD matrix. the eigenvalues of are then given by
where are the eigenvalues of the sample covariance or Hessian . The bound on the decay rate is given by . Since is quadratic in the maximum over is achieved at either or . For fixed the dependence on is also quadratic. If one plots vs for and one obtains two intersecting parabolas. The minimum tracks one parabola and then the other, with the overall minimum occurring when . Solving this equation one obtains (assuming )
For one obtains the GD result
where is the condition number pf , . If , as would be the case when there are strong correlations, then approaches fairly quickly. For some parameter choices the total computation cost shows a minimum for a value of but close to 1. However if (this is the case for the orthogonal matrix) then approaches more slowly. Thus the optimal choice of is dependent on the degree of correlations between then normalized sample vectors.
Iii Distributed GD with Elastic Penalty on a Graph in the Interpolating Limit
These considerations show both that SGD shows rapid convergence in the interpolation regime, and that data parallelism should be computationally beneficial if the sample vectors are not strongly correlated. Note that data parallelism is not dictated by computational cost alone - it may be practically impossible to store data locally at a central compute node, and one has to also consider the communication costs of centralized parallel computation. Even for data-parallel implementations of SGD, centralized communication to a parameter server may cause a problem. In the extreme case, where each compute node has one data vector, communicating all parameter vectors to a central server after gradient updates, would require a communication link with bandwidth . With both , this may be impossible to provide. With these motivations we proceed to study the fully distributed case, where individual compute nodes only communicate locally with a set of nodes connected to it. For simplicity we consider the case of a fixed, connected graph, although similar results should continue to hold on a fluctuating graph topology as long as the fluctuations still permit diffusion of signals.
We assume there are compute nodes, each with a single data vector, and a node-specific parameter vector. Define the penalized loss function as in Eq.(1), ie . Generally, the penalty term will not be minimized by a set of that also minimize an un-penalized loss function with this form. However, the interpolating limit is special since there exists a vector that minimizes each , that minimum value being zero ie for all i. Clearly, the penalty term also equals zero if . Since all terms in the sum are non-negative, it can be seen that an interpolating minimum of the loss is simultaneously a minimum of the penalized loss. This considerably simplifies things. Note again that we will ignore the presence of zero modes as the GD dynamics will leave a null subspace unchanged.
The distributed GD algorithm is given by
Where is the Graph Laplacian defined by the quadratic form . We have defined a concatenated vector . To prove exponential convergence of this distributed GD algorithm, one only needs to show that the linear operator governing the dynamics has its largest eigenvalue less than 1.
Equivalently, we need to show that the smallest eigenvalue of the matrix is greater than zero, where has diagonal blocks given by . Then , with the minimum being taken over unit vectors . The proof follows by expanding out the quadratic form:
We have to demonstrate that . Note that the argument being minimized is a sum over squares, so for the sum to be zero, each individual term must be simultaneously zero. However this is impossible. The only way for the Laplacian penalty to vanish is for to be all equal, for all (the normalization factor ensures that ). Plugging this choice of into the first part of Eq.(15) we get the quantity . Recall that we are only considering the subspace corresponding to the non-zero eigenvalues of (the dynamics leaves the null space of invariant). There is no choice of in this subspace which will make vanish: the minimum value is . the smallest nonzero eigenvalue of . For any other choice of the penalty will be strictly positive. Thus, there is no choice of for which all the terms on the LHS of Eq.(15) vanish, and therefore .
This argument does not provide an explicit estimate of , but we see that . Thus the largest eigenvalue of the evolution operator in Eq.(14) is (strictly) between 1 and . This concludes the proof that the error in the distributed GD algorithm shrinks exponentially to zero in the interpolating limit, with a rate . If the penalty term is large ( then we can expect that the first term in the quadratic form will dominate and .
Iv Conclusion and Discussion
In this manuscript we have exploited the simplicities arising in the interpolating limit for function learning, to analyze the convergence of stochastic and distributed gradient descent close to an interpolating minimum of the loss function. While the analysis is relatively simple and is based on linear regression using a least squared loss, we expect the conclusions to hold in qualitative terms for more general loss functions, with suitable smoothness and strong convexity constraints near the interpolating minimum.
We have introduced a variant of SGD in which data samples are picked using i.i.d Bernoulli processes, in which minibatch sizes will show a Binomial distribution. This variant simplifies theoretical analysis and may have more general utility than shown here. The analysis shows the importance of the correlation structure of the input vectors in determining the efficiency of SGD. The empirical efficiency of the SGD algorithm may point to the presence of strong correlations in real life data sets - even though the input dimensions are nominally very large, it is possible that the effective dimensionality is still modest.
We have presented and analyzed a distributed Gradient Descent algorithm also in the interpolating limit, with each compute node holding one data sample. We have shown that a Graph Laplacian-penalized distributed loss function adequately couples the nodes to drive the system to an interpolating minimum, with error exponentially decreasing with iterations (for finite sized connected graphs). Stochastic and asynchronous variants of the algorithm presented should be interesting to analyze (eg where the gradient update step is decoupled from the diffusion step). One possibility not discussed here, is to run the distributed GD algorithm in Eq.(14) in the under-parametrized regime, with the connectivity graph between data nodes determined by similarities in the sample vectors. In this case, the individual losses will not all be reduced to zero, but the penalty term would enforce local smoothness of the parameter vector on the connectivity graph. This would amount to a form of local linear regression.
PPM Thanks Misha Belkin and Saikat Chatterjee for extensive discussions.
-  I. Goodfellow, Y. Bengio, A. Courville, and Y. Bengio, Deep learning, vol. 1. MIT press Cambridge, 2016.
-  C. Zhang, S. Bengio, M. Hardt, B. Recht, and O. Vinyals, “Understanding deep learning requires rethinking generalization,” ArXiv e-prints, Nov. 2016.
-  M. Belkin, S. Ma, and S. Mandal, “To understand deep learning we need to understand kernel learning,” ArXiv e-prints, Feb. 2018.
-  P. Goyal, P. Dollár, R. Girshick, P. Noordhuis, L. Wesolowski, A. Kyrola, A. Tulloch, Y. Jia, and K. He, “Accurate, large minibatch sgd: training imagenet in 1 hour,” arXiv preprint arXiv:1706.02677, 2017.
-  S. Ma, R. Bassily, and M. Belkin, “The power of interpolation: Understanding the effectiveness of sgd in modern over-parametrized learning,” arXiv preprint arXiv:1712.06559, 2017.
-  J. Dean, G. Corrado, R. Monga, K. Chen, M. Devin, M. Mao, A. Senior, P. Tucker, K. Yang, Q. V. Le, et al., “Large scale distributed deep networks,” in Advances in neural information processing systems, pp. 1223–1231, 2012.
-  Z. Jiang, A. Balu, C. Hegde, and S. Sarkar, “Collaborative deep learning in fixed topology networks,” in Advances in Neural Information Processing Systems, pp. 5906–5916, 2017.
-  T. Strohmer and R. Vershynin, “A randomized kaczmarz algorithm with exponential convergence,” Journal of Fourier Analysis and Applications, vol. 15, no. 2, p. 262, 2009.
-  G. Kamath, P. Ramanan, and W.-Z. Song, “Distributed randomized kaczmarz and applications to seismic imaging in sensor network,” in Distributed Computing in Sensor Systems (DCOSS), 2015 International Conference on, pp. 169–178, IEEE, 2015.
-  S. Boyd, N. Parikh, E. Chu, B. Peleato, J. Eckstein, et al., “Distributed optimization and statistical learning via the alternating direction method of multipliers,” Foundations and Trends® in Machine Learning, vol. 3, no. 1, pp. 1–122, 2011.