Finite Depth and Width Corrections to the Neural Tangent Kernel
We prove the precise scaling, at finite depth and width, for the mean and variance of the neural tangent kernel (NTK) in a randomly initialized ReLU network. The standard deviation is exponential in the ratio of network depth to width. Thus, even in the limit of infinite overparameterization, the NTK is not deterministic if depth and width simultaneously tend to infinity. Moreover, we prove that for such deep and wide networks, the NTK has a non-trivial evolution during training by showing that the mean of its first SGD update is also exponential in the ratio of network depth to width. This is sharp contrast to the regime where depth is fixed and network width is very large. Our results suggest that, unlike relatively shallow and wide networks, deep and wide ReLU networks are capable of learning data-dependent features even in the so-called lazy training regime.
B. Hanin]email@example.com B. Hanin]firstname.lastname@example.org
Modern neural networks typically overparameterized: they have many more parameters than the size of the datasets on which they are trained. That some setting of parameters in such networks can interpolate the data is therefore not surprising. But it is a priori unexpected that not only can such interpolating parameter values can be found by stochastic gradient descent (SGD) on the highly non-convex empirical risk but that the resulting network function not only interpolates but also extrapolates to unseen data. In an overparameterized neural network individual parameters can be difficult to interpret, and one way to understand training is to rewrite the SGD updates
of trainable parameters with a loss and learning rate as kernel gradient descent updates for the values of the function computed by the network:
Here is the current batch, the inner product is the empirical inner product over , and is the neural tangent kernel (NTK):
Relation (1) is valid to first order in It translates between two ways of thinking about the difficulty of neural network optimization:
The parameter space view where the loss , a complicated function of is minimized using gradient descent with respect to a simple (Euclidean) metric;
The function space view where the loss , which is a simple function of the network mapping , is minimized over the manifold of all functions representable by the architecture of using gradient descent with respect to a potentially complicated Riemannian metric on
A remarkable observation of Jacot et. al. in  is that simplifies dramatically when the network depth is fixed and its width tends to infinity. In this setting, by the universal approximation theorem [7, 13], the manifold fills out any (reasonable) ambient linear space of functions. The results in  then show that the kernel in this limit is frozen throughout training to the infinite width limit of its average at initialization, which depends on the depth and non-linearity of but not on the dataset.
This reduction of neural network SGD to kernel gradient descent for a fixed kernel can be viewed as two separate statements. First, at initialization, the distribution of converges in the infinite width limit to the delta function on the infinite width limit of its mean . Second, the infinite width limit of SGD dynamics in function space is kernel gradient descent for this limiting mean kernel for any fixed number of SGD iterations. This shows that as long as the loss is well-behaved with respect to the network outputs and is non-degenerate in the subspace of function space given by values on inputs from the dataset, SGD for infinitely wide networks will converge with probability to a minimum of the loss. Further, kernel method-based theorems show that even in this infinitely overparameterized regime neural networks will have non-vacuous guarantees on generalization .
But replacing neural network training by gradient descent for a fixed kernel in function space is also not completely satisfactory for several reasons. First, it shows that no feature learning occurs during training for infinitely wide networks in the sense that the kernel (and hence its associated feature map) is data-independent. In fact, empirically, networks with finite but large width trained with initially large learning rates often outperform NTK predictions at infinite width. One interpretation is that, at finite width, evolves through training, learning data-dependent features not captured by the infinite width limit of its mean at initialization. In part for such reasons, it is important to study both empirically and theoretically finite width corrections to . Another interpretation is that the specific NTK scaling of weights at initialization [4, 5, 17, 18, 19, 20] and the implicit small learning rate limit  obscure important aspects of SGD dynamics. Second, even in the infinite width limit, although is deterministic, it has no simple analytical formula for deep networks, since it is defined via a layer by layer recursion. In particular, the exact dependence, even in the infinite width limit, of on network depth is not well understood.
Moreover, the joint statistical effects of depth and width on in finite size networks remain unclear, and the purpose of this article is to shed light on the simultaneous effects of depth and width on for finite but large widths and any depth . Our results apply to fully connected ReLU networks at initialization for which we will show:
In contrast to the regime in which the depth is fixed but the width is large, is not approximately deterministic at initialization so long as is bounded away from . Specifically, for a fixed input the normalized on-diagonal second moment of satisfies
Thus, when is bounded away from , even when both are large, the standard deviation of is at least as large as its mean, showing that its distribution at initialization is not close to a delta function. See Theorem 1.
Moreover, when is the square loss, the average of the SGD update to from a batch of size one containing satisfies
where is the input dimension. Therefore, if the NTK will have the potential to evolve in a data-dependent way. Moreover, if is comparable to and then it is possible that this evolution will have a well-defined expansion in See Theorem 2.
In both statements above, means is bounded above and below by universal constants. We emphasize that our results hold at finite and the implicit constants in both and in the error terms are independent of Moreover, our precise results, stated in §2 below, hold for networks with variable layer widths. We have denoted network width by only for the sake of exposition. The appropriate generalization of to networks with varying layer widths is the parameter
which in light of the estimates in (1) and (2) plays the role of an inverse temperature.
1.1. Prior Work
A number of articles [3, 8, 15, 22] have followed up on the original NTK work . Related in spirit to our results is the article , which uses Feynman diagrams to study finite width corrections to general correlations functions (and in particular the NTK). The most complete results obtained in  are for deep linear networks but a number of estimates hold general non-linear networks as well. The results there, like in essentially all previous work, fix the depth and let the layer widths tend to infinity. The results here and in [9, 10, 11], however, do not treat as a constant, suggesting that the expansions (e.g. in ) can be promoted to expansions. Also, the sum-over-path approach to studying correlation functions in randomly initialized ReLU nets was previously taken up for the foward pass in  and for the backward pass in  and .
1.2. Implications and Future Work
Taken together (1) and (2) above (as well as Theorems 1 and 2) show that in fully connected ReLU nets that are both deep and wide the neural tangent kernel is genuinely stochastic and enjoys a non-trivial evolution during training. This suggests that in the overparameterized limit with , the kernel may learn data-dependent features. Moreover, our results show that the fluctuations of both and its time derivative are exponential in the inverse temperature
It would be interesting to obtain an exact description of its statistics at initialization and to describe the law of its trajectory during training. Assuming this trajectory turns out to be data-dependent, our results suggest that the double descent curve [1, 2, 21] that trades off complexity vs. generalization error may display significantly different behaviors depending on the mode of network overparameterization.
However, it is also important to point out that the results in [9, 10, 11] show that, at least for fully connected ReLU nets, gradient-based training is not numerically stable unless is relatively small (but not necessarily zero). Thus, we conjecture that there may exist a “weak feature learning” NTK regime in which network depth and width are both large but . In such a regime, the network will be stable enough to train but flexible enough to learn data-dependent features. In the language of  one might say this regime displays weak lazy training in which the model can still be described by a stochastic positive definite kernel whose fluctuations can interact with data.
Finally, it is an interesting question to what extent our results hold for non-linearities other than ReLU and for network architectures other than fully connected (e.g. convolutional and residual). Typical ConvNets, for instance, are significantly wider than they are deep, and we leave it to future work to adapt the techniques from the present article to these more general settings.
2. Formal Statement of Results
Consider a ReLU network with input dimension , hidden layer widths , and output dimension . We will assume that the output layer of is linear and initialize the biases in to zero. Therefore, for any input the network computes given by
and is a fixed probability measure on that we assume has a density with respect to Lebesgue measure and satisfies:
The three assumptions in (4) hold for vitually all standard network initialization schemes. The on-diagonal NTK is
We emphasize that although we have initialized the biases to zero, they are not removed them from the list of trainable parameters. Our first result is the following:
Theorem 1 (Mean and Variance of NKT on Diagonal at Init).
Moreover, we have that is bounded above and below by universal constants times
times a multiplicative error , where means is bounded above and below by universal constants times In particular, if all the hidden layer widths are equal (i.e. , for ), we have
This result shows that in the deep and wide double scaling limit
the NTK does not converge to a constant in probability. This is contrast to the wide and shallow regime and is fixed.
Our next result shows that when is the square loss is not frozen during training. To state it, fix an input to and define to be the update from one step of SGD with a batch of size containing (and learning rate ).
Theorem 2 (Mean of Time Derivative of NTK on Diagonal at Init).
We have that is bounded above and below by universal constants times
times a multiplicative error of size , where as in Theorem 1, In particular, if all the hidden layer widths are equal (i.e. , for ), we find
Observe that when is fixed and the pre-factor in front of scales like . This is in keeping with the results from [8, 14]. Moreover, it shows that if grow in any way so that , the update to from the batch at initialization will have mean It is unclear whether this will be true also for larger batches and when the arguments of are not equal. In contrast, if and is bounded away from , and the is proportional to the average update has the same order of magnitude as .
2.1. Organization for the Rest of the Article
The remainder of this article is structured as follows. First, in §3 we introduce some notation about paths and edges in the computation graph of . This notation will be used in the proofs of Theorems 1 and 2, which are outlined in §4 and particularly in §4.1 where give an in-depth but informal explanation of our strategy for computing moments of and its time derivative. Then, §5-§7 give the detailed argument. The computations in §5 explain how to handle the contribution to and coming only from the weights of the network. They are the most technical and we give them in full detail. Then, the discussion in §6 and §7 show how to adapt the method developed in §5 to treat the contribution of biases and mixed bias-weight terms in and . Since the arguments are simpler in these cases, we omit some details and focus only on highlighting the salient differences.
It will also be convenient to denote
Given a ReLU network with input dimension hidden layer widths , and output dimension , its computational graph is a directed multipartite graph whose vertex set is the disjoint union and in which edges are all possible ways of connecting vertices from with vertices from for The vertices are the neurons in , and we will write for and
Definition 1 (Path in the computational graph of ).
Given and , a path in the computational graph of from neuron to neuron is a collection of neurons in layers :
Further, we will write
Given a collection of neurons
we denote by
Note that with this notation, we have for each . For we also set
Correspondingly, we will write
If each edge in the computational graph of is assigned a weight , then associated to a path is a collection of weights:
Definition 2 (Weight of a path in the computational graph of ).
Fix , and let be a path in the computation graph of starting at layer and ending at the output. The weight of a this path at a given input to is
is the event that all neurons along are open for the input Here is as in (2).
Next, for an edge in the computational graph of we will write
Definition 3 (Unordered multisets of edges and their endpoints).
to be the unordered multiset of edges in the complete directed bi-paritite graph oriented from to For every define its left and right endpoints to be
where are unordered multi-sets.
Using this notation, for any collection of neurons and define for each the associated unordered multiset
of edges between layers and that are present in Similarly, we will write
for the set of all possible edge multisets realized by paths in On a number of occasions, we will also write
We will moreover say that for a path an edge in the computational graph of belongs to (written ) if
Finally, for an edge in the computational graph of , we set
for the normalized and unnormalized weights on the edge corresponding to (see (3)).
The proofs of Theorems 1 and 2 are so similar that we will prove them at the same time. In this section and in §4.1 we present an overview of our argument. Then, we carry out the details in §5-§7 below. Fix an input to Recall from (5) that
where we’ve set
and have suppressed the dependence on Similarly, we have
where we have introduced
and have used that the loss on the batch is given by for some target value To prove Theorem 1 we must estimate the following quantities:
To prove Theorem 2, we must control in addition
The most technically involved computations will turn out to be those involving only weights: namely, the terms These terms are controlled by writing each as a sum over certain paths that traverse the network from the input to the output layers. The corresponding results for terms involving the bias will then turn out to be very similar but with paths that start somewhere in the middle of network (corresponding to which bias term was used to differentiate the network output). The main result about the pure weight contributions to is the following
Proposition 3 (Pure weight moments for ).
We prove Proposition 3 in §5 below. The proof already contains all the ideas necessary to treat the remaining moments. In §6 and §7 we explain how to modify the proof of Proposition 3 to prove the following two Propositions:
Proposition 4 (Pure bias moments for ).
Finally, with probability
Proposition 5 (Mixed bias-weight moments for ).
The statements in Theorems 1 and 2 that hold for general now follow directly from Propositions 3-5. To see the asymptotics in Theorem 1 when we find after some routine algebra that when , the second moment equals
up to a multiplicative error of When is small, this expression is bounded above and below by a constant times
we find that when is small,
Similarly, if is large but is small, then, still assuming we find is well-approximated by
Before turning to the details of the proof of Propositions 3-5 below, we give an intuitive explanation of the key steps in our sum-over-path analysis of the moments of Since the proofs of all three Propositions follow a similar structure and Proposition 3 is the most complicated, we will focus on explaining how to obtain the first moments of . The first moment of has a similar flavor. Since the biases are initialized to zero and involves only derivatives with respect to the weights, for the purposes of analyzing the biases play no role. Without the biases, the output of the neural network, can be express as a weighted sum over paths in the computational graph of the network:
where the weight of a path was defined in (10) and includes both the product of the weights along and the condition that every neuron in is open at . The path begins at some neuron in the input layer of and passes through a neuron in every subsequent layer until ending up at the unique neuron in the output layer (see (7)). Being a product over edge weights in a given path, the derivative of with respect to a weight on an edge of the computational graph of is:
There is a subtle point here that also involves indicator functions of the events that neurons along are open at However, with probability , the derivative with respect to of these indicator functions is identically at The details are in Lemma 11.
Because is a sum of derivatives squared (see (16)), ignoring the dependence on the network input , the kernel roughly takes the form
where the sum is over collections of two paths in the computation graph of and edges in the computational graph of that lie on both (see Lemma 6 for the precise statement). When computing the mean, , by the mean zero assumption of the weights (see (4)), the only contribution is when every edge in the computational graph of is traversed by an even number of paths. Since there are exactly two paths, the only contribution is when the two paths are identical, dramatically simplifying the problem. This gives rise to the simple formula for (see (23)). The expression
for is more complex. It involves sums over four paths in the computational graph of as in the second statement of Lemma 6. Again recalling that the moments of the weights have mean , the only collections of paths that contribute to are those in which every edge in the computational graph of is covered an even number of times:
However, there are now several ways the four paths can interact to give such a configuration. It is the combinatorics of these interactions, together with the stipulation that the marked edges belong to particular pairs of paths, which complicates the analysis of We estimate this expectation in several steps:
Notice that depends only on the un-ordered multiset of edges determined by (see (14)). We therefore change variables in the sum from the previous step to find
where counts how many collections of four paths that have the same also have paths pass through and paths pass through Lemma 8 gives a precise expression for this Jacobian. It turns outs, as explained just below Lemma 8, that