Three Mechanisms of Weight Decay Regularization
Abstract
Weight decay is one of the standard tricks in the neural network toolbox, but the reasons for its regularization effect are poorly understood, and recent results have cast doubt on the traditional interpretation in terms of regularization. Literal weight decay has been shown to outperform regularization for optimizers for which they differ. We empirically investigate weight decay for three optimization algorithms (SGD, Adam, and KFAC) and a variety of network architectures. We identify three distinct mechanisms by which weight decay exerts a regularization effect, depending on the particular optimization algorithm and architecture: (1) increasing the effective learning rate, (2) approximately regularizing the inputoutput Jacobian norm, and (3) reducing the effective damping coefficient for secondorder optimization. Our results provide insight into how to improve the regularization of neural networks.
1 Introduction
Weight decay has long been a standard trick to improve the generalization performance of neural networks (Krogh & Hertz, 1992; Bos & Chug, 1996) by encouraging the weights to be small in magnitude. It is widely interpreted as a form of regularization because it can be derived from the gradient of the norm of the weights in the gradient descent setting. However, several findings cast doubt on this interpretation:

Weight decay has sometimes been observed to improve training accuracy, not just generalization performance (e.g. Krizhevsky et al. (2012)).

Weight decay is widely used in networks with Batch Normalization (BN) (Ioffe & Szegedy, 2015). In principle, weight decay regularization should have no effect in this case, since one can scale the weights by a small factor without changing the network’s predictions. Hence, it does not meaningfully constrain the network’s capacity.
The effect of weight decay remains poorly understood, and we lack clear guidelines for which tasks and architectures it is likely to help or hurt. A better understanding of the role of weight decay would help us design more efficient and robust neural network architectures.
In order to better understand the effect of weight decay, we experimented with both weight decay and regularization applied to image classifiers using three different optimization algorithms: SGD, Adam, and KroneckerFactored Approximate Curvature (KFAC) (Martens & Grosse, 2015). Consistent with the observations of Loshchilov & Hutter (2017), we found that weight decay consistently outperformed regularization in cases where they differ. Weight decay gave an especially strong performance boost to the KFAC optimizer, and closed most of the generalization gaps between first and secondorder optimizers, as well as between small and large batches. We then investigated the reasons for weight decay’s performance boost. Surprisingly, we identified three distinct mechanisms by which weight decay has a regularizing effect, depending on the particular algorithm and architecture:

In our experiments with firstorder optimization methods (SGD and Adam) on networks with BN, we found that it acts by way of the effective learning rate. Specifically, weight decay reduces the scale of the weights, increasing the effective learning rate, thereby increasing the regularization effect of gradient noise (Neelakantan et al., 2015; Keskar et al., 2016). As evidence, we found that almost all of the regularization effect of weight decay was due to applying it to layers with BN (for which weight decay is meaningless). Furthermore, when we computed the effective learning rate for the network with weight decay, and applied the same effective learning rate to a network without weight decay, this captured the full regularization effect.

We show that when KFAC is applied to a linear network using the GaussNewton metric (KFACG), weight decay is equivalent to regularizing the squared Frobenius norm of the inputoutput Jacobian (which was shown by Novak et al. (2018) to improve generalization). Empirically, we found that even for (nonlinear) classification networks, the GaussNewton norm (which KFAC with weight decay is implicitly regularizing) is highly correlated with the Jacobian norm, and that KFAC with weight decay significantly reduces the Jacobian norm.

Because the idealized, undamped version of KFAC is invariant to affine reparameterizations, the implicit learning rate effect described above should not apply. However, in practice the approximate curvature matrix is damped by adding a multiple of the identity matrix, and this damping is not scaleinvariant. We show that without weight decay, the weights grow large, causing the effective damping term to increase. If the effective damping term grows large enough to dominate the curvature term, it effectively turns KFAC into a firstorder optimizer. Weight decay keeps the effective damping term small, enabling KFAC to retain its secondorder properties, and hence improving generalization.
Hence, we have identified three distinct mechanisms by which weight decay improves generalization, depending on the optimization algorithm and network architecture. Our results underscore the subtlety and complexity of neural network training: the final performance numbers obscure a variety of complex interactions between phenomena. While more analysis and experimentation is needed to understand how broadly each of our three mechanisms applies (and to find additional mechanisms!), our work provides a starting point for understanding practical regularization effects in neural network training.
2 Preliminaries
Supervised learning. Given a training set consisting of training pairs , and a neural network with parameters (including weights and biases), our goal is to minimize the emprical risk expressed as an average of a loss over the training set: .
Stochastic Gradient Descent. To minimize the empirical risk , stochastic gradient descent (SGD) is used extensively in deep learning community. Typically, gradient descent methods can be derived from the framework of steepest descent with respect to standard Euclidean metric in parameter space. Specifically, gradient descent minimizes the following surrogate objective in each iteration:
(1) 
where the distance (or dissimilarity) function is chosen as . In this case, solving equation 1 yields , where is the learning rate.
Natural gradient.
Though popular, gradient descent methods often struggle to navigate “valleys” in the loss surface with illconditioned curvature (Martens, 2010).
Natural gradient descent, as a variant of secondorder methods (Martens, 2014), is able to make more progress per iteration by taking into account the curvature information.
One way to motivate natural gradient descent is to show that it can be derived by adapting steepest descent formulation, much like gradient descnet, except using an alternative local distance.
The distance function which leads to natural gradient is the KL divergence on the model’s predictive distribution , where is the Fisher information matrix
(2) 
Applying this distance function to equation 1, we have .
GaussNewton algorithm. Another sensible distance function in equation 1 is the distance on the output (logits) of the neural network, i.e. . This leads to the classical GaussNewton algorithm which updates the parameters by , where the GaussNewton (GN) matrix is defined as
(3) 
and is the Jacobian of w.r.t . The GaussNewton algorithm, much like natural gradient descent, is also invariant to the specific parameterization of neural network function .
Two curvature matrices. It has been shown that the GN matrix is equivalent to the Fisher matrix in the case of regression task with squared error loss (Heskes, 2000). However, they are not identical for the case of classification, where crossentropy loss is commonly used. Nevertheless, Martens (2014) showed that the Fisher matrix is equivalent to generalized GN matrix when model prediction corresponds to exponential family model with natural parameters given by , where the generalized GN matrix is given by
(4) 
and is the Hessian of w.r.t , evaluated at . In regression with squared error loss, the Hessian happens to be identity matrix.
Preconditioned gradient descent. Given the fact that both natural gradient descent and GaussNewton algorithm precondition the gradient with an extra curvature matrix (including the Fisher matrix and GN matrix), we also term them preconditioned gradient descent for convenience.
KFAC. As modern neural networks may contain millions of parameters, computing and storing the exact curvature matrix and its inverse is impractical. Kroneckerfactored approximate curvature (KFAC) Martens & Grosse (2015) uses a Kroneckerfactored approximation to the curvature matrix to perform efficient approximate natural gradient updates. As shown by Luk & Grosse (2018), KFAC can be applied to general pullback metric, including Fisher metric and the GaussNewton metric. For more details, we refer reader to Appendix F or Martens & Grosse (2015).
Batch Normalization. Beyond sophisticated optimization algorithms, Batch Normalization (BN) plays an indispensable role in modern neural networks. Broadly speaking, BN is a mechanism that aims to stabilize the distribution (over a minibatch) of inputs to a given network layer during training. This is achieved by augmenting the network with additional layers that subtract the mean and divide by the standard deviation . Typically, the normalized inputs are also scaled and shifted based on trainable parameters and :
(5) 
For clarity, we ignore the parameters and , which do not impact the performance in practice. We further note that BN is applied before the activation function and not used in the output layer.
3 The Effectiveness of Weight Decay
Dataset  Network  B  D  SGD  ADAM  KFACF  KFACG  

WD  WD  WD  WD  
CIFAR10  VGG16  83.20  84.87  83.16  84.12  85.58  89.60  83.85  89.81  
✓  86.99  88.85  88.45  88.72  87.97  89.02  88.17  89.77  
✓  ✓  91.71  93.39  92.89  93.62  93.12  93.90  93.19  93.80  
CIFAR10 
ResNet32  85.47  86.63  84.43  87.54  86.82  90.22  85.24  90.64  
✓  86.13  90.65  89.46  90.61  89.78  91.24  89.94  90.91  
✓  ✓  92.95  95.14  93.63  94.66  93.80  95.35  93.44  95.04  
CIFAR100  VGG16  ✓  ✓  68.42  73.31  69.88  74.22  71.05  73.36  67.46  73.57 
CIFAR100  ResNet32  ✓  ✓  73.61  77.73  73.60  77.40  74.49  78.01  73.70  78.02 
Our goal is to understand weight decay regularization in the context of training deep neural networks. Towards this, we first discuss the relationship between regularization and weight decay in different optimizers.
Gradient descent with weight decay is defined by the following update rule: , where defines the rate of the weight decay per step and is the learning rate. In this case, weight decay is equivalent to regularization. However, the two differ when the gradient update is preconditioned by a matrix , as in Adam or KFAC. The preconditioned gradient descent update with regularization is given by
(6) 
whereas the weight decay update is given by
(7) 
The difference between these updates is whether the preconditioner is applied to . The latter update can be interpreted as the preconditioned gradient descent update on a regularized objective where the regularizer is the squared norm . If is adapted based on statistics collected during training, as in Adam or KFAC, this interpretation holds only approximately because gradient descent on would require differentiating through . However, this approximate regularization term can still yield insight into the behavior of weight decay. (As we discuss later, this observation informs some, but not all, of the empirical phenomena we have observed.) Though the difference between the two updates may appear subtle, we find that it makes a substantial difference in terms of generalization performance.
Initial Experiments.
We now present some empirical findings about the effectiveness of weight decay which the rest of the paper is devoted to explaining. Our experiments were carried out on two different datasets: CIFAR10 and CIFAR100 (Krizhevsky & Hinton, 2009) with varied batch sizes. We test VGG16 (Simonyan & Zisserman, 2014) and ResNet32 (He et al., 2016) on both CIFAR10 and CIFAR100 (for more details, see Appendix A). In particular, we investigate three different optimization algorithms: SGD, Adam and KFAC. We consider two versions of KFAC, which use the GaussNewton matrix (KFACG) and Fisher information matrix (KFACF).
Figure 1 shows the comparison between weight decay, regularization and the baseline. We also compare weight decay to the baseline on more settings and report the final test accuracies in Table 1. Finally, the results for largebatch training are summarized in Table 3. Based on these results, we make the following observations regarding weight decay:

In all experiments, weight decay regularization consistently improved the performance and was more effective than regularization in cases where they differ (See Figure 1).

Weight decay significantly improved performance even for BN networks (See Table 1), where it does not meaningfully constrain the networks’ capacity.

Finally, we notice that weight decay gave an especially strong performance boost to the KFAC optimizer when BN was disabled (see the first and fourth rows in Table 1).
In the following section, we seek to explain these phenomena. With further testing, we find that weight decay can work in unexpected ways, especially in the presence of BN.
4 Three Mechanisms of Weight Decay Regularization
4.1 Mechanism I: Higher Effective Learning Rate
As discussed in Section 3, when SGD is used as the optimizer, weight decay can be interpreted as penalizing the norm of the weights. Classically, this was believed to constrain the model by penalizing explanations with large weight norm. However, for a network with Batch Normalization (BN), an penalty does not meaningfully constrain the reprsentation, because the network’s predictions are invariant to rescaling of the weights and biases. More precisely, if denotes the output of a layer with parameters in which BN is applied before the activation function, then
(8) 
for any . By choosing small , one can make the norm arbitrarily small without changing the function computed by the network. Hence, in principle, adding weight decay to layers with BN should have no effect on the optimal solution. But empirically, weight decay appears to significantly improve generalization for BN networks (e.g. see Figure 1).
van Laarhoven (2017) observed that weight decay, by reducing the norm of the weights, increases the effective learning rate. Since higher learning rates lead to larger gradient noise, which has been shown to act as a stochastic regularizer (Neelakantan et al., 2015; Keskar et al., 2016), this means weight decay can indirectly exert a regularizing effect through the effective learning rate. In this section, we provide additional evidence supporting the hypothesis of van Laarhoven (2017). For simplicity, this section focuses on SGD, but we’ve observed similar behavior when Adam is used as the optimizer.
Due to its invariance to the scaling of the weights, the key property of the weight vector is its direction. As shown by Hoffer et al. (2018), the weight direction is updated according to
(9) 
Therefore, the effective learning rate is approximately proportional to . Which means that by decreasing the scale of the weights, weight decay regularization increases the effective learning rate.
Figure 3 shows the effective learning rate over time for two BN networks trained with SGD (the results for Adam are similar), one with weight decay and one without it. Each network is trained with a typical learning rate decay schedule, including 3 factorof10 reductions in the learning rate parameter, spaced 60 epochs apart. Without weight decay, the normalization effects cause an additional effective learning rate decay (due to the increase of weight norm), which reduces the effective learning rate by a factor of 10 over the first 50 epochs. By contrast, when weight decay is applied, the effective learning rate remains more or less constant in each stage.
We now show that the effective learning rate schedule explains nearly the entire generalization effect of weight decay. First, we independently varied whether weight decay was applied to the top layer of the network, and to the remaining layers. Since all layers except the top one used BN, it’s only in the top layer that weight decay would constrain the model. Training curves for SGD and Adam under all four conditions are shown in Figure 2. In all cases, we observe that whether weight decay was applied to the top (fully connected) layer did not appear to matter; whether it was applied to the reamining (convolution) layers explained most of the generalization effect. This supports the effective learning rate hypothesis.
We further tested this hypothesis using a simple experimental manipulation. Specifically, we trained a BN network without weight decay, but after each epoch, rescaled the weights in each layer to match that layer’s norm from the corresponding epoch for the network with weight decay. This rescaling does not affect the network’s predictions, and is equivalent to setting the effective learning rate to match the second network. As shown in Figure 4, this effective learning rate transfer scheme ( wnconv) eliminates almost the entire generalization gap; it is fully closed by also adding weight decay to the top layer ( wdfc+wnconv). Hence, we conclude that for BN networks trained with SGD or Adam, weight decay achieves its regularization effect primarily through the effective learning rate.
4.2 Mechanism II: Approximate Jacobian Regularization
In Section 3, we observed that when BN is disabled, weight decay has the strongest regularization effect when KFAC is used as the optimizer. Hence, in this section we analyze the effect of weight decay for KFAC with networks without BN. First, we show that in a certain idealized setting, KFAC with weight decay regularizes the inputoutput Jacobian of the network. We then empirically investigate whether it behaves similarly for practical networks.
As discussed in Section 3, when the gradient updates are preconditioned by a matrix , weight decay can be viewed as approximate preconditioned gradient descent on the norm . This interpretation is only approximate because the exact gradient update requires differentiating through .
These norms are interesting from a regularization perspective. First, under certain conditions, they are proportional to the average norm of the network’s outputs. Hence, the regularizer ought to make the network’s predictions less extreme. This is summarized by the following results:
Lemma 1 (Gradient structure).
For a feedforward neural network of depth with ReLU activation function and no biases, the network’s outputs are related to the inputoutput Jacobian and parameteroutput Jacobian as follows:
(10)  
Lemma 2 (KFAC GaussNewton Norm).
For a linear feedforward network of depth without biases
(11) 
Using these results, we show that for linear networks with whitened inputs, the KFAC GaussNewton norm is proportional to the squared Frobenius norm of the inputoutput Jacobian. This is interesting from a regularization perspective, since Novak et al. (2018) found the norm of the inputoutput Jacobian to be consistently coupled to generalization performance.
Theorem 1 (Approximate Jacobian norm).
For a linear feedforward network of depth without biases, if we further assume that and , then:
(12) 
Proof.
While the equivalence between the KFAC GN norm and the Jacobian norm holds only for linear networks, we note that linear networks have been useful for understanding the dynamics of neural net training more broadly (e.g. Saxe et al. (2013)). Hence, Jacobian regularization may help inform our understanding of weight decay in practical (nonlinear) networks.
To test whether the KFAC GN norm correlates with the Jacobian norm for practical networks, we trained feedforward networks with a variety optimizers on both MNIST and CIFAR10. For MNIST, we used simple fullyconnected networks with different depth and width. For CIFAR10, we adopted the VGG family (From VGG11 to VGG19). We defined the generalization gap to be the difference between training and test loss. Figure 5 shows the relationship of the Jacobian norm to the KFAC GN norm and to generalization gap for these networks. We observe that the Jacobian norm correlates strongly with the generalization gap (consistent with Novak et al. (2018)) and also with the KFAC GN norm. Hence, Remark 1 can inform the regularization of nonlinear networks.
Optimizer  VGG16  ResNet32  

WD  WD  
SGD  564  142  2765  1074 
KFACG  498  51.44  2115  64.16 
To test if KFAC with weight decay reduces the Jacobian norm, we compared the Jacobian norms at the end of training for networks with and without weight decay. As shown in Table 2, weight decay reduced the Jacboian norm by a much larger factor when KFAC was used as the optimizer than when SGD was used as the optimizer.
Our discussion so far as focused on the GN version of KFAC. Recall that, in many cases, the Fisher information matrix differs from the GN matrix only in that it accounts for the output layer Hessian. Hence, this analysis may help inform the behavior of KFACF as well. We also note that , the FisherRao norm, has been proposed as a complexity measure for neural networks (Liang et al., 2017). Hence, unlike in the case of SGD and Adam for BN networks, we interpret KFAC with weight decay as constraining the capacity of the network.
4.3 Mechanism III: Smaller Effective Damping Parameter
We now return our attention to the setting of architectures with BN. The Jacobian regularization mechanism from Section 4.2 does not apply in this case, since rescaling the weights results in an equivalent network, and therefore does not affect the inputoutput Jacobian. Similarly, if the network is trained with KFAC, then the effective learning rate mechanism from Section 4.1 also does not apply because the KFAC update is invariant to affine reparameterization (Luk & Grosse, 2018) and therefore not affected by the scaling of the weights. More precisely, for a layer with BN, the curvature matrix (either the Fisher matrix or the GN matrix) has the following property:
(13) 
where as in Section 4.1. Hence, the factor in the preconditioner counteracts the factor in the effective learning rate, resulting in an equivlaent effective learning rate regardless of the norm of the weights.
These observations raise the question of whether it is still useful to apply weight decay to BN layers when using KFAC. To answer this question, we repeated the experiments in Figure 2 (applying weight decay to subsets of the layers), but with KFAC as the optimizer. The results are summarized in Figure 6. Applying it to the nonBN layers had the largest effect, consistent with the Jacobian regularization hypothesis. However, applying weight decay to the BN layers also led to significant gains, especially for KFACF.
The reason this does not contradict the KFAC invariance property is that practical KFAC implementations (like many secondorder optimizers) dampen the updates by adding a multiple of the identity matrix to the curvature before inversion. According to Equation 13, as the norm of the weights gets larger, gets smaller, and hence the damping term comes to dominate the preconditioner. Mathematically, we can understand this effect by deriving the following update rule for the normalized weights (see Appendix D for proof):
(14) 
where is the damping parameter. Hence, for large or small , the update is close to the idealized secondorder update, while for small enough or large enough , KFAC effectively becomes a firstorder optimizer. Hence, by keeping the weights small, weight decay helps KFAC to retain its secondorder properties.
Most implementations of KFAC keep the damping parameter fixed throughout training. Therefore, it would be convenient if and do not change too much during training, so that a single value of can work well throughout training. Interestingly, the norm of the GN matrix appears to be much more stable than the norm of the Fisher matrix. Figure 7 shows the norms of the Fisher matrix and GN matrix of the normalized weights for the first layer of a CIFAR10 network throughout training. While the norm of decays by 4 orders of magnitude over the first 50 epochs, the norm of increases by only a factor of 2.
The explanation for this is as follows: in a classification task with crossentropy loss, the Fisher matrix is equivalent to the generalized GN matrix (see Section 2). This differs from the GN matrix only in that it incudes the output layer Hessian , where is the vector of estimated class probabilities. It is easy to see that goes to zero as collapses to one class, as is the case for tasks such as CIFAR10 and CIFAR100 where networks typically achieve perfect training accuracy. Hence, we would expect to get much smaller over the course of training, consistent with Figure 7.
To summarize, when KFAC is applied to BN networks, it can be advantageous to apply weight decay even to layers with BN, even though this appears unnecessary based on invariance considerations. The reason is that weight decay reduces the effective damping, helping KFAC to retain its secondorder properties. This effect is stronger for KFACF than for KFACG because the Fisher matrix shrinks dramatically over the course of training.
5 Discussion
Despite its long history, weight decay regularization remains poorly understood. We’ve identified three distinct mechanisms by which weight decay improves generalization, depending on the architecture and optimization algorithm: increasing the effective learning rate, reducing the Jacobian norm, and reducing the effective damping parameter. We would not be surprised if there remain additional mechanisms we have not found.
The dynamics of neural net training is incredibly complex, and it can be tempting to simply do what works and not look into why. But we think it is important to at least sometimes dig deeper to determine exactly why an algorithm has the effect that it does. Some of our analysis may seem mundane, or even tedious, as the interactions between different hyperparameters are not commonly seen as a topic worthy of detailed scientific study. But our experiments highlight that the dynamics of the norms of weights and curvature matrices, and their interaction with optimization hyperparameters, can have a substantial impact on generalization. We believe these effects deserve more attention, and would not be surprised if they can help explain the apparent success or failure of other neural net design choices. We also believe our results highlight the need for automatic adaptation of optimization hyperparameters, to eliminate potential experimental confounds and to allow researchers and practitioners to focus on higher level design issues.
6 Acknowledgement
We thank Jimmy Ba, Kevin Luk, Maxime Gazeau, and Behnam Neyshabur for helpful discussions, and Tianqi Chen and Shengyang Sun for their feedback on early drafts. GZ was funded by an MRIS Early Researcher Award.
Appendix A Experiments Details
Throughout the paper, we perform experiments on image classification with three different datasets, MNIST (LeCun et al., 1998), CIFAR10 and CIFAR100 (Krizhevsky & Hinton, 2009). For MNIST, we use simple fullyconnected networks with different depth and width. For CIFAR10 and CIFAR100, we use VGG16 (Simonyan & Zisserman, 2014) and ResNet32 (He et al., 2016). To make the network more flexible, we widen all convolutional layers in ResNet32 by a factor of 4, according to Zagoruyko & Komodakis (2016).
We investigate three different optimization methods, including Stochastic Gradient Descent (SGD), Adam (Kingma & Ba, 2014) and KFAC (Martens & Grosse, 2015). In KFAC, two different curvature matrices are studied, including Fisher information matrix and GaussNewton matrix.
In default, batch size 128 is used unless stated otherwise. In SGD and Adam, we train the networks with a budge of 200 epochs and decay the learning rate by a factor of 10 every 60 epochs for batch sizes of 128 and 640, and every 80 epochs for the batch size of 2K. Whereas we train the networks only with 100 epochs and decay the learning rate every 40 epochs in KFAC. Additionally, the curvature matrix is updated by running average with reestimation every 10 iterations and the inverse operator is amortized to 100 iterations. For KFAC, we use fixed damping term unless state otherwise. For each algorithm, best hyperparameters (learning rate and regularization factor) are selected using grid search on heldout 5k validation set. For the large batch setting, we adopt the same strategies in Hoffer et al. (2017) for adjusting the search range of hyperparameters. Finally, we retrain the model with both training data and validation data.
Appendix B Gradient Structure in Neural Networks (Lemma 1)
Claim. For a feedforward neural network of depth with ReLU activation function and no biases, one has the following property:
(15)  
The key observation of Lemma 1 is that rectified neural networks are piecewise linear up to the output . And ReLU activation function satisfies the property .
Proof.
For convenience, we introduce some notations here. Let denotes output logits , the output th layer. Similarly, we define and . By definition, it is easy to see that
By induction, we conclude that .
On the other side, we have
According to equation B, , therefore we get
Summing over all the layers, we conclude the following equation eventually:
∎
Appendix C Proof of Lemma 2
Claim. For a feedforward neural network of depth with ReLU activation function and no biases, we observe:
(16) 
Furthermore, if we assume the network is linear, we have KFAC GaussNewton norm as follows
(17) 
Proof.
We first prove the equaility . Using the definition of the GaussNewton norm, we have
By Lemma 1,
Combining above equalities, we arrive at the conclusion.
For second part , we note that kroneckerproduct is exact under the condition that the network is linear, which means is the diagonal block version of GaussNewton matrix . Therefore
According to Lemma 1, we have , therefore we conclude that
∎
Appendix D Derivation of equation 14
Claim. During training, the weight direction is updated according to
Proof.
Natural gradient update is given by
Denote . Then we have
and therefore
Additionally, we can rewrite the natural gradient update as follows
And therefore,
∎
Appendix E The Gradient of GaussNewton Norm
Appendix F Kroneckerfactored approximate curvature (KFAC)
Martens & Grosse (2015) proposed KFAC for performing efficient natural gradient optimization in deep neural networks. Following on that work, KFAC has been adopted in many tasks (Wu et al., 2017; Zhang et al., 2017) to gain optimization benefits, and was shown to be amendable to distributed computation (Ba et al., 2016).
f.1 Basic idea of KFAC
As shown by Luk & Grosse (2018), KFAC can be applied to general pullback metric, including Fisher metric and the GaussNewton metric. For convenience, we introduce KFAC here using the Fisher metric.
Considering th layer in the neural network whose input activations are , weight , and output , we have . Therefore, weight gradient is . With this gradient formula, KFAC decouples this layer’s fisher matrix using mild approximations,
(20)  
Where and . The approximation above assumes independence between and , which proves to be accurate in practice. Further, assuming betweenlayer independence, the whole fisher matrix can be approximated as block diagonal consisting of layerwise fisher matrices . Decoupling into and not only avoids the memory issue saving , but also provides efficient natural gradient computation.
(21) 
As shown by equation 21, computing natural gradient using KFAC only consists of matrix transformations comparable to size of , making it very efficient.
f.2 Pseudo code of KFAC
Appendix G Additional Results
g.1 LargeBatch Training
It has been shown that KFAC scales very favorably to larger minibatches compared to SGD, enjoying a nearly linear relationship between minibatch size and periteration progress for mediumtolarge sized minibatches (Martens & Grosse, 2015; Ba et al., 2016). However, Keskar et al. (2016) showed that largebatch methods converge to sharp minima and generalize worse. In this subsection, we measure the generalization performance of KFAC with large batch training and analyze the effect of weight decay.
In Table 3, we compare KFAC with SGD using different batch sizes. In particular, we interpolate between smallbatch (BS128) and largebatch (BS2000). We can see that in accordance with previous works (Keskar et al., 2016; Hoffer et al., 2017) the move from a smallbatch to a largebatch indeed incurs a substantial generalization gap. However, adding weight decay regularization to KFAC almost close the gap on CIFAR10 and cause much of the gap diminish on CIFAR100. Surprisingly, the generalization gap of SGD also disappears with welltuned weight decay regularization. Moreover, we observe that the training loss cannot decrease to zero if weight decay is not used, indicating weight decay may also speed up the training.
Dataset 
Network  Method  BS128  BS640  BS2000  

WD  WD  WD  
CIFAR10  VGG16  SGD  91.71  93.39  90.46  93.09  88.50  92.24 
KFACF  93.12  93.90  92.93  93.55  92.17  93.31  
KFACG  93.19  93.80  92.98  93.74  90.78  93.46  
CIFAR10  ResNet32  SGD  92.95  95.14  91.68  94.45  89.70  94.68 
KFACF  93.80  95.35  92.30  94.79  91.15  94.43  
KFACG  93.44  95.04  91.80  94.73  90.02  94.85  
CIFAR100  ResNet32  SGD  73.61  77.73  71.74  76.67  65.38  76.87 
KFACF  74.49  78.01  73.54  77.34  71.64  77.13  
KFACG  73.70  78.02  71.13  77.40  65.41  76.93 
g.2 The curves of test accuracies
g.3 Optimization performance of different optimizers
While this paper mostly focus on generalization, we also report the convergence speed of different optimizers in deep neural networks; we report both perepoch performance and wallclock time performance.
We consider the task of image classification on CIFAR10 (Krizhevsky & Hinton, 2009) dataset. The models we use consist of VGG16 (Simonyan & Zisserman, 2014) and ResNet32 (He et al., 2016). We compare our KFACG, KFACF with SGD, Adam (Kingma & Ba, 2014).
We experiment with constant learning for KFACG and KFACF. For SGD and Adam, we set batch size as 128. For KFAC, we use batch size of 640, as suggested by Martens & Grosse (2015).
In Figure 9, we report the training curves of different algorithms. Figure 8(a) show that KFACG yields better optimization than other baselines in training loss per epoch. We highlight that the training loss decreases to 1e4 within 10 epochs with KFACG. Although KFAC based algorithms take more time for each epoch, Figure 8(b) still shows wallclock time improvements over the baselines.
In Figure 8(c) and 8(d), we report similar results on the ResNet32. Note that we make the network wider with a widening factor of 4 according to Zagoruyko & Komodakis (2016). KFACG outperforms both KFACF and other baselines in term of optimization per epoch, and compute time.
Footnotes
 The underlying distribution for the expectation in equation 2 has been left ambiguous. Throughout the experiments, we sample the targets from the model’s predictions, as done in Martens & Grosse (2015)
 We show in Appendix E that this interpretation holds exactly in the case of GaussNewton norm.
 For exact GaussNewton norm, the result also holds for deep rectified networks (see Appendix C).
References
 ShunIchi Amari. Natural gradient works efficiently in learning. Neural computation, 10(2):251–276, 1998.
 Jimmy Ba, Roger Grosse, and James Martens. Distributed secondorder optimization using kroneckerfactored approximations. 2016.
 Siegfried Bos and E Chug. Using weight decay to optimize the generalization ability of a perceptron. In Neural Networks, 1996., IEEE International Conference on, volume 1, pp. 241–246. IEEE, 1996.
 Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778, 2016.
 Tom Heskes. On “natural” learning and pruning in multilayered perceptrons. Neural Computation, 12(4):881–901, 2000.
 Elad Hoffer, Itay Hubara, and Daniel Soudry. Train longer, generalize better: closing the generalization gap in large batch training of neural networks. In Advances in Neural Information Processing Systems, pp. 1731–1741, 2017.
 Elad Hoffer, Ron Banner, Itay Golan, and Daniel Soudry. Norm matters: efficient and accurate normalization schemes in deep networks. arXiv preprint arXiv:1803.01814, 2018.
 Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift. arXiv preprint arXiv:1502.03167, 2015.
 Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and Ping Tak Peter Tang. On largebatch training for deep learning: Generalization gap and sharp minima. arXiv preprint arXiv:1609.04836, 2016.
 Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
 Alex Krizhevsky and Geoffrey Hinton. Learning multiple layers of features from tiny images. Technical report, Citeseer, 2009.
 Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton. Imagenet classification with deep convolutional neural networks. In Advances in neural information processing systems, pp. 1097–1105, 2012.
 Anders Krogh and John A Hertz. A simple weight decay can improve generalization. In Advances in neural information processing systems, pp. 950–957, 1992.
 Yann LeCun, Léon Bottou, Yoshua Bengio, and Patrick Haffner. Gradientbased learning applied to document recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998.
 Tengyuan Liang, Tomaso Poggio, Alexander Rakhlin, and James Stokes. Fisherrao metric, geometry, and complexity of neural networks. arXiv preprint arXiv:1711.01530, 2017.
 Ilya Loshchilov and Frank Hutter. Fixing weight decay regularization in adam. arXiv preprint arXiv:1711.05101, 2017.
 Kevin Luk and Roger Grosse. A coordinatefree construction of scalable natural gradient. arXiv preprint arXiv:1808.10340, 2018.
 James Martens. Deep learning via hessianfree optimization. 2010.
 James Martens. New insights and perspectives on the natural gradient method. arXiv preprint arXiv:1412.1193, 2014.
 James Martens and Roger Grosse. Optimizing neural networks with kroneckerfactored approximate curvature. In International conference on machine learning, pp. 2408–2417, 2015.
 Arvind Neelakantan, Luke Vilnis, Quoc V Le, Ilya Sutskever, Lukasz Kaiser, Karol Kurach, and James Martens. Adding gradient noise improves learning for very deep networks. arXiv preprint arXiv:1511.06807, 2015.
 Roman Novak, Yasaman Bahri, Daniel A Abolafia, Jeffrey Pennington, and Jascha SohlDickstein. Sensitivity and generalization in neural networks: an empirical study. arXiv preprint arXiv:1802.08760, 2018.
 Andrew M Saxe, James L McClelland, and Surya Ganguli. Exact solutions to the nonlinear dynamics of learning in deep linear neural networks. arXiv preprint arXiv:1312.6120, 2013.
 Karen Simonyan and Andrew Zisserman. Very deep convolutional networks for largescale image recognition. arXiv preprint arXiv:1409.1556, 2014.
 Twan van Laarhoven. L2 regularization versus batch and weight normalization. arXiv preprint arXiv:1706.05350, 2017.
 Yuhuai Wu, Elman Mansimov, Roger B Grosse, Shun Liao, and Jimmy Ba. Scalable trustregion method for deep reinforcement learning using kroneckerfactored approximation. In Advances in neural information processing systems, pp. 5279–5288, 2017.
 Sergey Zagoruyko and Nikos Komodakis. Wide residual networks. arXiv preprint arXiv:1605.07146, 2016.
 Guodong Zhang, Shengyang Sun, David Duvenaud, and Roger Grosse. Noisy natural gradient as variational inference. arXiv preprint arXiv:1712.02390, 2017.