Mean field theory for deep dropout networks: digging up gradient backpropagation deeply
In recent years, the mean field theory has been applied to the study of neural networks and has achieved a great deal of success. The theory has been applied to various neural network structures, including CNNs, RNNs, Residual networks, and Batch normalization. Inevitably, recent work has also covered the use of dropout. The mean field theory shows that the existence of depth scales that limit the maximum depth of signal propagation and gradient backpropagation. However, the gradient backpropagation is derived under the gradient independence assumption that weights used during feed forward are drawn independently from the ones used in backpropagation. This is not how neural networks are trained in a real setting. Instead, the same weights used in a feed-forward step needs to be carried over to its corresponding backpropagation. Using this realistic condition, we perform theoretical computation on linear dropout networks and a series of experiments on dropout networks. Our empirical results show an interesting phenomenon that the length gradients can backpropagate for a single input and a pair of inputs are governed by the same depth scale. Besides, we study the relationship between variance and mean of statistical metrics of the gradient and shown an emergence of universality. Finally, we investigate the maximum trainable length for deep dropout networks through a series of experiments using MNIST and CIFAR10 and provide a more precise empirical formula that describes the trainable length than original work.
Deep neural networks have achieved exceptional results in a range of fields since its inception lecun2015deep (). Recent seminal innovations have been proposed to improve the performance of neural networks further. For example, residual networks he2016deep () and batch normalization ioffe2015batch (), which were introduced to overcome the gradient vanishing and exploding problem, enabled the trainable length to be very deep. Another technology is the dropout srivastava2014dropout (), which is a regularization technique for reducing the over-fitting problem. It is also the focus of this paper. In dropout network, units are randomly dropped during training, which can prevent complex co-adaptations srivastava2014dropout ().
More recently, we have witnessed several signs of progress made using mean field theory poole2016exponential (); schoenholz2016deep (); pennington2017resurrecting () in deep learning. The mean field considers networks after random initialization, whose weights and biases were i.i.d. Gaussian distributed, and the width of each layer tends to infinity. As a result of studying signal propagation under mean field theory, an order-to-chaos expressivity phase transition split by a critical line has been found poole2016exponential (). Later, how parameter initialization may impact the gradient of backpropagation was studied, and the conclusion that the ordered and chaotic phases correspond to regions of vanishing and exploding gradient respectively was shown schoenholz2016deep (). The results were also equivalently applied to networks with or without dropout.
The main contribution of the mean field theory for random networks is that it shows the existence of depth scales that limit the maximum depth of signal propagation and gradient backpropagation. Practically, the result is to show a hypothesis that random networks may be trained precisely when information can travel through them. Thus, the depth scales provide bounds on how deep a network may be trained for a specific choice of hyper-parameters schoenholz2016deep (). This ansatz was tested and verified by practical experiments on MNIST and CIFAR10 dataset with wide width fully-connected networks schoenholz2016deep (), deep dropout networks schoenholz2016deep (), and residual networks yang2017mean ().
However, the mean field calculation for the gradient is based on the so-called gradient independence assumption, which states that the weights used during feed forward are drawn independently from the ones used in backpropagation. This is in an effort to make the calculation of gradient feasible regardless of the choice of activation functions. This assumption was later formulated explicitly yang2017mean () for residual networks and was illustrated in a review yang2019scaling (). While it enjoys the correct prediction of gradient dynamics in some cases, our experiments show that under the condition in which the weights in feed-forward are carried over to its backpropagation, the length that gradients can backpropagate for a single and a pair of inputs are governed by the same depth scale on deep dropout networks instead.
By further studying the mean and variance of gradient statistics metrics on deep dropout networks, we show an emergence of universality for the relationship between the mean and variance. This universality exists regardless of the choice of hyper-parameters, including dropout rate and activation function. After summarizing the theoretical results about the trainable length of deep dropout networks governed by maximum depth of signal propagation and gradient backpropagation, we perform a series of experiments to investigate it. Empirically, we find a more precise way to describe the maximum trainable length for deep dropout networks, compared with the original results schoenholz2016deep ().
2 Related Work
The mean field theory has been applied to different network architectures, including CNNs lecun1995convolutional (), RNNs mikolov2010recurrent (), Residual networks he2016deep (), Batch normalization ioffe2015batch (), LSTM gers1999learning (), and GRUs chung2014empirical (). These networks have been investigated by xiao2018dynamical (); chen2018dynamical (); yang2017mean (); yang2019mean (); gilboa2019dynamical (), respectively, which form a large family of the mean field theory for deep neural networks.
Following the mean field theory, pennington2017resurrecting () studied all singular values of the input-output Jacobian and found a strong connection between dynamical isometry and fast training speed. Later, the analysis of the spectrum of input-output Jacobian has been developed to provide a detailed analytic understanding pennington2018emergence () and a nonlinear random matrix theory for deep learning pennington2017nonlinear (). The study of the spectrum of input-output Jacobian is based on the mean field theory, which will not be addressed in this work since it is trivial to extend the analysis method by pennington2017resurrecting () to the dropout networks.
In contrast to the mean field theory view to the random networks, daniely2016toward () studied the relationship between random networks and kernels while lee2017deep (); matthews2018gaussian () adopted another view of Gaussian processes (GPs) in the realm of Bayesian learning. The correspondence between single infinite neural networks and Gaussian process was first observed by neal1996priors (). Moreover, a mean field study of the dynamics of networks in the infinite width limit jacot2018neural () has achieved great success by lee2019wide () recently.
Finally, dropout training in deep neural networks can be viewed as approximate Bayesian inference in deep Gaussian processes gal2016dropout (). Further, dropout can be used in the Neural Network GP lee2017deep (). While this topic is interesting, we do not include the Bayesian learning of random dropout networks in our work.
In this section, we review the mean field theory for deep dropout networks. We give the main definitions, setup, and notations, and introduce the results of theory for random networks at initialization, including signal feed-forward and gradient backpropagation, respectively.
3.1 Feed Forward
Consider a feed-forward, fully-connected, untrained, and dropout network of depth with layer width . We denote synaptic weight and bias for the -th layer by and ; pre-activations and post-activations by and respectively. Finally, we take the input to be and the dropout keep rate to be . The information propagation in this network is governed by,
where is the activation function and Bernoulli(). We adopt the mean field theory assumption poole2016exponential (); schoenholz2016deep (), where , , and the width tends to infinite. Since the weights and biases are randomly distributed, these equations define a probability distribution on the pre-activations over an ensemble of untrained neural networks. Under the mean field approximation, can be replaced by a Gaussian distribution with zero mean.
Consider a single input , where the subscript refers to the index of input. We define the length quantities , which is the mean squared pre-activations. According to the mean field approximation, the length quantity is described by the recursion relation,
where is the measure for a normal distribution. This equation describes how a single input evolves through a random neural network. To study the property of evolution, we investigate the fixed point at . One way to estimate the fixed point is to plot Equation (2) with the unity line, and the intersection is the fixed point. We show the result for Equation (2) with Linear dropout network and Tanh dropout network in Figure 1(a)(b). Note that the smaller the dropout rate , the larger the fixed point value .
The propagation of a pair of inputs and , where the subscript and refer to different inputs, can be studied by looking at the correlation between the two inputs after layers. We definite this correlation quantity as . Similarly, the correlation will be given by the recurrence relation,
where and , with
This equation also have a fixed point at . It is known that when , while when schoenholz2016deep (). We show the result of Equation (4) on the ReLU and Erf dropout networks in Figure 1(c)(d), which demonstrate the main conclusion about fixed-point without () and with () dropout.
The main contribution of mean field theory for the fully-connected networks without dropout () is that it presents a phase diagram, which is determined by a crucial quantity,
This quantity was firstly introduce by poole2016exponential () to determine whether or not the is an attractive fixed point. When , the fixed point is unstable. Conversely, when , the fixed point is stable. Thus, the critical line separates two phases. One is the chaotic phase (), where a pair of inputs end up asymptotically decorrelated, and the other is the ordered phase, in which a pair of inputs end up asymptotically correlated.
We give a comment on the difference between and here. The random networks in the infinite width limit can be viewed as the Gaussian processes, where and are the diagonal and non-diagonal elements of the compositional kernellee2017deep (), respectively. Intuitively, the non-diagonal element of the kernel measures the correlation between different data points while the diagonal component measures the information of one input itself.
The study of information propagation shows the existence of a depth-scales , which represent the length of propagation of the following qualities:
where , with , where and . Intuitively, the depth-scales measures how far can correlation between two different inputs survives through the network.
3.2 Back Propagation
There is a duality between the forward propagation of signals and the backpropagation of gradients. Given a loss , we have
where . We define the metric of gradient for both a single input and a pair of inputs cases:
Within mean field theory, the scale of fluctuations of the gradient of weights in a layer will be proportional to , which can be written as, schoenholz2016deep (). On the other hand, the correlation between gradients of a pair of inputs will be proportional to , namely, .
In order to work out the recurrence relation for and , an approximation was made schoenholz2016deep (), named gradient independence assumption, that the weights used during forward propagation are drawn independently from the weights used in backpropagation. In this way, the term , and in Equation (7) can be addressed independently. Then, the recurrence behavior of and are achieved,
where we redefine the quantity for the dropout networks,
Equation (9) has an exponential solution with,
Similar to the signal propagation, gradient backpropagation can limit the trainable length in the way of gradient vanishing or gradient exploding, which is measured by the depth-scales and .
4 Gradient Backpropagation
In this section, we first calculate the metrics of gradient and theoretically without the gradient independence assumption on linear dropout networks. We then conduct a series experiment for metrics of gradient on deep dropout networks, including non-linear cases. Finally, we show an emergence of a universal relationship between mean and variance of metrics of the gradient.
4.1 Breaking the gradient independence assumption
We follow the fact that weights used in a feed-forward are carried over to its back-propagation. We first provide a theoretical treatment to the linear networks in which we assume the output is the last layer of network without soft-max. The labels of data are set to be zeros, and the loss is the mean squared loss.
For space reason, we omit details of the calculation and present the primary analysis and final results here. The main problem is that we should expand when calculating in Equation (7), since can correlate with without the gradient independence assumption. Using as an example, we perform:
Starting from the last layer , we compute and use this result to compute .
Then we compute with the result of and .
By parity of reasoning, we obtained the results for the penultimate layer . The correlation between terms that contain and are considered.
As the index of the layer decreases, the amount of calculation becomes larger and larger. Thus we use the induction method to achieve the results for the left layers.
We use the same approach to derive the result for . As a result, we have,
By analyzing the first formula of Equation 12, we find that . This can be better observed by dividing the expression related to layer into two factors: one is , and the other is . The first factor accounts for , where for linear dropout networks. And second factor will be stable after several layers starting from the last layer due to . We show an excellent match between the theoretical calculation above with simulation using networks with width and layer over 100 different instantiations of the network in Figure 2(a).
Despite the successful prediction of theoretical calculation for , our theoretical results for only hold on the case of while fail to predict the experimental behavior except for last few layers when , as shown in Figure 2(b)(c). After a few layers from , the variances began to increase dramatically as shown in Figure 2(c). We noticed that unlike the case of computing , using is prohibitive for computing . On the other hand, we try a function regarding to fit , and find an interesting observations that is a much more compatible term for , i.e, . This is demonstrated in Figure 2(d).
The incompatible phenomenon between theoretical calculation and experimental results for begins with the emergence of variance, as shown in Figure 2(c). One possible explanation is that the emergence of variance is caused by limited network length. Thus, we can reduce this variance by increasing network length only. To check if this explanation works, we further investigate the relationship between variance and mean of with different network widths . The answer is that holds regardless of the finite width. We will demonstrate it in the next section.
After studying the gradient behavior at the linear networks, we conduct a series of experiments on the nonlinear case since the theoretical formulation for nonlinear activation or with the soft-max layer is intractable. We firstly use as the metric of gradient and find it has a huge variance when . This is because the element of the gradient matrix with a pair of inputs can be either negative or positive. To find a metric with low variance, we consider the metric whose elements are all positive. Besides, it is the norm of the gradient matrix.
We plot and as a function of in Figure 3. Interestingly, our simulations show that both and are governed by in a range of activations. Thus we make a conjecture that the relation,
holds on deep dropout networks.
|Summary||feed-forward propagation||gradient backpropagation||empirical results|
|realistic condition (our work)||-|
|independent assumption schoenholz2016deep ()||-|
4.2 Emergence of Universality
We have studied three statistical metrics of the gradient, i.e. , , and using their mean value. Inevitably, the variance of these metrics can give us essential information about the gradient. To do this, we performed a series of experiments to obtain the mean and variance of , and with different activation and different network width .
First, we show the relationship between variance and mean of the metric of gradient with different activations, including Linear, ReLU, Tanh, and Hard Tanh. We denote the mean of , and as , , and , while naming the variance as , , and respectively. We show the variance as a function of mean in Figure 4, and find the emergence of universality between the variance and mean regardless of dropout rate and choice of activation for , , and .
The plot of variance as a function of mean shows a power-law between them since it is like a straight line in the log-log plot. To estimate the power, we use a simple equation to compare with the experiment results. Surprisingly, all three curves are consistent with . Thus we make a conjecture that the universal power coefficient between the variance and mean is 2.
Then, we investigate the relationship between variance and mean with different network width and show the results in Figure 5. This time, we perform experiments on the Tanh networks with different network width . Again, the relationship between variance and mean satisfies universality, which means the Equation (13) does not depend on the network width of .
We want to point out that we have performed the same investigation on and . However, we did not observe a similar universal relationship between variance and mean of and . This may occur due to the different behavior of () and (). As Equation 6 shows, the mean of will converge to a fixed point after several layers, which means that the mean of will be stable in deeper layers. So, we won’t expect a universal relation between the mean and the variance in this case.
In summary, we have tried all the parameter freedom that we can tune, the universal power coefficient between the variance and mean remains the same. We conclude that once the topological structure of the neural network is set, the power coefficient is universal.
According to the theoretical results, during feed-forward, we expect that length-scale control the propagation of , while measure the number of layers that gradient metrics and can survive during backpropagation. However, schoenholz2016deep () claimed that both networks with or without dropout networks have a limited trainable length, which is governed by the depth-scale . As our experimental results show, which be demonstrated later, this statement is not exactly right. To summarize, we present the comparison for the length-scale between schoenholz2016deep () and our work in Table 1.
5.1 Training speed
Before investigating this problem, we study the relationship between training speed and choice of hyper-parameters. We confine the hyper-parameters at the critical line for the network with and without dropout and train networks of a range of length with width for steps with a batch size of on the standard CIFAR10 dataset. Strictly speaking, is not the critical line when , since . For learning rates of each network, we consider logarithmically spaced in steps . To search the optimal learning rate, we select a threshold accuracy of and measure the first step when performance exceeds . We show the steps as a function of learning rate on the networks of dropout rate , and 0.98 in Figure 6.
We find that for networks without dropout, there is a universal scaling between the steps and learning rate, where is a scaling function, as shown in Figure 6(a). Note that it is different to the result that in pennington2017resurrecting () where they use the standard CIFAR10 dataset augmented with random flips, crops, and so on. The difference may be caused by the pretreatment of the dataset in pennington2017resurrecting (). Besides, we study the networks with and , and find that the scaling can be kept under a limited length for and for , as shown in Figure 6(b) and (c) respectively.
5.2 Trainable length
Now we study the problem of trainable length. We consider random networks of depth , and with . We train these networks using Stochastic Gradient Descent (SGD) and RMSProp on MNIST and CIFAR10 with Gaussian and Orthogonal weights, which can be seen as another variant of weight initialization in the mean field theory pennington2017resurrecting (). We perform four experiments on the network without dropout () with different datasets, optimizer, and learning rate to conduct a comprehensive study, and plot the results in Figure 7(a)-(d). Besides, four experiments are conducted on the dropout networks (), and results are shown in Figure 7(e)-(h). We color in bright yellow the training accuracy that networks achieved as a function of and for different dropout rates. From the heatmap, we can observe a boundary in which accuracy began to drop. We noticed that there are two boundaries, left and right. In order to show its relationship with and , we superimpose them onto the heatmap.
In figure 7(a), we use the same learning rate and optimizer as those in Figure 5(a)-(c) of schoenholz2016deep (). We use a learning rate of for SGD when , and for larger . From the plot, we find the underestimates the scope of train-ability in the - plane, while is more compatible with the experimental result. We note the phenomenon that underestimates the scope of train-ability also happened in Figure 5(b)(c) of schoenholz2016deep (). In figure 7(b), we adopt the same learning rate and optimizer as those in Figure 5(d) of schoenholz2016deep (), where we use a learning of and RMSProp optimizer. Here, the only difference is that we use 1000 training steps instead of 300 training steps in schoenholz2016deep (). According to the simulation result, (solid line) and (dashed line) are identical on the left boundary, while they differ on the right side. We make a comparison between and , and find that has a much better argument with the trainable length while overrates the trainable length on the right side.
Based on the analysis of Figure 7(a)(b), we may conclude that can be used to measure the maximum trainable length of the network without dropout. We further reinforce this conclusion by performing experiments on different learning rates, weight initialization, and datasets. In figure(c), we use orthogonal weight initialization. In figure(d), we perform experiment on CIFAR10 dataset and adopt a learning rate of , where is constant. These learning rates were selected for the reason that each learning rate can lead to the fast step to a certain test accuracy at , as shown in Figure 6. In a word, we attribute the maximum trainable length to , where the relation holds on the network without dropout.
Furthermore, we consider the dropout case in Figure 7(e)-(h). We have studied three different dropout rate: (Figure 7(e)), (Figure 7(f)(g)), and (Figure 7(h)). We find that both and have connections to the trainable length: the networks appear to be trainable when . Networks on the left side are influenced by while they are constrained by the on the right size. Note that the formula is valid in the no dropout case as discussed above. To conclude, we show an improved relationship between maximum trainable length and length scale and than schoenholz2016deep (). This conclusion that both and have connections to the trainable length instead of only schoenholz2016deep () is more compatible with the theoretical results.
In this paper, we have investigated the dropout networks by calculating its statistical metrics of gradient during the backpropagation at initialization and conjecture that both gradients metric with a single input and a pair of inputs are governed by the same quantity . We further investigate the relationship between variance and mean of statistical metrics empirically and find an emergence of universality. Our finding of a universal relationship between variance and mean of statistical metrics of gradient backpropagation suggests a deeper mechanism behind it. This mechanism may be comprehended better by studying more different network structures such as Resnet. Finally, for networks with or without dropout, we attribute the maximum trainable length to the formula , which is novel and important.
Appendix A Theoretical derivation of gradient metrics at initialization
a.1 Derivation of on linear dropout networks with a single input
We rewrite Eq (S1) as: