Scaling Limits of Wide Neural Networks with Weight Sharing
Scaling Limits of Wide Neural Networks with Weight Sharing:
Gaussian Process Behavior, Gradient Independence, and Neural Tangent Kernel Derivation
Abstract
Several recent trends in machine learning theory and practice, from the design of stateoftheart Gaussian Process to the convergence analysis of deep neural nets (DNNs) under stochastic gradient descent (SGD), have found it fruitful to study wide random neural networks. Central to these approaches are certain scaling limits of such networks. We unify these results by introducing a notion of a straightline tensor program that can express most neural network computations, and we characterize its scaling limit when its tensors are large and randomized. From our framework follows {enumerate*}
the convergence of random neural networks to Gaussian processes for architectures such as recurrent neural networks, convolutional neural networks, residual networks, attention, and any combination thereof, with or without batch normalization;
conditions under which the gradient independence assumption – that weights in backpropagation can be assumed to be independent from weights in the forward pass – leads to correct computation of gradient dynamics, and corrections when it does not;
the convergence of the Neural Tangent Kernel, a recently proposed kernel used to predict training dynamics of neural networks under gradient descent, at initialization for all architectures in (1) without batch normalization. Mathematically, our framework is general enough to rederive classical random matrix results such as the semicircle and the MarchenkoPastur laws, as well as recent results in neural network Jacobian singular values. We hope our work opens a way toward design of even stronger Gaussian Processes, initialization schemes to avoid gradient explosion/vanishing, and deeper understanding of SGD dynamics in modern architectures.
Greg YangMSR
MSRMicrosoft Research AI
gregyang@microsoft.com
1 Introduction
Several recent trends in machine learning theory and practice have found it fruitful to study wide random neural networks, such as neural network inspired Gaussian Processes, signal propagation in DNNs, small learning rate SGD dynamics, and even, in some sense, the celebrated Approximate Message Passing algorithm for compressed sensing. We review these subjects and more in Section 2. All of these works involve some theory that derives, rigorously or semirigorously, some scaling limit of a neural network as its width goes to infinity. In this paper, we give a unifying treatment to such scaling limits: {itemize*}
We define a notion of tensor programs which can express most neural network computations, and a natural notion of tensor program scaling that corresponds to increasing width with Glorot initialization Glorot & Bengio (2010). Our main theorems characterize the scaling limits in the two most common scenarios that roughly correspond to DNN inference and backpropagation, as well as in the general tensor program case. They are proved via a Gaussian conditioning technique first used in Bolthausen (2012) for analyzing the TAP equations in spin glass theory.
We obtain corollaries that fully justify semirigorous derivations in prior works and strengthen previous results in the different strands of research mentioned above. In the next section we highlight the most important corollaries and discuss other briefly, leaving their details to the appendix.
By standard architecture we mean any DNN architecture that is some composition of multilayer perceptrons (MLP)s, recurrent neural networks (RNNs) (e.g., LongShort Term Memory (LSTM) Hochreiter & Schmidhuber (1997) or Gated Recurrent Unit (GRU) Cho et al. (2014)), skip connections He et al. (2016); Huang et al. (2016), (self)attention Bahdanau et al. (2014); Vaswani et al. (2017), convolution LeCun et al. (1998, 1999), and/or batch normalization (batchnorm) Ioffe & Szegedy (2015). We use readout layer to mean any linear layer converting some hidden states to an output vector. While most of our corollaries are stated for standard architectures, they are typically more general, but we just highlight the most relevant cases for a deep learning audience.
2 Related Works and Our Corollaries
We formulate informal versions of our main corollaries and comment on other results inline, marked by a star .
2.1 Gaussian Behavior of Wide Neural Networks
In 1995, Neal first discovered the Gaussian Process behavior of wide neural networks. He showed that under certain conditions, a singlelayer neural network with random parameters can converge in distribution to a Gaussian process as its width goes to infinity. Later works extended the conditions under which this scaling limit takes place (Williams, 1997; Le Roux & Bengio, 2007; Hazan & Jaakkola, 2015). Recently, Lee et al. (2018); Matthews et al. (2018) empirically and/or theoretically investigated analogous correspondences for infinite width, finite depth deep MLPs, and likewise Novak et al. (2018), for deep convolution networks. Daniely et al. (2016) also proved similar results in the framework of kernel methods, where they introduced a notion of “computational skeleton,” similar to tensor programs introduced here, that covers feedforward computation with no weightsharing (so that, for example, it can express locally connected networks but not convolutional networks)^{1}^{1}1even though they claim that dealing with weighttying is straightforward. It’s unclear what they had in mind, however, as there is a significant difference in the scaling behavior of sharing matrix transposes vs sharing no matrix transposes (see Section 5 and Section 6).
Many previous works have exploited this DNNGP correspondence implicitly or explicitly to build new models (Cho & Saul, 2009; Lawrence & Moore, 2007; Damianou & Lawrence, 2013; Wilson et al., 2016b, a; Bradshaw et al., 2017; van der Wilk et al., 2017; Kumar et al., 2018; Blomqvist et al., 2018; Borovykh, 2018). In particular, Lee et al. (2018); GarrigaAlonso et al. (2018); Novak et al. (2018) directly converted DNN to GP using this correspondence. Lee et al. (2018) constructed the stateoftheart (SOTA) permutationinvariant GP on MNIST, and Novak et al. (2018) achieved SOTA on CIFAR10 for any GP with untrainable kernel.
In this paper, we generalize the DNNGP correspondence to standard architectures and very general nonlinearities. {cor}[DNNGP correspondence, informal] Let be a network of fixed standard architecture, with linear readout layer, and with nonlinearities bounded uniformly by for some . Fix a finite input set of the right signature (e.g. set of batches for batchnorm network; set of sequences for RNN). Sampling ’s parameters from iid Gaussians induces a distribution of functions on . If the readout layer weights are sampled independently from hidden parameters, then this distribution weakly converges to a Gaussian process as the network widths go to infinity (with fixed input and output dimensions). See Sections D.4 and D.1.
In contrast, Matthews et al. (2018) requires to be linearly bounded in norm; Daniely et al. (2016) requires be twicedifferentiable with all bounded, or that ReLU; and a sufficient condition given in Novak et al. (2018) is that exists and is bounded by , though it is unclear how the more general set of 3 conditions given there compares with ours.
2.2 Signal Propagation in Neural Networks
Glorot & Bengio (2010); He et al. (2015) derived the popular Glorot and He initializations from consideration of hidden state norms in a DNN with random weights. A recent line of work generalizes their studies significantly by examining the evolution with depth of covariance between and between for distinct inputs and , when the DNN is wide and parameters of are randomized. This evolution is referred to as (forward and backward) signal propagation in the literature (Poole et al., 2016; Schoenholz et al., 2017; Yang & Schoenholz, 2017, 2018; Hanin & Rolnick, 2018; Chen et al., 2018; Yang et al., 2018; Pennington et al., 2017). It has been used to optimize initialization hyperparameters to prevent gradient explosion/vanishing, even to allow training of a 10,000 layer CNN without batchnorm or skip connections (Xiao et al., 2018).
Suppose is a set of inputs. Let be an layer MLP with activation and uniform width . If and , with expectation taken over for each layer , then the signal propagation literature posits that, in the limit, the dynamics of and are summarized by
(1)  
(2) 
Note that essentially is the kernel of the corresponding GP, and Eq. 1 is the same one used in the DNNGP correspondence. Pennington et al. (2017) more generally computed the singular value distribution of the inputoutput Jacobian matrix of an MLP and characterized conditions under which this distribution concentrates around 1. To make this computation and to derive Eq. 2, they and others (Schoenholz et al., 2017; Yang & Schoenholz, 2017, 2018; Chen et al., 2018; Xiao et al., 2018; Pennington et al., 2017) relied on {assm}[Gradient Independence Assumption] In backpropagation, whenever we multiply by for some weight matrix , we multiply by an iid copy instead. that was first discovered by Schoenholz et al. (2017) to make good predictions and later formulated explicitly by Yang & Schoenholz (2017). In this paper we show {cor}[Section 2.2 is conditionally correct, informal] In a MLP having nonlinearities with polynomially bounded weak derivatives, Section 2.2 leads to the correct equation Eq. 2 and the correct singular value distribution computation from Pennington et al. (2017), as long as the readout layer is sampled independently from other parameters and has mean 0. In general, Section 2.2 does not induce correct computations – for example when the last layer is global mean pooling – and we rigorously give the correct equations, and more generally a way to compute the singular value distribution of the neural network Jacobian, both generalized to all standard architectures without batchnorm. See Sections D.5 and LABEL:{subsec:warmupConseq}.
As an example, we computed the scaling limit for the gradient norms of an LSTM and compared it against empirical simulation (Fig. 1). The theoretical prediction is very precise already for 1000 neurons, which is typical for applications of LSTM.
Note that this literature also studies the limit of iterating Eqs. 2 and 1 (large depth limit), but our results only apply to a fixed number of iterations, and so do not rigorously justify such limits.
Chen et al. (2018) estimates the signal propagation in tiedweights RNNs with the equations for that in untiedweights RNNs. They find this a fantastic approximation for simple RNNs but not quite so for gated RNNs. As a corollary of Section 4 we show that, indeed, the tied and untiedweights theories agree for simple RNNs, but not for general (say, gated) RNNs. We give the simplest counterexample of weighttied residual network. See Section D.6.
Recently, Li & Nguyen (2018) investigated (forward only) signal propagation in weighttied autoencoders. A version of their main theorem allowing for arbitrary polynomially bounded activations, without restriction on smoothness, also follows as corollary of Section 6. See Section D.6.
2.3 Neural Tangent Kernel
For any parametrized function , the Neural Tangent Kernel can be in general defined as (Jacot et al., 2018). In the case when is a feedforward neural network, with parameters appropriately scaled (see Section D.1), there is a scaling limit of when is randomized and ’s widths grow to infinity Jacot et al. (2018). This convergence allows one to predict the evolution of due to gradient descent on . For example, if we apply gradient flow on a training set and loss function , for , Jacot et al. (2018) derived
where is the “ground truth“ function that sends for every , and and are thought of dimension vectors. Jacot et al. (2018) proved that under suitable conditions, with training time fixed and width , for all . This means that, in the large width regime, (in the function space) evolves approximately according to a linear differential equation under gradient flow. In this paper we show {cor}[NTK convergence, informal] Fix a finite input set . Let be a network of fixed standard architecture, with linear readout layer, and having nonlinearities with polynomially bounded weak derivatives (so in particular cannot have batchnorm). Then over , almost surely as the widths of go to infinity and suitably randomized, for some . See Sections D.7 and D.1.
While Jacot et al. (2018) is groundbreaking in producing an equation to predict the behavior of gradient descent in the small leraning rate, large width regime, its proof of the convergence is incomplete, in that it implicitly assumes gradient independence (Section 2.2). Thus Section 2.3 simultaneously completes its proof and generalizes to arbitrary standard architectures. We give an example computation of the NTK for a CNN in Section D.7; this is a new result that has not appeared in prior literature.
Amari et al. (2018); Karakida et al. (2018) recently used Section 2.2 to study the empirical Fisher information matrix (FIM), over finitely many datapoints drawn from isotropic Gaussian, of random neural networks, specifically its spectral properties. If we let be the matrix whose rows are , then (empirical) FIM while NTK is ^{2}^{2}2Karakida et al. (2018) called NTK the dual matrix. Thus the spectral properties of empirical FIM and NTK are identical up to scaling. By Section 2.3, we can then justify the computations of Amari et al. (2018); Karakida et al. (2018) rigorously.
2.4 Other Works
Very recently, Du et al. (2018b, a); AllenZhu et al. (2018b, c) formally proved that GD or SGD can reduce an overparametrized DNN’s training error to 0 by showing that random initialization imbues the network with certain good properties^{3}^{3}3using tools similar to ones in the signal propagation literature and, with small learning rate, the network never moves too far from its initialization^{4}^{4}4using reasoning similar to Jacot et al. (2018). AllenZhu et al. (2018a) also shows generalization bounds for 3 layer networks using a similar reasoning.
There is a long line of work investigating random classic spiking or hopfield networks, for example Landau & Sompolinsky (2018); Crisanti & Sompolinsky (2018); Kadmon & Sompolinsky (2015); Stern et al. (2014); Rajan et al. (2010); Sompolinsky et al. (1988); Amit et al. (1985). In the reinforcement learning literature, Osband et al. (2018); Burda et al. (2018a, b) used random DNNs for exploration. Other than the works discussed above, Li & Saad (2018); Giryes et al. (2016); GabriÃ© et al. (2018); Reeves (2017); Fletcher & Rangan (2017) also considered neural networks with random weights.
Our technique is general enough to rederive the semicircle law for the Gaussian Orthogonal Ensemble and the MarchenkoPastur Law for Wishart matrices Tao (2012). See Sections D.3 and D.2.
Approximate Message Passing is an algorithm for recovering a ground truth vector from noisy measurements (Compressed Sensing) (Donoho et al., 2009). In one view, the algorithm repeatedly applies a certain neural network to the noisy measurement, and it succeeds if the result eventually converges to the ground truth vector. Previous works have shown that when the measurement matrix is randomized and the dimension goes to infinity, this algorithm satisfies a set of equations called State Evolution that can be used to reason about its behavior (Bayati & Montanari, 2011; Berthier et al., 2017). Their proofs are based on the same Gaussian conditioning technique used here. In Section D.8, we detail the algorithm and State Evolution, and prove the validity of State Evolution equations for arbitrary polynomially bounded nonlinearities and test functions, removing the smoothness assumption of Bayati & Montanari (2011) (in exchange for a stronger moment condition on the measurements).
This concludes the discussion of related works and our corollaries. We now present the tensor program framework and our main theorems.
3 Tensor Programs
Consider programs of the following form, which we call tensor programs. Each line contains an assignment and a dimension annotation and can have the following types.
 VecIn

(G) a vector input
 MatIn

(A) a matrix input
 T

(A) transpose of an Avar
 MatMul

(G) if and have , then an assignment via a linear mapping
or similarly for Hvars
where
 LinComb

(G) if , then an assignment via linear combination of Gvars that appeared in previous lines: with ,
 Nonlin

(H) if , then an assignment via some general (possibly nonlinear) function , acting coordinatewise,
Here (G) marks those variables that we call Gvars, and similarly we have Avars and Hvars. Vars introduced by 3 and 3 are also called (vector and matrix) input vars. The initial “” marks the line number, and each new variable formed from this line is labeled with a superscript . A partial program with and input G and Avars unspecified is called a (program) skeleton, typically denoted by Greek letters like . This skeleton can be thought of as a generalization of the skeleton in Daniely et al. (2016) in the language of a straightline program that allows weight sharing (transposed or not) and simple type checking.
3.1 Examples
Such tensor programs can express the computation in most neural network scenarios. In Appendix B, we give example programs for computations in {enumerate*}[label=(0)]
MLP, forward and backward passes (B.1);
batch of input (B.2);
residual networks (B.3);
simple RNN (B.4);
batchnorm (B.5);
3.2 Setup
Lines of type 3, 3, 3, and 3 induce equality constraints on the dimensions . Given a skeleton and a possible set of additional dimensional constraints , consider the smallest equivalence relation on Gvars such that if or if they are constrained to have equal dimension by some line of type 3, 3, 3, or 3. We call each class a common dimension class (CDC) of and write for the class of a Gvar . The collection of all common dimension classes is written as or just when and are understood from context. An algorithm to compute CDCs is presented in Appendix A.
In this work, for every skeleton (equipped with ), we study the behavior of vars in its realizations when the input vars are appropriately randomized and as the dimensions . More precisely, we consider a sequence (in ) of dimensions respecting , along with input G and Avars , of appropriate dimensions. For each , let for . We extend the notations and to the noninput G and Hvars computed from these inputs.
At time , we sample independently for a set ^{5}^{5}5 We could as well assume that there is an infinite 2D array of independent Gaussian variables , and at time , set . In that case, we do not need to increase stricty with . For each , we also sample independently for each . Here is the set of input Gvars in , , and are specified mean and covariance at time .
Thus given , the data , , and realize a random program . Its vars are random variables and our theorems will concern certain “moments” of them.
We assume that, as , for all : {enumerate*}[label=(0)]
is increasing with and .
for some constant depending only on .
for some finite for each input Avar .
and for some finite , and for all large .
Discussion
Tensor programs are meant to represent the “body” of a neural network where all dimensions are large compared to input and output dimensions. The CDCs are used to capture the varying widths of practical neural networks; for example, while widths typically increase and decrease in classical networks, they are held constant in blocks in residual networks (see Section B.3 for an example). For the first readthrough, we recommend the reader to assume all dimensions are the same so that there is a single CDC consisting of all Gvars.
The sampling of Avars reflects variants of Glorot initialization Glorot & Bengio (2010) used in practice. The sampling of input Gvars models the distribution of the first hidden layer across multiple inputs, sampling of the first layer parameters (see Section B.1 for an example), and/or sampling of bias vectors. Most often, the vector vars should be thought of as hidden layer quantities whose dimensions go to infinity; neural network inputs (of fixed dimension) are indirectly expressed as above, and outputs (of fixed dimension) are obtained as some coordinates of a vector var.
Sections 5 and 4 below say that, under certain conditions, Gvars converge to Gaussians of specific mean and covariances (hence the name “Gvar”). But Section 6 shows that in general this may not be true.
Notation
We will often identify functions with vectors in (which should be thought of as dictionaries with keys in ). Given a subset , is the subvector of supported on , or as a function is the restriction of on . For , , i.e. we automatically ignore the values of outside . We use for convergence almost surely.
4 Programs with No Transposes
For any , we recursively define
(3) 
and recursively define
(4) 
where , and . We also make branch 4 cover the case when or by “typecasting” to an Hvar and setting (similarly for ). Note that, as discussed in Notations above, will ignore irrelevant components of , and the expectations only depend on the entries of and that correspond to alreadydefined values, so this describes a valid recursion.
We introduce the following technical assumption.
[Almost Sure Rank Convergence] For any and any collection , let be the matrix whose columns are or for each or in . If converges almost surely to some with , then almost surely for all large .
If we don’t have lines of type 3, and no is a polynomial, then the s are all full rank, implying rank convergence by the upper semicontinuity of rank. 3 lines may add linear dependencies, but they are constant with , so that in the limit and we still have rank convergence.
For , a function is said to be controlled if for some , for all .
thmnotransposeLimit Consider dimension constraints and a skeleton without 3 lines, i.e. no transpose allowed. Suppose all are controlled for some . Sample all input vars as in Section 3.2 and assume almost sure rank convergence. Then for any and any controlled function , ,
where and .
Discussion
Roughly speaking, Gvars created from the same matrix have nonzero correlations, but otherwise are asymptotically independent modulo 3. Intuitively, for large , iid for each .
There is an apparent contradiction in Section 4 if we consider deep linear networks with tied () and untied weights (). Via simple computations of and , one sees that, by Section 4, as the width , is distributed “similarly” in either case (in that controlled moments match asymptotically). This seems to contradict our intuition that should blow up or decay exponentially, with , along the direction of the eigenvector of corresponding to the largest eigenvalue of ; whereas in the untied case it’s easy to see that each converges in distribution to an i.i.d. Gaussian.
This apparent paradox is resolved by noting that Section 4 only applies for fixed skeletons (so fixed in this example), as widths . By Rider (2003), the maximum eigenvalue of scales like if , and so does that of for fixed . Furthermore, as increases, the components of corresponding to large eigenvalues () of decrease in magnitude to 0 in probability, by the circular law Tao (2012). So the at which the exponentiating effect of kicks in increases with .
5 Backprop with Zero Mean Gradients


Let be a skeleton with lines but no 3. WLOG, suppose all input vars appear at the start, with matrix inputs first, as in Fig. 1(a). Consider an extension of in the following way: The first few appended lines are transposes of Avars in , followed by a series of new vector input vars , as in Fig. 1(b). Lines appended below this can be arbitrary noninput lines except that {enumerate*}[label=(0)]
lines of type 3 must use a transposed matrix to and or must have been introduced after (i.e. ), and
any for , as a function of , must be odd: for any fixed values of ; likewise must be odd for . This in particular means that 3 lines cannot involve for .
If expresses the forward computation of an NN without matrix transposes, then has enough power to express the backpropagation of and compute the gradients with respect to hidden states. For example, if , so that , then is an odd function of and can be expressed as a as above (see Section B.1 for a concrete example). In general, the multiple allow for multiple NN outputs.
CDCs are naturally defined for (see Appendix A) just like before. We extend and to vars introduced in : For and set and when or , set
where , and branch 4 covers the case when or by taking or to be identity. Note that covariances between vars of and new vars in are 0.
thmgradIndep Sample with zero mean (i.e. for all ) and independently from the input vars (i.e. if ) ^{6}^{6}6In our previous example of , this corresponds to the readout layer sampled with zero mean and independently from and other parameters of .. Sample all other vars in according to Section 3.2. Assume all of are polynomially bounded and satisfies almost sure rank convergence. Then for any dimension constraints , any , and any polynomially bounded function ,
where and .
Note that our result here does not apply to batchnorm, whose Jacobian has a singularity on a 1dimensional affine subspace (and in particular, at the origin). This theorem allows one to justify the gradient independence assumption rigorously; see Section D.5.
6 General Tensor Programs
Section 5 does not give the correct computation if do not have zero mean: Consider a onehiddenlayer MLP with quadratic activation, Then . If , and , then . If we have assumed Section 5 is correct, then we would have (incorrectly) computed where is an iid copy of .
Below, we develop a theory of the scaling limit of general tensor programs, from which follows the correct way of computing gradients when do not have 0 mean.
We first introduce “extended syntax” programs, which are equivalent semantically to programs of original syntax, and then show that we can “compile” original syntax programs to extended syntax programs with no transposes, with the same scaling limit in a suitable sense.
Extended syntax programs are those that allow all line types of Section 3 and in addition
 Comp

(H) if , then an assignment via some general (possibly nonlinear) function
where are previous G or Hvars, and acts coordinatewise.
So in essence, extended syntax programs just allow 3 lines to take Hvars in addition to Gvars. While in the original syntax, Hvars must feed into lines of type 3, in extended syntax they can also be used to create new Hvars via coordinatewise action.
One can define CDCs for extended syntax programs just as before (see Appendix A). Each extended syntax program is equivalent to an original syntax program, by expanding the definition of each Hvar to a function of previous Gvars. For example, if , and , then the expanded definition of is . We call this the expanded definition of , and write for this expanded function, so that for some ; for Gvars, we also define . (In our example above, ). So by replacing each line of type 6 with its expanded definition, we can convert an extended syntax program to an original syntax program with the same semantics for all vector vars.
Let be an original syntax skeleton with associated sampling data We define an extended syntax skeleton , called the detransposition of , by induction on line number as follows. During this process, we keep track of an injective mapping taking vector (resp. matrix) vars of to vector (resp. matrix) vars of , along with a specialized mapping taking a Gvar of produced by 3 to a Gvar of . We use a check to denote objects of the detransposition. We also simultaneously set , and of the detransposition. They propagate according to the usual rules, Eqs. 3 and 4, to determine and . Let be the current line number of we are processing, and let denote the 1 + length of the current (this is where we are adding new lines in ).

If is a 3 line, then add a line of the same type to , . Set , where and . Set for all input Gvars with .

If is a 3 line, then add to the line . Set and for all .

If is a 3 line, then add to an input line for a new input sampled iid as . Set and

Suppose is a line of type 3 in , where is some previous var. Consider the Avar where if is an input Avar, or if is a transposed var. Let be all previous lines of type 3 involving , where can be G or Hvar. Define by
where (the expectation will only depend on components of corresponding to previous lines). Compute , where Then we add the following to :
If are all Gvars, we typecast line to 3 and write instead. Set and
See Section B.1.1 for a concrete example of detransposition . Below, for any , let be , i.e. the collection of H or Gvars with the same dimension constraint; see Appendix A. {restatable}thmgeneralTensorP Let be an original syntax program with sampling instructions, and be the detransposition of , with the mapping from vector vars of to vector vars of . Assume all of are polynomially bounded, and that almost sure rank convergence holds for . Sample input vars of according to Section 3.2. Then for any dimension constraints , any , and any polynomially bounded function ,
where is the sequence of the th coordinates of all vector vars in , and