A Mean Field Theory of Batch Normalization

A Mean Field Theory of Batch Normalization

Greg Yang, Jeffrey Pennington, Vinay Rao, Jascha Sohl-Dickstein, & Samuel S. Schoenholz

Microsoft Research AI, Google Brain
gregyang@microsoft.com, {jpennin,vinaysrao,jaschasd,schsam}@google.com
Abstract

We develop a mean field theory for batch normalization in fully-connected feedforward neural networks. In so doing, we provide a precise characterization of signal propagation and gradient backpropagation in wide batch-normalized networks at initialization. We find that gradient signals grow exponentially in depth and that these exploding gradients cannot be eliminated by tuning the initial weight variances or by adjusting the nonlinear activation function. Indeed, batch normalization itself is the cause of gradient explosion. As a result, vanilla batch-normalized networks without skip connections are not trainable at large depths for common initialization schemes, a prediction that we verify with a variety of empirical simulations. While gradient explosion cannot be eliminated, it can be reduced by tuning the network close to the linear regime, which improves the trainability of deep batch-normalized networks without residual connections. Finally, we investigate the learning dynamics of batch-normalized networks and observe that after a single step of optimization the networks achieve a relatively stable equilibrium in which gradients have dramatically smaller dynamic range.

\iclrfinalcopy

1 Introduction

Deep neural networks have been enormously successful across a broad range of disciplines. These successes are often driven by architectural innovations. For example, the combination of convolutions (LeCun et al., 1990), residual connections (He et al., 2015), and batch normalization (Ioffe & Szegedy, 2015) has allowed for the training of very deep networks and these components have become essential parts of models in vision (Zoph et al., ), language (Chen & Wu, 2017), and reinforcement learning (Silver et al., 2017). However, a fundamental problem that has accompanied this rapid progress is a lack of theoretical clarity. An important consequence of this gap between theory and experiment is that two important issues become conflated. In particular, it is generally unclear whether novel neural network components improve generalization or whether they merely increase the number of hyperparameter configurations where good generalization can be achieved. Resolving this confusion has the promise of allowing researchers to more effectively and deliberately design neural networks.

Recently, progress has been made (Poole et al., 2016; Schoenholz et al., 2016; Daniely et al., 2016; Pennington et al., 2017; Hanin & Rolnick, 2018; Yang, 2019) in this direction by considering neural networks at initialization, before any training has occurred. In this case, the parameters of the network are random variables which induces a distribution of the activations of the network as well as the gradients. Studying these distributions is equivalent to understanding the prior over functions that these random neural networks compute. Picking hyperparameters that correspond to well-conditioned priors ensures that the neural network will be trainable and this fact has been extensively verified experimentally. However, to fulfill its promise of making neural network design less of a black box, these techniques must be applied to neural network architectures that are used in practice. Over the past year, this gap has closed significantly and theory for networks with skip connections (Yang & Schoenholz, 2017, 2018), convolutional networks (Xiao et al., 2018), and gated recurrent networks (Chen et al., 2018) have been developed. After the publication of this paper, Yang (2019) devised a unifying framework rigorously proving the mean field dynamics for all of the above as well as attention and other modern deep learning modules.

Before state-of-the-art models can be studied in this framework, a slowly-decreasing number of architectural innovations must be studied. One particularly important component that has thus-far remained elusive is batch normalization. In this paper, we develop a theory of random, fully-connected networks with batch normalization. A significant complication in the case of batch normalization (compared to e.g. layer normalization or weight normalization) is that the statistics of the network depend non-locally on the entire batch. Thus, our first main result is to recast the theory for random fully-connected networks so that it can be applied to batches of data. We then extend the theory to include batch normalization explicitly and validate this theory against Monte-Carlo simulations. We show that as in previous cases we can leverage our theory to predict valid hyperparameter configurations.

In the process of our investigation, we identify a number of previously unknown properties of batch normalization that make training unstable. In particular, we show that for any choice of nonlinearity, gradients of fully-connected networks with batch normalization explode exponentially in the depth of the network. This imposes strong limits on the maximum trainable depth of batch normalized networks that can be ameliorated by pushing activation functions to be more linear at initialization. It might seem that such gradient explosion ought to lead to learning dynamics that are unfavorable. However, we show that networks with batch normalization causes the scale of the gradients to naturally equilibrate after a single step of gradient descent (provided the gradients are not so large as to cause numerical instabilities).

Finally, we note that there is a related vein of research that has emerged that leverages the prior over functions induced by random networks to perform exact Bayesian inference (Lee et al., 2017; de G. Matthews et al., 2018; Novak et al., 2019; Yang, 2019). One of the natural consequences of this work is that the prior for networks with batch normalization can be computed exactly in the wide network limit. As such, it is now possible to perform exact Bayesian inference in the case of wide neural networks with batch normalization.

2 Related Work

Batch normalization has rapidly become an essential part of the deep learning toolkit. Since then, a number of similar modifications have been proposed including layer normalization (Ba et al., 2016) and weight normalization (Salimans & Kingma, 2016). Comparisons of performance between these different schemes have been challenging and inconclusive (Gitman & Ginsburg, 2017). The original introduction of batchnorm in (Ioffe & Szegedy, 2015) proposed that batchnorm prevents “internal covariate shift” as an explanation for its effectiveness. Since then, several papers have approached batchnorm from a theoretical angle, especially following Ali Rahimi’s catalyzing call to action at NIPS 2018. (Balduzzi et al., 2017) found that batchnorm in resnets allow deep gradient signal propagation in contrast to the case without batchnorm. (Santurkar et al., 2018) found that batchnorm does not help covariate shift but helps by smoothing loss landscape. (Bjorck et al., 2018) reached the opposite conclusion as our paper for residual networks with batchnorm, that batchnorm works in this setting because it induces beneficial gradient dynamics and thus allows a much bigger learning rate. (Luo et al., 2018) explores similar ideas that batchnorm allows large learning rates and likewise uses random matrix theory to support their claims. (Kohler et al., 2018) identified situations in which batchnorm can provably induce acceleration in training. Of the above that mathematically analyze batchnorm, all but (Santurkar et al., 2018) make simplifying assumptions on the form of batchnorm and typically do not have gradients flowing through the batch variance. Even (Santurkar et al., 2018) only analyzes a deep linear network which gets added a batchnorm layer at a single moment in training. Our analysis here works for arbitrarily deep batchnorm networks with any activation function used in practice111upper bounded by an exponential function, for example.. It is an initialization time analysis, but we use such insight to predict training and test time behavior.

3 Theory

We begin with a brief recapitulation of mean field theory in the fully-connected setting. In addition to recounting earlier results, we rephrase the formalism developed previously to compute statistics of neural networks over a batch of data. Later, we will extend the theory to include batch normalization. We consider a fully-connected network of depth whose layers have width , activation function222The activation function may be layer dependent, but for ease of exposition we assume that it is not. , weights , and biases . Given a batch of inputs333Throughout the text, we assume that all elements of the batch are unique. , the pre-activations of the network are defined by the recurrence relation,

(1)

At initialization, we choose the weights and biases to be i.i.d. as and We will be concerned with understanding the statistics of the pre-activations and the gradients induced by the randomness in the weights and biases. For ease of exposition we will typically take the network to have constant width .

In the mean field approximation, we iteratively replace the pre-activations in eq. (2) by Gaussian random variables with matching first and second moments. In the infinite width limit this approximation becomes exact (Lee et al., 2017). Since the weights are i.i.d. with zero mean it follows that the mean of each pre-activation is zero and the covariance between distinct neurons are zero. The pre-activation statistics are therefore given by where are covariance matrices. The covariance matrices are defined by the recurrence relation,

(2)

where computes the matrix of uncentered second moments of for . At first eq. (2) may seem challenging since the expectation involves a Gaussian integral in . However, each term in the expectation of involves at most a pair of pre-activations and so the expectation may be reduced to the evaluation of two-dimensional integrals. These integrals can either be performed analytically (Cho & Saul, 2009; Williams, 1997) or efficiently approximated numerically (Lee et al., 2017), and so eq. (2) defines a computationally efficient method for computing the statistics of neural networks after random initialization. This theme of dimensionality reduction will play a prominent role in the forthcoming discussion on batch normalization.

Eq. (2) defines a dynamical system over the space of covariance matrices. Studying the statistics of random feed-forward networks therefore amounts to investigating this dynamical system and is an enormous simplification compared with studying the pre-activations of the network directly. As is common in the dynamical systems literature, a significant amount of insight can be gained by investigating the behavior of eq. (2) in the vicinity of its fixed points. For most common activation functions, eq. (2) has a fixed point at . Moreover, when the inputs are non-degenerate, this fixed point generally has a simple structure with ] owing to permutation symmetry among elements of the batch. We refer to fixed points with such symmetry as Batch Symmetry Breaking 1 (BSB1) fixed points. As we will discuss later, in the context of batch normalization other fixed points with fewer symmetries may become preferred. In the fully-connected setting fixed points may efficiently be computed by solving the fixed point equation induced by eq. (2) in the special case . The structure of this fixed point implies that in asymptotically deep feed-forward neural networks all inputs yield pre-activations of identical norm with identical angle between them. Neural networks that are deep enough so that their pre-activation statistics lie in this regime have been shown to be untrainable (Schoenholz et al., 2016).

Notation

As we often talk about matrices and also linear operators over matrices, we write for an operator applied to a matrix , and matrix multiplication is still written as juxtaposition. Composition of matrix operators are denoted with .

To understand the behavior of eq. (2) near its fixed point we can consider the Taylor series in the deviation from the fixed point, . To lowest order we generically find,

(3)

where is the Jacobian of In most prior work where was a pointwise non-linearity one could consider the special case of which naturally gave rise to linearized dynamics in and . However, in the case of batch normalization we will see that one must consider the evolution of eq. (3) as a whole. This is qualitatively reminiscent of the case of convolutional networks studied in Xiao et al. (2018) where the evolution of the entire pixel pixel covariance matrix had to be evaluated. The dynamics induced by eq. (3) will be controlled by the eigenvalues of . Suppose has eigenvalues - ordered such that - with associated eigen“vectors” (note that the will themselves be matrices). It follows that if for some choice of constants then . Thus, if for all , will approach zero exponentially and the fixed-point will be stable. The number of layers over which will approach will be given by . By contrast if for any then the fixed point will be unstable. In this case, there is typically a different, stable, fixed point that must be identified. It follows that if the eigenvalues of can be computed then the dynamics will follow immediately.

At face value, is a complicated object since it simultaneously has large dimension and possesses an intricate block structure. However, the permutation symmetry of induces strong symmetries in that significantly simplify the analysis [A.2.4]. In particular is a four-index object, however for all permutations on and . We call linear operators possessing such symmetries ultrasymmetric and show that all ultrasymmetric matrices admit an eigen-decomposition that contains three distinct eigenspaces with associated eigenvalues [A.2.4].

Theorem 1.

Let be an ultrasymmetric matrix operator. Then it has the following eigenspaces,

  1. Two -dimensional eigenspaces whose eigenvectors have identical structure to ,

    (4)

    with eigenvalue .

  2. Two -dimensional eigenspaces whose eigenvectors are permutations of the matrix,

    (5)

    with eigenvalues .

  3. An eigenspace of dimension whose eigenvectors are of the form such that is symmetric and . The eigenvalue of all such eigenvectors is .

The eigenvalues as well as and are not arbitrary but depend on the specific choice of ultrasymmetric matrix. In the case of fully-connected networks, the number of distinct eigenspaces reduces to two whose eigenvalues are identical to those found via the simplified analysis presented in Schoenholz et al. (2016) [A.2.4.2].

Similar arguments allow us to develop a theory for the statistics of gradients. The backpropogation algorithm gives an efficient method of propagating gradients from the end of the network to the earlier layers as,

(6)

Here are -dimensional vectors that describe the error signal from neurons in the ’th layer due to the ’th element of the batch. The preceding discussion gave a precise characterization of the statistics of the that we can leverage to understand the statistics of . It is easy to see that and where is a covariance matrix and we may once again drop the neuron index. We can construct a recurrence relation to compute ,

(7)

Typically, we will be interested in understanding the dynamics of when has converged exponentially towards its fixed point. Thus, we study the approximation,

(8)

Since these dynamics are linear, explosion and vanishing of gradients will be controlled by the eigenvalues of .

3.1 Batch Normalization

We now extend the mean field formalism to include batch normalization. Here, the definition for the neural network is modified to the coupled equations,

(9)

where and are the per-neuron batch statistics. In practice or so to prevent division by zero, but in this paper, unless stated otherwise (in the last few sections), is assumed to be 0. Unlike in the case of vanilla fully-connected networks, here the pre-activations are invariant to and . Without a loss of generality, we therefore set and for the remainder of the text. In principal, batch normalization additionally yields a pair of hyperparameters and which are set to be constants. However, these may be incorporated into the nonlinearity and so without a loss of generality we set and .

The arguments from the previous section can proceed identically and we conclude that as the width of the network grows, the pre-activations will be jointly Gaussian with identically distributed neurons. Thus, we arrive at an analogous expression to eq. (2),

(10)

Here we have introduced the projection operator which is defined such that with . Unlike , is does not act component-wise on . It is therefore not obvious whether can be evaluated without performing a -dimensional Gaussian integral.

We present a pair of results that simplify eq. (10) to a small number of integrals – independent of – over by finding integral transforms to relate the two functions. From previous work (Poole et al., 2016), can be expressed in terms of a two-dimensional Gaussian integrals independent of . When is degree- positive homogeneous (e.g. rectified linear activations) we can relate and by the Laplace transform [A.2.1.1].

Theorem 2.

Suppose is degree- positive homogeneous. For any positive semi-definite matrix define the projection . Then

(11)

Using this parameterization, when has a closed form solution then involves only a single integral. We further show that for any , can be related to by a Fourier transform at the expense of an additional integral to perform the change of variables [A.2.1.2].

Theorem 3.

For general with finite Gaussian moments,

(12)

Together these theorems provide analytic recurrence relations for random neural networks with batch normalization over a wide range of activation functions. By analogy to the fully-connected case we would like to study the dynamical system over covariance matrices induced by these equations.

We begin by investigating the fixed point structure of eq. (10). As in the case of feed-forward networks, permutation symmetry implies that there exist fixed points of the form . A low-dimensional integral expression for and can be obtained by transforming to hyperspherical coordinates [A.2.3.1].

Theorem 4.

For the fixed point satisfies,

(13)
(14)

where

(15)

While these equations allow for the efficient computation of fixed points for arbitrary activation functions, significant simplification occurs when the activation functions are -homogeneous [A.2.3.2]. In particular, for rectified linear activations we arrive at the following result.

Theorem 5.

When , there is a unique fixed point of the form with,

(16)

where is the arccosine kernel Cho & Saul (2009).

Together, these results describe the fixed points for most commonly used activation functions.

In the presence of batch normalization, when the activation function grows quickly, a winner-take-all phenomenon can occur where a subset of samples in the batch have much bigger activations than others. This causes the covariance matrix to form blocks of differing magnitude, breaking the BSB1 symmetry. One observes this, for example, as the degree of -relu increases past a point depending on the batch size . We examine this in more detail and give concrete examples in the appendix. However, by far most of the nonlinearities used in practice, like ReLU, leaky ReLU, tanh, sigmoid, etc, all lead to BSB1 fixed points. Thus from here on, we assume that any nonlinearity mentioned induces to converge to BSB1 fixed points.

3.1.1 Linearized Dynamics

With the fixed point structure for batch normalized networks having been described, we now investigate the linearized dynamics of eq. (10) in the vicinity of these fixed points. As in the vanilla setting, we leverage the properties of ultrasymmetric matrices; however, as a consequence of mean subtraction with batch normalization here there are only three unique eigenspaces with , and and in this case we label them , , and respectively. These eigenspaces have an intuitive interpretation and in particular captures the size of the batch ; captures the fluctuation between norms of the elements of the batch; captures the correlation subject to zero mean constraint.

To determine the eigenvalues of it is helpful to consider the action of batch normalization in more detail. In particular, we notice that can be decomposed into the composition of three separate operations, . As discussed above, subtracts the mean from and we introduce the new function which normalizes by its standard deviation. Applying the chain rule, we can rewrite the Jacobian as,

(17)

where denotes composition and is the natural extension of to act on matrices as . It ends up being advantageous to study and to note that the nonzero eigenvalues of this object are identical to the nonzero eigenvalues of the Jacobian [A.2.4].

As in the previous section there are two distinct ways to make progress on the spectrum of eq. (17). For arbitrary nonlinearity one can transform to hyperspherical coordinates which leads to tractable integral equations for the eigenvalues. The resulting equations for the eigenvalues can be evaluated, but are complicated and the specific form is relatively unenlightening [A.2.4.1] . In the case of positive-homogeneous activation functions we arrive at a relatively compact representation for the different eigenvalues [A.2.4.2]. Here, we summarize the results for rectified linear networks.

Theorem 6.

Let and . The eigenvalues for the different eigenspaces outlined above are

(18)
(19)
(20)

Together these eigenvalues along with the fixed point outlined in Theorem 5 completely characterize the statistics of pre-activations in deep networks with batch normalization.

3.1.2 Gradient Backpropagation

With a mean field theory of the pre-activations of feed-forward networks with batch normalization having been developed, we turn our attention to the backpropagation of gradients. In contrast to the case of networks without batch normalization, we will see that exploding gradients at initialization are a severe problem here. To this end, one of the main results from this section will be to show that fully-connected networks with batch normalization feature exploding gradients for any choice of nonlinearity such that a BSB1 fixed point. Below, by “rate of gradient explosion” we mean the rate at which the gradient norm squared grows with depth.

As a starting point we seek an analog of eq. (8) in the case of batch normalization. However, because the activation functions no longer act point-wise on the pre-activations, the backpropagation equation becomes,

(21)

where we observe the additional sum over the batch. Computing the resulting covariance matrix , we arrive at the recurrence relation,

(22)

where and we have defined the linear operator

for any vector-indexed linear operator . As in the case of vanilla feed-forward networks, here we will be concerned with the behavior of gradients when is close to its fixed point. We therefore study the asymptotic approximation to eq. (22) given by . In this case the dynamics of are linear and are therefore naturally determined by the eigenvalues of .

As in the forward case, batch normalization is the composition of three operations . Applying the chain rule, eq. (22) can be rewritten as,

(23)

with appropriately defined. Note that

, so that it suffices to study the eigendecomposition of . Due to the symmetry of , this operator is ultrasymmetric, so that its eigenspaces are and we can compute its eigenvalues as in Section 3.1.1. However, this computation is not so enlightening as to the dependence of these eigenvalues on the nonlinearity. We instead use the Laplace and Fourier methods to derive more explicit representations of the eigenvalues. Here we highlight our results on the max eigenvalue, , which determines the asymptotic dynamics of .

Theorem 7.

For any well-behaved nonlinearity such that converges to a BSB1 fixed point with depth , the gradient explodes asymptotically at the rate of

(24)

where .

Theorem 8.

In a ReLU-batchnorm network, gradients explode asymptotically at the rate

(25)

which decreases to as . In contrast, for a linear batchnorm network, the gradients explode asymptotically at the rate , which goes to 1 as .

Fig. 1 shows theory and simulation for ReLU gradient dynamics.

By noticing that the integral in Theorem 7 diagonalizes over the Gegenbauer basis, we obtain the following lower bound on the gradient explosion rate (Section A.3.1.1):

Theorem 9 (Batchnorm causes gradient explosion).

Suppose has the Gegenbauer expansion , normalized so that

(26)

Then

(27)

where for all . Consequently, for any non-constant (i.e. there is a such that ), ; minimizes iff it is linear (i.e. ), in which case gradients explode at the rate of .

This contrasts starkly with the case of non-normalized fully-connected networks, which can use the weight and bias variances to control its mean field network dynamics (Poole et al., 2016; Schoenholz et al., 2016). As a corollary, we disprove the conjecture of the original batchnorm paper (Ioffe & Szegedy, 2015) that “Batch Normalization may lead the layer Jacobians to have singular values close to 1” in the initialization setting, and in fact prove the exact opposite, that batchnorm forces the layer Jacobian singular values away from 1.

Figure 1: Numerical confirmation of theoretical predictions. (a,b) Comparison between theoretical prediction (dashed lines) and Monte Carlo simulations (solid lines) for the eigenvalues of the backwards Jacobian as a function of batch size and the magnitude of gradients as a function of depth respectively for rectified linear networks . In each case Monte Carlo simulations are averaged over 200 sample networks of width 1000 and shaded regions denote 1 standard deviation. (c,d) Demonstration of the existence of a BSB1 to BSB2 symmetry breaking transition as a function of for -homogeneous activation functions. In (c) we plot the empirical variance of the eigenvalues of the covariance matrix which clearly shows a jump at the transition. In (d) we plot representative covariance matrices for the two phases (BSB1 bottom, BSB2 top).
Effect of as a hyperparameter

In practice, is usually treated as small constant and is not regarded as a hyperparameter to be tuned. Nevertheless, we can investigate its effect on gradient explosion. A straightforward generalization of the analysis presented above to to the case of suggests somewhat larger values than typically used can ameliorate (but not eliminate) gradient explosion problems. See Fig. 4(c,d).

4 Experiments

Having developed a theory for neural networks with batch normalization at initialization, we now explore the relationship between the properties of these random networks and their learning dynamics. We will see that the trainability of networks with batch normalization is controlled by gradient explosion. We quantify the depth scale over which gradients explode by where, as above, is the largest eigenvalue of the jacobian. Across many different experiments we will see strong agreement between and the maximum trainable depth.

Figure 2: Batch normalization strongly limits the maximum trainable depth. Colors show test accuracy for rectified linear networks with batch normalization and , , , , and . (a) trained on MNIST for 10 epochs (b) trained with fixed batch size and batch statistics computed over sub batches of size . (c) trained using RMSProp. (d) Trained on CIFAR10 for 50 epochs.

We first investigate the relationship between trainability and initialization for rectified linear networks as a function of batch size. The results of these experiments are shown in fig. 2 where in each case we plot the test accuracy after training as a function of the depth and the batch size and overlay in white dashed lines. In fig. 2 (a) we consider networks trained using SGD on MNIST where we observe that networks deeper than about 50 layers are untrainable regardless of batch size. In (b) we compare standard batch normalization with a modified version in which the batch size is held fixed but batch statistics are computed over subsets of size . This removes subtle gradient fluctuation effects noted in Smith & Le (2018). In (c) we do the same experiment with RMSProp and in (d) we train the networks on CIFAR10. In all cases we observe a nearly identical trainable region.

Figure 3: Gradients in networks with batch normalization quickly achieve dynamical equilibrium. Plots of the relative magnitudes of (a) the weights (b) the gradients of the loss with respect to the pre-activations and (c) the gradients of the loss with respect to the weights for rectified linear networks of varying depths during the first 10 steps of training. Colors show step number from 0 (black) to 10 (green).

It is counter intuitive that training can occur at intermediate depths where there is significant gradient explosion. To gain insight into the behavior of the network during learning we record the magnitudes of the weights, the gradients with respect to the pre-activations, and the gradients with respect to the weights for the first 10 steps of training for networks of different depths. The result of this experiment is shown in fig. 3. Here we see that before learning, as expected, the norm of the weights is constant and independent of layer while the gradients feature exponential explosion. However, we observe that two related phenomena occur after a single step of learning: the weights grow exponentially in the depth and the magnitude of the gradients are stable up to some threshold after which they vanish exponentially in the depth . Thus, it seems that although the gradients of batch normalized networks at initialization are ill-conditioned, the gradients appear to quickly reach a stable dynamical equilibrium. Pathologically, in very high depth settings, the relative gradient vanishing can in fact be so severe as to cause lower layers to mostly stay constant during training.

Figure 4: Three techniques for counteracting gradient explosion. Test accuracy on MNIST as a function of different hyperparameters along with theoretical predictions (white dashed line) for the maximum trainable depth. (a) network changing the overall scale of the pre-activations, here corresponds to the linear regime. (b) Rectified linear network changing the mean of the pre-activations, here corresponds to the linear regime. (c,d) and rectified linear networks respectively as a function of , here we observe a well defined phase transition near . Note that in the case of rectified linear activations we use so that the function is locally linear about . We also find initializing and/or setting having positive effect on VGG19 with batchnorm. See Figs. A.4 and A.3.

As discussed in the theoretical exposition above, batch normalization necessarily features exploding gradients for any nonlinearity that converges to a BSB1 fixed point. We performed a number of experiments exploring different ways of ameliorating this gradient explosion. These experiments are shown in fig. 4 with theoretical predictions for the maximum trainable depth overlaid; in all cases we see exceptional agreement. In fig. 4 (a,b) we explore two different ways of tuning the degree to which activation functions in a network are nonlinear. In fig. 4 (a) we tune for networks with -activations and note that in the limit the function is linear. In fig. 4 (b) we tune for networks with rectified linear activations and we note, similarly, that in the limit the function is linear. As expected, we see the maximum trainable depth increase significantly with decreasing and increasing . In fig. 4 (c,d) we vary for and rectified linear networks respectively. In both cases, we observe a critical point at large where gradients do not explode and very deep networks are trainable.

5 Conclusion

In this work we have presented a theory for neural networks with batch normalization at initialization. In the process of doing so, we have uncovered a number of counterintuitive aspects of batch normalization and – in particular – the fact that at initialization it unavoidably causes gradients to explode with depth. We have introduced several methods to reduce the degree of gradient, explosion enabling the training of significantly deeper networks in the presence of batch normalization. Finally, this work paves the way for future work on more advanced, state-of-the-art, network architectures and topologies.

References

  • Ba et al. (2016) Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E. Hinton. Layer Normalization. arXiv:1607.06450 [cs, stat], July 2016. URL http://arxiv.org/abs/1607.06450.
  • Balduzzi et al. (2017) David Balduzzi, Marcus Frean, Lennox Leary, J. P. Lewis, Kurt Wan-Duo Ma, and Brian McWilliams. The Shattered Gradients Problem: If resnets are the answer, then what is the question? In PMLR, pp. 342–350, July 2017. URL http://proceedings.mlr.press/v70/balduzzi17b.html.
  • Bjorck et al. (2018) Johan Bjorck, Carla Gomes, and Bart Selman. Understanding Batch Normalization. June 2018. URL https://arxiv.org/abs/1806.02375.
  • Chen et al. (2018) Minmin Chen, Jeffrey Pennington, and Samuel Schoenholz. Dynamical isometry and a mean field theory of RNNs: Gating enables signal propagation in recurrent neural networks. In Jennifer Dy and Andreas Krause (eds.), Proceedings of the 35th International Conference on Machine Learning, volume 80 of Proceedings of Machine Learning Research, pp. 873–882, Stockholmsmässan, Stockholm Sweden, 10–15 Jul 2018. PMLR. URL http://proceedings.mlr.press/v80/chen18i.html.
  • Chen & Wu (2017) Q. Chen and R. Wu. CNN Is All You Need. ArXiv e-prints, December 2017.
  • Cho & Saul (2009) Youngmin Cho and Lawrence K. Saul. Kernel methods for deep learning. In Advances in neural information processing systems, pp. 342–350, 2009. URL http://papers.nips.cc/paper/3628-kernel-methods-for-deep-learning.
  • Daniely et al. (2016) Amit Daniely, Roy Frostig, and Yoram Singer. Toward Deeper Understanding of Neural Networks: The Power of Initialization and a Dual View on Expressivity. arXiv:1602.05897 [cs, stat], February 2016. URL http://arxiv.org/abs/1602.05897.
  • de G. Matthews et al. (2018) Alexander G. de G. Matthews, Jiri Hron, Mark Rowland, Richard E. Turner, and Zoubin Ghahramani. Gaussian process behaviour in wide deep neural networks. In International Conference on Learning Representations, 2018. URL https://openreview.net/forum?id=H1-nGgWC-.
  • Gitman & Ginsburg (2017) Igor Gitman and Boris Ginsburg. Comparison of Batch Normalization and Weight Normalization Algorithms for the Large-scale Image Classification. arXiv:1709.08145 [cs], September 2017. URL http://arxiv.org/abs/1709.08145.
  • Hanin & Rolnick (2018) Boris Hanin and David Rolnick. How to start training: The effect of initialization and architecture. arXiv preprint arXiv:1803.01719, 2018.
  • He et al. (2015) Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep Residual Learning for Image Recognition. arXiv:1512.03385 [cs], December 2015. URL http://arxiv.org/abs/1512.03385.
  • Ioffe & Szegedy (2015) Sergey Ioffe and Christian Szegedy. Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. arXiv:1502.03167 [cs], February 2015. URL http://arxiv.org/abs/1502.03167.
  • Kohler et al. (2018) Jonas Kohler, Hadi Daneshmand, Aurelien Lucchi, Ming Zhou, Klaus Neymeyr, and Thomas Hofmann. Towards a Theoretical Understanding of Batch Normalization. arXiv:1805.10694 [cs, stat], May 2018. URL http://arxiv.org/abs/1805.10694.
  • LeCun et al. (1990) Yann LeCun, Bernhard E Boser, John S Denker, Donnie Henderson, Richard E Howard, Wayne E Hubbard, and Lawrence D Jackel. Handwritten digit recognition with a back-propagation network. In Advances in neural information processing systems, pp. 396–404, 1990.
  • Lee et al. (2017) Jaehoon Lee, Yasaman Bahri, Roman Novak, Samuel S Schoenholz, Jeffrey Pennington, and Jascha Sohl-Dickstein. Deep neural networks as gaussian processes. arXiv preprint arXiv:1711.00165, 2017.
  • Luo et al. (2018) Ping Luo, Xinjiang Wang, Wenqi Shao, and Zhanglin Peng. Understanding Regularization in Batch Normalization. arXiv:1809.00846 [cs, stat], September 2018. URL http://arxiv.org/abs/1809.00846.
  • Novak et al. (2019) Roman Novak, Lechao Xiao, Jaehoon Lee, Yasaman Bahri, Greg Yang, Jiri Hron, Daniel A. Abolafia, Jeffrey Pennington, and Jascha Sohl-Dickstein. Bayesian deep convolutional networks with many channels are gaussian processes. International Conference of Learning Representations, 2019. URL https://openreview.net/forum?id=B1g30j0qF7.
  • Pennington et al. (2017) Jeffrey Pennington, Samuel Schoenholz, and Surya Ganguli. Resurrecting the sigmoid in deep learning through dynamical isometry: theory and practice. In Advances in neural information processing systems, pp. 4785–4795, 2017.
  • Poole et al. (2016) Ben Poole, Subhaneil Lahiri, Maithra Raghu, Jascha Sohl-Dickstein, and Surya Ganguli. Exponential expressivity in deep neural networks through transient chaos. arXiv:1606.05340 [cond-mat, stat], June 2016. URL http://arxiv.org/abs/1606.05340.
  • Salimans & Kingma (2016) Tim Salimans and Diederik P. Kingma. Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks. February 2016. URL https://arxiv.org/abs/1602.07868.
  • Santurkar et al. (2018) Shibani Santurkar, Dimitris Tsipras, Andrew Ilyas, and Aleksander Madry. How Does Batch Normalization Help Optimization? (No, It Is Not About Internal Covariate Shift). arXiv:1805.11604 [cs, stat], May 2018. URL http://arxiv.org/abs/1805.11604.
  • Schoenholz et al. (2016) Samuel S. Schoenholz, Justin Gilmer, Surya Ganguli, and Jascha Sohl-Dickstein. Deep Information Propagation. arXiv:1611.01232 [cs, stat], November 2016. URL http://arxiv.org/abs/1611.01232.
  • Silver et al. (2017) David Silver, Julian Schrittwieser, Karen Simonyan, Ioannis Antonoglou, Aja Huang, Arthur Guez, Thomas Hubert, Lucas Baker, Matthew Lai, Adrian Bolton, et al. Mastering the game of go without human knowledge. Nature, 550(7676):354, 2017.
  • Smith & Le (2018) Samuel L Smith and Quoc V Le. A bayesian perspective on generalization and stochastic gradient descent. 2018.
  • (25) P.K. Suetin. Ultraspherical polynomials - Encyclopedia of Mathematics. URL https://www.encyclopediaofmath.org/index.php/Ultraspherical_polynomials.
  • (26) Eric W. Weisstein. Gegenbauer Polynomial. URL http://mathworld.wolfram.com/GegenbauerPolynomial.html.
  • Williams (1997) Christopher KI Williams. Computing with infinite networks. In Advances in neural information processing systems, pp. 295–301, 1997.
  • Xiao et al. (2018) Lechao Xiao, Yasaman Bahri, Jascha Sohl-Dickstein, Samuel Schoenholz, and Jeffrey Pennington. Dynamical isometry and a mean field theory of CNNs: How to train 10,000-layer vanilla convolutional neural networks. In Jennifer Dy and Andreas Krause (eds.), Proceedings of the 35th International Conference on Machine Learning, volume 80 of Proceedings of Machine Learning Research, pp. 5393–5402, Stockholmsmässan, Stockholm Sweden, 10–15 Jul 2018. PMLR. URL http://proceedings.mlr.press/v80/xiao18a.html.
  • Yang (2019) Greg Yang. Scaling limits of wide neural networks with weight sharing: Gaussian process behavior, gradient independence, and neural tangent kernel derivation. arXiv preprint arXiv:1902.04760, 2019.
  • Yang & Schoenholz (2017) Greg Yang and Samuel S. Schoenholz. Meanfield Residual Network: On the Edge of Chaos. In Advances in neural information processing systems, 2017.
  • Yang & Schoenholz (2018) Greg Yang and Samuel S Schoenholz. Deep mean field theory: Layerwise variance and width variation as methods to control gradient explosion. 2018.
  • (32) Barret Zoph, Vijay Vasudevan, Jonathon Shlens, and Quoc V Le. Learning transferable architectures for scalable image recognition.

(a)
(b)

Figure A.1: Batch norm leads to a chaotic input-output map with increasing depth. A linear network with batch norm is shown acting on two minibatches of size 64 after random orthogonal initialization. The datapoints in the minibatch are chosen to form a 2d circle in input space, except for one datapoint that is perturbed separately in each minibatch (leftmost datapoint at input layer 0). (a) Each pane shows a scatterplot of activations at a given layer for all datapoints in the minibatch, projected onto the top two PCA directions. PCA directions are computed using the concatenation of the two minibatches. Due to the batch norm nonlinearity, minibatches that are nearly identical in input space grow increasingly dissimilar with depth. Intuitively, this chaotic input-output map can be understood as the source of exploding gradients when batch norm is applied to very deep networks, since very small movements in input space correspond to very large movements in output space. (b) The correlation between the two minibatches, as a function of layer, for the same network. Despite having a correlation near one at the input layer, the two minibatches rapidly decorrelate with depth. See A.4 for a theoretical treatment.

Appendix A.1 VGG19 with Batchnorm on CIFAR100

Even though at initialization time batchnorm causes gradient explosion, after the first few epochs, the relative gradient norms for weight parameters or BN scale parameter , equilibrate to about the same magnitude. See Fig. A.2.

Figure A.2: relative gradient norms of different parameters in layer order (input to output from left to right), with and interleaving. From dark to light blue, each curve is separated by (a) 3, (b) 5, or (c) 10 epochs. We see that after 10 epochs, the relative gradient norms of both and for all layers become approximately equal despite gradient explosion initially.

We find acceleration effects, especially in initial training, due to setting and/or initializing . See Figs. A.4 and A.3.

Figure A.3: We sweep over different values of learning rate, initialization, and , in training VGG19 with batchnorm on CIFAR100 with data augmentation. We use 8 random seeds for each combination, and assign to each combination the median training/validation accuracy over all runs. We then aggregate these scores here. In the first row we look at training accuracy with different learning rate vs initialization at different epochs of training, presenting the max over . In the second row we do the same for validation accuracy. In the third row, we look at the matrix of training accuracy for learning rate vs , taking max over . In the fourth row, we do the same for validation accuracy.
Figure A.4: In the same setting as Fig. A.3, except we don’t take the max over the unseen hyperparameter but rather set it to 0 (the default value).

In what follows, we adopt a slightly different notation from the main text in order to express the mean field theory of batchnorm more faithfullly.

Appendix A.2 Forward Dynamics

{defn}

Let be the space of PSD matrices of size . Given a measurable function , define the integral transform by 444This definition of absorbs the previous definitions of and in Yang & Schoenholz (2017) for the scalar case When and is clear from context, we also write for applied to the function acting coordinatewise by .

{defn}

For any , let be batchnorm (applied to a batch of neuronal activations) followed by coordinatewise applications of , (here ). When we will also write . We write for ReLU, so that is batchnorm followed by ReLU.

{defn}

Define the matrix . Let be the space of PSD matrices of size with zero mean across rows and columns, .

When is clear from context, we will suppress the subscript/superscript . In short, for , zeros the sample mean of . is a projection matrix to the subspace of vectors of zero coordinate sum. With the above definitions, we then have

In this section we will be interested in studying the dynamics on PSD matrices of the form

where and .

a.2.1 Simplification

On the face of it, the iteration map requires one to do an dimensional integral. However, one can reduce this down to a constant number of dimension (independent of ) if the operator has a closed form. There are two ways to do this: 1) The Laplace method, which reduces the integral down to 1 dimension but requires the assumption that is positive homogeneous, and 2) The Fourier method, which reduces the integral down to 2 dimensions but allows to be any function.

a.2.1.1 Laplace Method

The key insight in the Laplace Method is to apply Schiwinger parametrization to deal with normalization. {lemma}[Schwinger parametrization] For and ,

The following is the key lemma in the Laplace method. {lemma}[The Laplace Method Master Equation]

For , let and let . Suppose for some nondecreasing function such that exists for every . Define . Then on , is well-defined and continuous, and furthermore satisfies

(28)
Proof.

is well-defined for full rank because the singularity at is Lebesgue-integrable in a neighborhood of 0 in dimension .

We prove Eq. 28 in the case when is full rank and then apply a continuity argument.

Proof of Eq. 28 for full rank . First, we will show that we can exchange the order of integration

by Fubini-Tonelli’s theorem. Observe that

For ,

Because exists by assumption,

as , by dominated convergence with dominating function . By the same reasoning, the function is continuous. In particular this implies that . Combined with the fact that as ,

which is bounded by our assumption that . This shows that we can apply Fubini-Tonelli’s theorem to allow exchanging order of integration.

Thus,

Domain and continuity of . The LHS of Eq. 28, , is defined and continuous on . Indeed, if , where is a full rank matrix with , then

This is integrable in a neighborhood of 0 iff , while it’s always integrable outside a ball around 0 because by itself already is. So is defined whenever . Its continuity can be established by dominated convergence.

Proof of Eq. 28 for . Observe that is continuous in and, by an application of dominated convergence as in the above, is continuous in . So the RHS of Eq. 28 is continuous in whenever the integral exists. By the reasoning above, is bounded in and , so that the integral exists iff .

To summarize, we have proved that both sides of Eq. 28 are defined and continous for . Because the full rank matrices are dense in this set, by continuity Eq. 28 holds for all . ∎

If is degree- positive homogeneous, i.e. for any , we can then compute