Deep learning versus kernel learning: an empirical study of loss landscape geometry and the time evolution of the Neural Tangent Kernel
In suitably initialized wide networks, small learning rates transform deep neural networks (DNNs) into neural tangent kernel (NTK) machines, whose training dynamics is well-approximated by a linear weight expansion of the network at initialization. Standard training, however, diverges from its linearization in ways that are poorly understood. We study the relationship between the training dynamics of nonlinear deep networks, the geometry of the loss landscape, and the time evolution of a data-dependent NTK. We do so through a large-scale phenomenological analysis of training, synthesizing diverse measures characterizing loss landscape geometry and NTK dynamics. In multiple neural architectures and datasets, we find these diverse measures evolve in a highly correlated manner, revealing a universal picture of the deep learning process. In this picture, deep network training exhibits a highly chaotic rapid initial transient that within 2 to 3 epochs determines the final linearly connected basin of low loss containing the end point of training. During this chaotic transient, the NTK changes rapidly, learning useful features from the training data that enables it to outperform the standard initial NTK by a factor of 3 in less than 3 to 4 epochs. After this rapid chaotic transient, the NTK changes at constant velocity, and its performance matches that of full network training in 15% to 45% of training time. Overall, our analysis reveals a striking correlation between a diverse set of metrics over training time, governed by a rapid chaotic to stable transition in the first few epochs, that together poses challenges and opportunities for the development of more accurate theories of deep learning.
The remarkable empirical success of deep learning across a range of domains stands in stark contrast to our theoretical understanding of the mechanisms underlying this same success Bahri2020-mi. Indeed, we are currently far from a mature, unified mathematical theory of deep learning that is powerful enough to universally guide engineering design choices. As in many other fields of inquiry, a key prerequisite to any such theory is careful empirical measurements of the deep learning process, with the scientific aim of unearthing combinations of variables that obey correlated dynamical laws that can serve as the inspiration for future theories. Indeed, a large body of work has studied, mainly in isolation, diverse and intriguing phenomenological properties, as well as extreme simplifying theoretical limits, of deep learning. In particular, we focus on intertwined aspects of deep learning that have previously been studied largely in isolation: (1) the large scale structure of deep learning loss surfaces, (2) the local geometry of such loss surfaces, and (3) and the performance of linearized training methods, like the neural tangent kernel (NTK), that has gained attention through its ability to theoretically describe an infinite width low learning rate limit of deep learning in terms of kernel machines with random data-independent kernels. The fundamental goal of this work is to obtain a more integrative view of the intertwined relations between loss landscape geometry at multiple scales of organization and the dynamics of learning in deep networks, by performing simultaneous measurements of many diverse properties. We describe the previous work that motivates our current measurements in Section 1, and we summarize our results and contributions in Section 8, which can be read right after Section 1.
1 Diverse aspects of deep learning phenomenology
The large scale geometric structure of neural loss landscapes.
Recent work has revealed many insights into the shape of loss functions over the high dimensional space of neural network parameters. For example, (li2018measuring; goldilocks) demonstrates that training even within a random, low-dimensional affine subspace of parameter space can yield a network with low test loss. This suggests that the region of parameter space with low test loss must be a relatively high dimensional object, such that low dimensional random affine hyperplanes can generically intersect it. Moreover, (draxler2018essentially; kuditipudi2019explaining; garipov2018loss) show that different, independently trained networks in weight space with low loss can be connected through nonlinear pathways (found via an optimization process) that never leave this low loss manifold. However, direct linear pathways connecting two such independently trained networks typically always leave the low loss manifold. The loss function restricted to such linear paths then yields a loss barrier at an intermediate point between the two networks. (fort2019large) builds and provides evidence for a unifying geometric model of the low-loss manifold consisting of a network of mutually intersecting high dimensional basins (Fig. 1A). Two networks within a basin can be connected by a straight line that never leaves the low-loss manifold, while two networks in different basins can be connected by a piecewise linear path of low loss that is forced to traverse the intersection between two basins. fort2019deep uses these insights to argue that deep ensembles are hard to beat using local subspace sampling methods due to the geometry of this underlying loss landscape. frankle2019linear provides further evidence for this large-scale structure by demonstrating that after a very early stage of training of a parent network (but not earlier) two child networks trained starting from the parameters of the parent end up in the same low loss basin at the end of training, and could be connected by a linear path in weight space that does not leave the low loss manifold (Fig. 1B). Furthermore, jastrzebski2020breakeven; leclerc2020two show that the properties of the final minimum found are strongly influenced by the very early stages of training. Taken together, these results present an intriguing glimpse into the large scale structure of the low loss manifold, and the importance of early training dynamics in determining the final position on the manifold.
Neural tangent kernels, linearized training and the infinite width limit.
The neural tangent kernel (NTK) has garnered much attention as it provides a theoretical foothold to understand deep networks, at least in an infinite width limit with appropriate initialization scale and low learning rate jacot2018neural; novak2019neural. In such a limit, a network does not move very far in weight space over the course of training, and so one can view learning as a linear process occurring along the tangent space to the manifold of functions realizable by the parameters , at the initial function (Fig. 1C). This learning process is well described by kernel regression with a certain random kernel associated with the tangent space at initialization. The NTK is also a special case of Taylorized training bai2020taylorized, which approximates the realizable function space to higher order in the vicinity of initialization. Various works compare the training of deep networks to the NTK arora2019exact; lee2019wide; arora2019harnessing; shankar2020neural. In many cases, state of the art networks outperform their random kernel counterparts by significant margins, suggesting that deep learning in practice may indeed explore regions of function space far from initialization, with the tangent space twisting significantly over training time, and hence the kernel being learned from the data (Fig. 1D). However, the nature and extent of this function space motion, the degree of tangent space twisting, and how and when data is infused into a learned tangent kernel, remains poorly understood.
The local geometric structure of neural loss landscapes.
Much effort has gone into characterizing the local geometry of loss landscapes in terms of Hessian curvature and its impact on generalization and learning. Interestingly papyan2019measurements analyses the Hessian eigenspectrum of loss landscapes at scale, demonstrating that learning leads to the emergence of a small number of large Hessian eigenvalues, and many small ones, bolstering evidence for the existence of many flat directions in low loss regions depicted schematically in Fig. 1A. fort2019emergent shows that the gradients of logits with respect to parameters cluster tightly based on the logit over training time, leading directly to the emergence of very sharp Hessian eigenvalues. Moreover, a variety of work has explored relations between the curvature of local minima found by training and their generalization properties (dziugaite2017computing; chaudhari2019entropy; langford2002bounding; hochreiter1997flat; keskar2016large; fort2019large; fort2018goldilocks), and how learning rate and batch size affect the curvature of the minima found (jastrzbski2017factors; sagun2017empirical; NIPS2018_8049), with larger learning rates generically enabling escape from sharper minima (Fig. 1E). lewkowycz2020large makes a connection between learning rates and the validity of NTK training, showing that for infinitely wide networks, training with a learning rate above a scale determined by the top eigenvalue of the Hessian at initialization results in a learning trajectory that outperforms NTK training, presumably by exploring nonlocal regions of function space far away from initialization.
Towards an integrative view.
Above, we have reviewed previously distinct strands of inquiry into deep learning phenomenology that have made little to no contact with each other. Indeed, we currently have no understanding of how local and global loss geometry interacts with the degree of kernel learning in state of the art architectures and training regimes used in practice. For example, at what point in training is the fate of the final chosen basin in Fig. 1 A,B irrevocably determined? Does the kernel change significantly from initialization as in Fig. 1D? If so, when during training does the tangent kernel start to acquire knowledge of the data? Also, when does kernel learning finally stabilize? What relations do either of these times have to the time at which basin fate is determined? How does local geometry in terms of curvature change as all these events occur? Here we address these questions to obtain an integrative view of the learning process across a range of networks and datasets. While we only present results for ResNet20 trained on CIFAR10 and CIFAR100 in the main paper, in Appendix C we find similar results for a WideResNet, variations of Resnets and a Simple CNN trained on CIFAR10 and CIFAR100, indicating our results hold generally across architectures, datasets and training protocols. Many experimental details are covered in our Appendix.
2 Definition of measurement metrics for geometry and training
We now mathematically formalize the quantities introduced in the previous section as well as define more quantities whose dynamics we will measure during training. Let be training examples, with , where is the number of classes. Let denote the -dimensional output vector of logits, of a neural network parameterized by weights on input . We are interested in the average classification error over the samples . For training purposes, we also consider a (surrogate) loss for predicting when the true label is . Denote by the gradient of , evaluated at . Write for concatenation of the gradient vectors , for . Let be the Jacobian of with respect to the parameters . Define to be the concatenation of , which is then the Jacobian of with respect to the parameters . The row of , denoted , is a vector in . Let be the tensor where is the Hessian of logit w.r.t. weights .
Training Dynamics, Linearized training, and introduction of a data-dependent NTK.
Let be the weights at each iteration of SGD, based on minibatch estimates of the training loss , where is a subsample of data of size . We write for and similarly for , , and . The SGD update with learning rate is then
Consider also a second-order Taylor expansion to approximate the change to the logits for input :
Note, that for an infinitesimal , the dynamics in Eq. 1 are those of gradient flow, and terms higher than order 1 in Eq. 2 vanish. In this case, steepest descent in the parameter space corresponds to steepest descent in the function space using a neural tangent kernel (NTK),
Let denote the by gram matrix with entry . If for , i.e., if the tangent kernel is constant over time, then the dynamics correspond to those of training the neural network linearized at time . The kernel has been shown to be nearly constant in the case of very wide neural networks at initialization (see, e.g., (jacot2018neural; zou2019improved; ji2019polylogarithmic; du2018gradient; lee2019wide; chen2019much)). Intuitively, we can think of each of the columns of as a tangent vector to the manifold of realizable neural network functions in the ambient space of all functions of logits over input space , at the point in function space. Thus the span of the columns of , as varies, constitute the tangent planes in function space depicted schematically in Fig. 1CD. Since the kernel is the Gram matrix associated with these tangent functions, evaluated at the training points, then if the tangent space twists substantially, the kernel necessarily changes (as in Fig. 1D).
Conversely, if the NTK does not change substantially from initialization, then the full SGD training can be well approximated by training along the tangent space to at initialization, yielding the linearized training dynamics. This approach can be generalized to training along higher order Taylor approximations of the manifold in the vicinity of the initial function bai2020taylorized. In this work, in order to explore function space geometry and its impact on training, we extend this approach by doing full network training up to time , and then linearized training subsequently. This yields a linearized training trajectory , which can then be compared to the weight dynamics under full training (see Appendix for details). This approach geometrically corresponds to training along an intermediate tangent plane (one of the green planes in Fig. 1CD), or equivalently, corresponds to learning with a data-dependent NTK. This novel examination of how much training time is required to learn a high performing NTK, distinct from the random one used at initialization, and relations between this time and both the local and large scale structure of the loss landscape, constitutes a key contribution of our work.
Hierarchical exploration of the loss landscape through parents and children.
In order to explore the loss landscape and the stability of training dynamics in a more multiscale hierarchical manner than is possible using completely independent training runs, we employ a method of parent-child spawning frankle2019linear (shown schematically in Fig. 1B). In this process, a parent network is trained from initialization to a spawning time , yielding a parent weight trajectory . At the spawn time , several copies of the parent network are made, and these so-called children are then trained with independent minibatch stochasticity, yielding different child weight trajectories , where indexes the children, and is the final training time. We will be interested in various measures of the distance between children after training, as a function of their spawn time , as well as measures of the distance between the same network (either parent or child) at two different training times. We turn to these various distance measures next.
For finite width networks, the kernel changes with training time . We compare two Kernel gram matrices in a scale-invariant manner by computing a kernel distance:
We further track the speed at which the kernel changes. As discussed above, in non-linear neural networks, we do not expect Eq. 3 to vanish. In order to capture the evolution of the quantity in Eq. 3, we compute the kernel velocity , i.e. the rate of change of kernel distance. We use a time separation of epochs to capture appreciable change.
Error barrier between children.
To assess (and indeed define) whether two children arrive at the same basin or not at the end of training (see e.g. Fig. 1AB), we compute the error barrier between children along a linear path interpolating between them in weight space. Let , where and are the weights of two child networks, spawned from some iteration , and . At various we compute , which we call the error barrier. Note, that the error barrier at the end of training between two children is the same as instability in (frankle2019linear).
ReLU activation pattern distance.
In a ReLU network, the post-nonlinearity activations in layer are either greater or equal to . We can thus construct a tensor , with if for an input , node in the layer is strictly positive, and otherwise. We compare ReLU on/off similarity between networks parameterized by and by computing the Hamming distance between and , and normalizing by the total number of entries in .
Figures 4, 4 and 4: An integrated view of learning. (A) Parent network learning curves. (B) Error barrier between pairs of children at the end of training, as a function of spawn time, with children trained for same number of epochs as the parent. (C) and (D) Heatmaps representing the ReLU and kernel distance between a parent network at different pairs of training times. Dashed black lines indicate epochs at which the learning rate is dropped. (E) ReLU, function space, and kernel distances between pairs children at the end of training, as a function of spawn time.
Function space distance.
To compute the distance between the two functions and , parameterized by weights and , we would ideally like to calculate the degree of disagreement between their outputs averaged over the whole input space . However, since this is computationally intractable, we approximate this distance by the normalized fraction of test examples on which their predicted labels disagree. Let denote the test set. Then, l where are test inputs and is a normalizing constant chosen to aid comparison. In particular, we define to be the expected number of examples on which two classifiers would disagree assuming each made random independent predictions with the same error rates, and , as their error rates on the test set. This quantity is used also by (fort2019deep), and is given by , where is the number of classes. A unit distance indicates two networks make uncorrelated errors.
3 An integrative view of learning dynamics
Figs. 4 and 4 plot the full range of metrics defined in Section 2 for two SOTA networks. Panel A presents standard training curves. Panel B confirms the results of (frankle2019linear), that the error barrier on a linear path between two children decreases rapidly with spawning time, falling close to within two to five epochs. Panel C and D indicate that the NTK changes rapidly early in training, and more slowly later in training, as quantified by ReLU activation distance (C) and kernel distance (D) measured on a parent run at different pairs of times in training. Finally, Panel E shows that function, kernel and ReLU distances between children at the end of training also drop as a function of spawn time.
We note that the SOTA training protocols in Figs. 4 and 4 involve learning rate drops later in training, which alone could account for a slowing of the NTK evolution. Therefore we ran a constant learning rate experiment in Fig. 4. We see that all tracked metrics still exhibit the same key patterns: the error barrier drops rapidly within a few epochs (B), the NTK evolves very rapidly early on, but continues to evolve at a constant slow velocity later in training (C,D), and final distances between children drop at an early spawn time and remain constant thereafter (E).
Overall, these results provide an integrative picture of the learning process, which reveals an early, extremely short chaotic period in which the final basin chosen by a child is highly sensitive to SGD noise and the NTK evolves very rapidly, followed by a later more stable phase in which the basin selection is determined, the NTK continues to evolve, albeit more slowly, and the final distance between children remains smaller. In the next few sections we explore these results in more detail.
4 The local and global geometry of the loss landscape surrounding children
We first explore how both the global and local landscape geometry surrounding two child pairs and their spawning parent depend on the spawn time in Fig. 5. These three networks define a affine plane in weight space and a curved manifold in function space.
The first two columns of Fig. 5 clearly indicate that two children spawned at an early time in the chaotic training regime arrive at two different loss basins that are well separated in function space (top row), while two children spawned at a later time in the stable training regime arrive at the same loss basin, though this loss basin can still exhibit non-negligible function diversity (albeit smaller than the diversity between basins). Furthermore, the right two columns of Fig. 5 indicate that the test error as a function of position along the tangent plane to the 2D curved manifold in function space (either at the spawn point or a child point) is insufficient to describe the error along the full curved manifold in function space when the children are in different basins (top row), but can approximately describe the loss landscape when the children are in the same basin (bottom row). Thus Fig. 5 constitutes a new direct data-driven visualization of loss landscape geometry that provides strong evidence for several aspects of the conceptual picture laid out in Fig. 1: the existence of multiple basins (Fig. 1A), the chaotic sensitivity of basin fate to SGD choices early in training (Fig. 1B), and the twisting of tangent planes in function space that occur as one travels from one basin to another (Fig. 1ACD). See also Fig. 11 for a t-SNE visualization of the bifurcating evolution of all parents and children in function space that further corroborates this picture of loss landscape geometry and its impact on training dynamics.
In Fig. 6 we explore more quantitatively the relationship between the final function space distance between children, spawn time of children, and the error barrier. This figure demonstrates that the error barrier drops to zero rapidly within 2-3 epochs, and then after that, the later two children are spawned, the closer they remain to each other. Since these experiments were done by training parents and children at a constant learning rate over 200 epochs such that child distances stabilized, the reduction in achievable function space distance between children as a function of spawn time cannot be explained either by learning rate drops or by insufficient training time for children (see Fig. 10).
5 NTK velocity slows down and stabilizes after basin fate is determined
We next explore the relation between error barrier and kernel velocity in Fig. 7 by zooming in on the early epochs, compared to the full training shown in Fig. 4-4 panels B and D. This higher resolution view clearly reveals that the early chaotic training regime is characterized by a tightly correlated reduction in both error barrier and kernel velocity, with the latter stabilizing to a low non-zero velocity after the error barrier disappears. Thus the NTK evolves relatively rapidly until basin fate is sealed.
6 The data-dependent NTK rapidly learns features useful for performance
The rapid evolution of the NTK during the chaotic training regime and its subsequent constant velocity motion after basin fate determination, as shown in Fig. 7, raises a fundamental question: at what point during training does the NTK learn useful features that can yield high task performance, or even match full network training? We answer these questions in Fig. 8 through a two step training protocol. We first train the full nonlinear network up to a time . We then Taylor expand the full nonlinear network obtained at time with respect to the weights , and perform linearized training thereafter up to a total time . Geometrically this corresponds to training for time up to one of the intermediate green points in Fig. 1CD, and then subsequently training only within the green tangent space about that point in function space. We can think of this as training with a data-dependent NTK kernel that has been learned using data from time to . Classic NTK training corresponds to a random kernel arising when the onset time of linearized training is .
Using this two step procedure, Fig. 8 demonstrates several key findings. First, extremely rapidly, within to epochs, the data dependent NTK has learned features that allow it to achieve significantly better performance (i.e. error drops by at least a factor of 3) compared to the classic NTK obtained at initialization (see rapid initial drop of green curves in Fig. 8). Second, by about to epochs, representing 15% to 45% of training, the data-dependent NTK essentially matches the performance of a network trained for the standard full 200 epochs (compare green curves to purple baseline in Fig. 8). This indicates that the early chaotic training period characterized by rapid drops in error barrier and kernel velocity in Fig. 7 also corresponds to rapid kernel learning: useful information is acquired within a few epochs. This kernel learning continues, albeit more slowly after the initial chaotic learning period is over and the basin fate is already determined.
7 NTK and nonlinear training remain different even at low learning rates
In the NTK limit, which involves both infinite widths and infinitesimal learning rates, linearized training and full nonlinear training dynamics provably coincide. However, the persistent performance gap up to 30 to 90 epochs between linearized and full nonlinear training (green curves versus purple baselines in Fig. 8) indicates the NTK limit does not accurately describe training dynamics used in practice, at finite widths and large learning rates. We remove one of the two reasons for the discrepancy by comparing the same linearized training dynamics to extremely low learning rate nonlinear training dynamics (blue curves in Fig. 8). In this finite width low learning rate regime, we find, remarkably, that a significant performance gap persists between linearized and nonlinear training (the red nonlinear training advantage region in Fig. 9, left), but only during the first few epochs of training, corresponding precisely to the chaotic regime before basin fate is sealed. Indeed the disappearance of this low learning rate nonlinear advantage is tightly correlated with the disappearance of the error barrier (Fig. 9, right). This indicates that while the data-dependent NTK limit can describe well the low (but not high) learning rate dynamics after the first few epochs, this same NTK limit cannot accurately describe the full nonlinear learning dynamics during the highly chaotic early phase prior to basin fate determination, even when the full nonlinear training uses very low learning rates, and when the NTK is learned from the data. We present additional experiments with Taylor expansions of order 2 on ResNet in Fig. 19 and linear order for WideResNet in Fig. 19. In Section C.3, we also perform the same set of linearized training experiments on networks trained with no batch normalization to ensure that we observe the same effect.
8 Summary of contributions and discussion
In summary, we have performed large scale simultaneous measurements of diverse metrics (Figs. 4, 4 and 4) finding a strikingly universal chaotic to stable training transition across datasets and architectures that completes within two to three epochs. During the early chaotic transient: (1) the final basin fate of a network is determined (Fig. 7 left); (2) the NTK rapidly changes at high speed (Fig. 7 middle and Fig. 4D); (3) the NTK rapidly learns useful features in training data, outperforming the standard NTK at initialization by a factor of 3 within 3 to 4 epochs (Fig. 8 green curves); (4) even low learning rate training retains a nonlinear performance advantage over linearized NTK training with a learned kernel (Fig. 9 red regions); and (5) the error barrier, kernel velocity, and low learning rate nonlinear advantage all fall together in a tightly correlated manner (Fig. 7, right) and (Fig. 9, right). After this rapid chaotic transient, training enters a more stable regime in which: (6) SGD stochasticity allows more limited child exploration in terms of function space distance, leading to smaller function diversity within basins compared to between basins (Figs. 6 and 5); (7) the kernel velocity stabilizes to a fixed nonzero speed (Fig. 7 middle and Fig. 4D); (8) the data dependent kernel performance continues to improve, matching that of full network training by 30 to 90 epochs, of training, representing 15% to 45% of the full 200 epochs (Fig. 8 green curves).
The empirical picture uncovered by our work is much richer than what any theory of deep learning can currently capture. In particular, the NTK theory attempts to describe the entire nonlinear deep learning process using a fixed random kernel at initialization. While this description is provably accurate at infinite width and low learning rate, our results show it is a poor description of what occurs in practice at finite widths and large learning rates (Figs. 8 and 7). More interestingly, the NTK theory is even a poor description of nonlinear training at finite width and extremely low learning rates, especially during the early chaotic training phase (Fig. 9).
This rich phenomenological picture of the rapid sequential nature of the learning process could potentially yield practical dividends in terms of a theory for the rational design of learning rate schedules. For example, the timing of optimized learning rate drops coincide with the time when the data-dependent tangent kernel can achieve high accuracy. Indeed our observations are consistent with findings in (leclerc2020two). But more generally, we hope that our empirical measurements of such a rich phenomenology may serve as an inspiration for developing an equally rich unifying theory of deep learning that can simultaneously capture these diverse phenomena.
The goal of our work is to gain a better understanding of deep neural networks. This could potentially make machine learning applications more reliable and transparent in the long run.
DMR was supported, in part, by an NSERC Discovery Grant, Ontario Early Researcher Award, and a stipend provided by the Charles Simonyi Endowment. SG thanks the Simons Foundation, James S. McDonnell Foundation, NTT Research, and an NSF Career award for support. This research was in part carried out while GKD and DMR participated in the Special Year on Optimization, Statistics, and Theoretical Machine Learning at the Institute of Advanced Studies.
The authors would like to thank Jonathan Frankle, Ekansh Sharma, and Mufan Li for feedback on drafts, and Shems Saleh for helping to produce Fig. 1.
Appendix A Function distance between children runs
The function distance between the children runs is shown in Fig. 10 (right column) for ResNet20 on CIFAR-10 (top row) and CIFAR-100 (bottom row). The signal is relatively noisy from iteration to iteration, we therefore overlay the raw data with a smoothed out version with a window of epochs in Fig. 10.
We also produced a t-SNE (vanDerMaaten2008) visualization of parent and children evolution in the function space. To do that, we took predictions of the parent and children runs at different stages of their training on the test set. We then flattened the vector of predicted probabilities for all images and all of their classes into a single long vector, one for each stage of training of a network. We then used the t-SNE embedding to embed all parent and children runs into a 2D space. For the individual panels in Fig. 11, we highlighted the relevant embedded points, however, due to the nature of t-SNE all predictions had in fact been embedded together.
Appendix B Definitions of additional metrics
b.1 Logit gradient centroid alignment
The logit gradient centroid for each class is defined as . Previous work (fort2019emergent; papyan2019measurements) has shown that the span of the logit gradient centroids approximately tracks an important local quantity: the span of the top directions of maximal Hessian curvature (see also see Section B.2). Thus the span of the logit gradient centroids describes the orientation of the walls of the basins shown schematically in Fig. 1AB, and two trained networks with highly dissimilar logit gradient centroids likely lie in differently oriented basins. In order to evaluate how logit gradient centroids compare at and , we compute average cosine similarity
which we refer to as logit gradient centroid alignment.
b.2 Logit gradient centroids and the top Hessian eigenvectors
Here we note qualitative relations between the space of logit gradient centroids and the principal Hessian subspace. First, without loss of generality, assume and consider a single data point . Consider mean squared error (MSE) loss, with empirical risk term . Then
where (as defined in Section 2).
If , then is zero and the second term vanishes. Similarly, for and a training set of size , the Hessian of MSE error loss is equal to .
Since , and logit gradient centroids are defined as
we decompose , for . This decomposition of has been previously studied in (fort2019emergent; papyan2019measurements). Then .
According to our empirical results as well as other literature (Papyan2019MeasurementsOT; fort2019emergent), the logit gradient centroids are mutually almost orthogonal. Therefore, . For mutually orthogonal gradient centroids , this amounts to a singular value decomposition with non-zero singular values associated with singular vectors .
While the relative length of can be changing with training time, the empirically observed stability of their directions makes the Hessian eigenvector associated with the highest eigenvalue constrained to lie primarily within the vector space defined by the span of . This subspace of dimension ( for CIFAR-10, for CIFAR-100) has a significantly lower dimensions than the typical weight space of the network.
When evaluating the cosine similarity of the logit gradient centroids as in Eq. 5, we are approximately estimating the overlap between the low dimensional subspaces to which the sharpest directions of the Hessian are constrained between the two networks.
b.3 Escape threshold
Consider iterative optimizers that at each iteration perform an update on the weights of the form
where is a minibatch, i.e., a random subset .
Let be the Hessian of the empirical loss. Then, by a second order Taylor expansion of around , we have
The loss after one iteration decreases if the difference is negative. Under the second order approximation, the condition for non-increasing loss is equivalent to
We refer to left hand side term in Eq. 10 as the escape threshold. If the escape threshold is below zero, then the trajectory will be descending in the quadratic basin, under the assumption that the local quadratic approximation is accurate. If the escape threshold is positive, the loss will increase or the trajectory will escape the quadratic basin.
Gradient descent (GD) update is . Combining this with the bound , where is the spectral norm of , Eq. 10 simplifies to . We refer to this term as the escape threshold for GD.
Appendix C Additional results
Here we present additional experiments similar to the ones in Figs. 4, 4 and 4 that include measurements of logit gradient centroid clustering, Hessian spectral norms and escape-time threshold analysis (for the networks with no batch norm).
c.1 Diverse metrics for loss landscape and training are highly correlated
We compute all of above metrics for SOTA networks in Figs. 17, 17, 14, 4 and 4 (see Appendix D for details of training and hyperparameters). The top rows of Figs. 17, 17 and 14 describe the dynamics of parent training. From left to right, the test and training error drop, as well as the top Hessian eigenvalue drops over time, and all three distances (ReLU pattern, mean logit gradient, and kernel distance) computed on a parent run between pairs of training epochs reveal a rapid change around a very early point at about 5 epochs, followed by slow freezing of all three quantities. In particular, there is significant kernel learning. Moreover, there is a period of time early on where the top Hessian eigenvalue where is the learning rate. This large learning rate condition would be necessary for gradient descent to escape a quadratic minimum (Fig. 1E, top).
The bottom row computes various distances between pairs of children at the end of training as a function of the common time at which the pairs of children were spawned. These plots indicate that at an extremely early spawn epoch (x-axis) of around , the basin fate of the two children is sealed by their parent. Beyond , the two children end up in the same basin, as evidenced by the lack of a loss barrier along a linear path between them, and are much closer to each other, as measured by distances in (from left to right), weight space, ReLU pattern, logit gradient, function space, and kernel space. In contrast, before this early spawn epoch of , the final basin choice of the children displays a chaotic sensitivity to SGD steps, as evidenced by a large loss barrier and larger distances. Intriguingly, significant, albeit slow kernel learning continues after basin fate selection.
c.2 Further discussion
The integrative analysis of diverse measurements during neural network training (Figs. 17, 17, 17, 14, 4 and 4) and the results of linearized training (Fig. 8) reveal a uniform and striking story. First, very early in the training process—between 1 and 10%—of training time, the final basin fate of the neural network is determined. After this point, large scale motion of the neural network is no longer influenced by SGD noise and the networks trained with different SGD noise converge to low loss points in the same (linearly connected) basin. This is indicated by the fact that children spawned after this point are linearly connected through low error networks: the error barrier between spawned children goes below 0 at around 10 epochs for all networks. Additionally, the escape threshold and spectral norm analysis reveals that, after this point, the learning rate is small enough to keep the network within the local quadratic approximation of the loss surface, further supporting the hypothesis that the network does not leave the basin. Once the network is in the basin, we find that various metrics of network distance start decreasing.
Our linearized training analysis (Figs. 19 and 19) reveals further interesting features about this early period of training. Very early in the training process, before an epoch is completed, the data-dependent NTK—the first order approximation of the neural network—rapidly starts learning useful information about the data as evidenced by the fall off of the green lines in Figs. 19, 19 and 8. This occurs uniformly across networks and datasets. When a kernel machine is trained with these data dependent features, it performs significantly better than the NTK at random initialization. In fact, less than halfway through training, the data-dependent neural tangent kernel machine performs nearly as well as as the full network.
These observations have two important implications. First, while training the NTK at random initialization may not be very representative of training finite sized networks, across a range of networks, the data-dependent NTK obtained from a small amount of training is. Second, the features built up by the NTK relatively early in the training process are sufficient for achieving low errors competitive with the full non-linear networks.
While deep learning research often focuses on improving accuracy towards the end of training, our results show that the early phase of training is important for determining the final fate of the network. A better understanding of this phase may provide us with tools to diagnose and improve networks early on in the training process, thus decreasing the cost of training neural networks.
Grey dashed lines in F–J represent a straight line between the y axis values at epoch 0 and the final epoch. X-axis in F–J indicates the spawning epoch. Dashed lines in C–E mark epochs when the learning rate was dropped. In A, ‘child‘ line represents the final test error of a child spawn at epoch indicated on the x-axis. Fig. 14B, Fig. 14B, and Fig. 14B depicts the spectral norm of the Hessian compared to the learning rate . These are additional results extending Figs. 4, 4 and 4
Grey dashed lines in F–J represent a straight line between the y axis values at epoch 0 and the final epoch. X-axis in F–J indicates the spawning epoch. Dashed lines in C–E mark epochs when the learning rate was dropped. In A, ‘child‘ line represents the final test error of a child spawned at epoch indicated on the x-axis. Fig. 17B depicts the spectral norm of the Hessian compared to the learning rate . Fig. 17B and Fig. 17B look at the escape threshold (See Section B.3). These are additional results extending Figs. 14, 14, 14, 4, 4 and 4
c.3 Nonlinear advantage and function space distance without batch norm
In Fig. 6 we show the function space distance between children runs and its correlation with the error barrier for a network with batch normalization. The nonlinear advantage with batch normalization is also shown in Fig. 9. While state-of-the-art models use batch normalization and we therefore focused on it in the main text, we wanted to confirm that the same phenomena are observed for networks without batch normalization. Fig. 20 shows the function space results, and Fig. 21 shows the nonlinear advantage for ResNet20v1 without batch normalization trained on CIFAR-10 with Adam at learning rate .
Appendix D Experimental details
SimpleCNN: SimpleCNN is a 6 layer fully convolutional neural network. Each convolution has a 3x3 kernel, stride 1 and bias. The layers have 32, 32, 64, 64, 128, and 128 channels from the first to the last layer. The weights are initialized using Kaiming initialization (he2015delving) and the biases are initialized to 0. There is a 2D maxpooling with a 2x2 kernel and stride 2 after layers 2, 4, and 6. Layer 6 is followed by a 2d global average pool which results in a 1x128 unit feature vector from which the classes are linearly predicted.
ResNet20: We use the ResNet20 used with CIFAR-10/100 data in the original paper (he2015deep).
ResNet20 without batchnorm (RN no BN): Same as ResNet20 but with batchnorm turned off.
WideResNet-16-4 (WRN-16-4): We use the WideResNet-16-4 described in (zagoruyko2016wide) (16 layers, widen factor 4, no dropout).
d.2 Training Details
See Table 1 for training details and hyperparameters.
|Network||Dataset||Opt||LR||Mom||WD||LR Decay||Decay Epochs||Total Epochs|
|Resnet20||CIFAR100||SGD||1e-1||0.9||1e-4||0.1||60, 120, 160||200|
|RN20 no BN||CIFAR10||Adam||1e-3||-||0||0.1||80, 120||160|
|RN20 no BN||CIFAR100||Adam||1e-3||-||0||0.1||80, 120||160|
d.3 Extended training time and learning rate ablations
In Figs. 9 and 8, we train networks with a lower learning rate to compare to linearized training and measure the nonlinear advantage. This is done as follows: we first train a parent network at a constant learning rate of 0.1. At certain epochs (the x-axis of Fig. 8 and labels of Fig. 9) we spawn a child network and train it with a learning rate of 0.001 and independent SGD noise until convergence (the learning rate is chosen as the smallest learning rate at which the network converges in a reasonable amount of time, 1000 epochs in our case.) We then use the learning rate dropped accuracy is the final accuracy that the child converges to. For stability, we repeat this process and average the two independent runs. The linearized training is performed using the neuraltangents2020 package and implemented in JAX (jax2018github). It trains for additional 200 epochs at learning rate 0.001 that we choose based on a small-scale grid search for linearized training specifically. The nonlinear advantage is computed as the error difference between the final test performance of the low learning rate child network, and the final test performance of the linearized neural network.
d.4 Linearized training details
The linearized training was performed using the Taylor expansions tools described in novak2019neural and implemented in neuraltangents2020. We trained for 200 epochs using SGD with Momentum, implementing the training loop in JAX (jax2018github). The learning used was . We did a small-scale grid search for the best performing learning rate between and , choosing in the end.