New insights and perspectives on the natural gradient method
Natural gradient descent is an optimization method traditionally motivated from the perspective of information geometry, and works well for many applications as an alternative to stochastic gradient descent. In this paper we critically analyze this method and its properties, and show how it can be viewed as a type of approximate 2nd-order optimization method, where the Fisher information matrix can be viewed as an approximation of the Hessian. This perspective turns out to have significant implications for how to design a practical and robust version of the method. Additionally, we make the following contributions to the understanding of natural gradient and 2nd-order methods: a thorough analysis of the convergence speed of stochastic natural gradient descent (and more general stochastic 2nd-order methods) as applied to convex quadratics, a critical examination of the oft-used “empirical” approximation of the Fisher matrix, and an analysis of the (approximate) parameterization invariance property possessed by natural gradient methods, which we show still holds for certain choices of the curvature matrix other than the Fisher, but notably not the Hessian.
Keywords: natural gradient methods, 2nd-order optimization, neural networks, convergence rate, parameterization invariance
- 1 Introduction and overview
- 2 Neural Networks
- 3 Supervised learning framework
- 4 KL divergence objectives
- 5 Various definitions of the natural gradient and the Fisher information matrix
- 6 Geometric interpretation
- 7 2nd-order optimization
- 8 The generalized Gauss-Newton matrix
- 9 Computational aspects of the natural gradient and connections to the generalized Gauss-Newton matrix
- 10 Constructing practical natural gradient methods, and the role of damping
- 11 The empirical Fisher
- 12 Asymptotic convergence speed
- 13 A critical analysis of parameterization invariance
- 14 A new interpretation of the natural gradient
- 15 Conclusions and open questions
- A Extra derivations for Theorem 1
- B Proof of Theorem 4
- C Derivations of bounds for Section 12.2.1
- D Proof of Theorem 6
1 Introduction and overview
The natural gradient descent approach, pioneered by Amari and collaborators (e.g. Amari, 1998), is a popular alternative to traditional gradient descent methods which has received a lot of attention over the past several decades, motivating many new and related approaches. It has been successfully applied to a variety of problems such as blind source separation (Amari and Cichocki, 1998), reinforcement learning (Peters and Schaal, 2008), and neural network training (e.g. Park et al., 2000; Martens and Grosse, 2015; Desjardins et al., 2015).
Natural gradient descent is generally applicable to the optimization of probabilistic models111This includes neural networks, which can be cast as conditional models., and involves the use of the so-called “natural gradient” in place of the standard gradient, which is defined as the gradient times the inverse of the model’s Fisher information matrix (see Section 5). In many applications, natural gradient descent seems to require far fewer total iterations than gradient descent, making it a potentially attractive alternative method. Unfortunately, for models with very many parameters such as large neural networks, computing the natural gradient is impractical due to the extreme size of the Fisher information matrix (“the Fisher”). This problem can be addressed through the use of one of various approximations to the Fisher (e.g Le Roux et al., 2008; Ollivier, 2015; Grosse and Salakhudinov, 2015; Martens and Grosse, 2015) that are designed to be easier to compute, store and invert than the exact Fisher.
Natural gradient descent is classically motivated as a way of implementing steepest descent in the space of realizable distributions222Those distributions which correspond to some setting of the model’s parameters. instead of the space of parameters, where distance in the distribution space is measured with a special “Riemannian metric” (Amari and Nagaoka, 2007). This metric depends only on the properties of the distributions themselves and not their parameters, and in particular is defined so that it approximates the square root of the KL divergence within small neighborhoods. Under this interpretation (discussed in detail in Section 6), natural gradient descent is invariant to any smooth and invertible reparameterization of the model, putting it in stark contrast to gradient descent, whose performance is highly parameterization dependent.
In practice however, natural gradient descent still operates within the default parameter space, and works by computing directions in the space of distributions and then translating them back to the default space before taking a discrete step. Because of this, the above discussed interpretation breaks down unless the step-size becomes arbitrarily small, and as discussed in Section 10, this breakdown has important implications for designing a natural gradient method that can work well in practice. Another problem with this interpretation is that it doesn’t provide any obvious reason why a step of natural gradient descent should make more progress optimizing the objective than a step of standard gradient descent (assuming well chosen step-sizes for both).
Given a large step-size one also loses the parameterization invariance property of the natural gradient method, although it will still hold approximately under certain conditions which are described in Section 13.
In Section 10 we argue for an alternative view of natural gradient descent: as an approximate 2nd-order method which utilizes the Fisher as an approximation to the Hessian, so that the natural gradient approximates a 2nd-order step computed using the Hessian. As discussed in Section 7, 2nd-order methods work by forming a local quadratic approximation to the objective around the current iterate and produce the next iterate by optimizing this approximation within some restricted region where the approximation is thought to be accurate. According to this view, natural gradient descent makes more progress per step than gradient descent because it implicitly uses a local quadratic model/approximation of the objective function which is more accurate and less conservative than the one implicitly used by gradient descent.
In support of this view is the fact that the Fisher can be cast as an approximation of the Hessian in at least two different ways (provided the objective has the form discussed in Section 4). First, as discussed in Section 5, it corresponds to the expected Hessian of the loss under the model’s distribution over predicted outputs instead of the usual empirical one used to compute the exact Hessian. And second, as we establish in Section 9, it is very often equivalent to the so-called “Generalized Gauss-Newton matrix” (GGN) (defined in Section 8), which is a well established and rigorously justified approximation of the Hessian that has been used in practical 2nd-order optimizations such as those of Martens (2010) and Vinyals and Povey (2012).
Viewing natural gradient descent as an approximate 2nd-order method is also prescriptive, since it suggests the use of various damping/regularization techniques often used in the optimization literature for dealing with the problem of quadratic model trust. Indeed, such techniques have been successfully applied in 2nd-order methods such as that of Martens (2010) and Martens and Grosse (2015), where they proved crucial in achieving good and robust performance in practice.
The Fisher, which is used in computing the natural gradient direction, is defined as the covariance of the gradient of the model’s log likelihood function with respect to cases sampled from its distribution. Because it is often simpler to implement and somewhat more economical, a commonly used approximation of the Fisher, which we discuss in Section 11, is to use cases sampled from the training set instead. Known as the “empirical Fisher”, this matrix differs from the usual Fisher in subtle but very important ways, which as shown in Section 11.1, make it considerably less useful as an approximation to the Fisher and as a curvature matrix within 2nd-order optimization methods. Using the empirical Fisher also breaks some of the theory regarding natural gradient descent, although it nonetheless preserves the (approximate) parameterization invariance enjoyed by the method (as shown in Section 13). Despite these objections, the empirical Fisher has been used in many works such as Le Roux et al. (2008) and the recent spate of methods based on diagonal approximations of this matrix (which we review and critically examine in Section 11.2).
A well-known and often quoted result about stochastic natural gradient descent is that it is asymptotically “Fisher efficient” Amari (1998). Roughly speaking, this means that it provides an asymptotically unbiased estimate of the parameters with the lowest possible variance among all unbiased estimators (given the same amount of data), thus achieving the best possible expected objective function value. Unfortunately, as discussed in Section 12.1, this result comes with several important caveats which severely limit its applicability. Moreover even when it is applicable it only provides an asymptotically accurate characterization of the method which may not accurately describe its behavior given a realistic number of iterations.
To address these issues we build on the work of Murata (1998) in Section 12.2 and Section 12.3 to develop a more powerful convergence theory for approximate stochastic 2nd-order methods (including natural gradient descent) as applied to convex quadratic objectives. Our results provide a more precise expression for the convergence speed of such methods than existing results do, and properly account for the effect of the starting point. And as we discuss in Section 12.2.1 and Section 12.3.1 they imply various interesting consequences about the relative performance of various 1st and 2nd-order stochastic optimization methods.
Perhaps the most interesting conclusion of this analysis is that with parameter averaging applied, stochastic gradient descent with a constant step-size/learning-rate achieves the same asymptotic convergence speed as natural gradient descent (and is thus also “Fisher efficient”), although 2nd-order methods (such as the latter) can enjoy a more favorable dependence on the starting point, which means that they can make much more progress given a limited iteration budget.
Unfortunately these results fail to fully explain why 2nd-order optimization with the GGN/Fisher works so much better than classical 2nd-order schemes such as Newton’s method. And so in Section 15 we propose several important open questions in this direction that we leave for future research.
Table of notation
|-th entry of a vector|
|-th entry a matrix|
|gradient of a scalar function|
|Jacobian of a vector-valued function|
|Hessian of a scalar function (typically taken with respect to unless otherwise specified)|
|vector of parameters|
|weight matrix at layer|
|unit inputs at layer|
|unit activities at layer|
|number of layers|
|dimension of the network’s output|
|number of units in -th layer of the network|
|function mapping the neural network’s inputs to its output|
|local quadratic approximation of at|
|strength constant for penalty-based damping|
|-th largest eigenvalue a symmetric matrix|
|generalized Gauss-Newton matrix (GGN)|
|predictive distribution used at network’s output (so )|
|, ,||density functions associated with above , , and (resp.)|
|Fisher information matrix (typically associated with )|
|Fisher information matrix associated with parameterized distribution|
2 Neural Networks
Feed-forward neural networks are structured very similarly to classical circuits. They typically consist of a sequence of “layers” of units, where each unit in a given layer receive inputs from the units in the previous layer, and computes an affine function of these, followed by a scalar non-linear function called an “activation function”. The input vector to the network, denoted by , is given by the units of the first layer, which is called the “input layer” (and is not counted towards the total ). The output vector of the network, denoted by , is given by the units of the network’s last layer (called the “output layer”). The other layers are referred to as the network’s “hidden layers”.
Formally, given input , and parameters which determine weight matrices and biases , the network computes its output according to
where . Here, is the vector of values (“activities”) of the network’s -th layer, and is the vector-valued non-linear function computed at layer , and is often given by some simple monotonic activation function applied coordinate-wise.
Note that most of the results discussed in this document will apply to the more general setting where is an arbitrary differentiable function (in both and ).
3 Supervised learning framework
The goal of optimization/learning is to find some setting of so that the output of the network (which we will sometimes call its “prediction”) matches certain target outputs as closely as possible. In particular, given a training set consisting of training pairs , the goal of learning is to minimize the objective function
where is a “loss function” which measures the amount of disagreement between and .
The prediction may be a guess for , in which case might measure the inaccuracy of this guess (e.g. using the familiar squared error ). Or could encode the parameters of some simple predictive distribution. For example, could be the set of probabilities which parameterize a multinomial distribution over the possible discrete values of , with being the negative log probability of under this distribution.
4 KL divergence objectives
The natural gradient method of Amari (1998) can be potentially applied to any objective function which measures the performance of some statistical model. However, it enjoys richer theoretical properties when applied to objective functions based on the KL divergence between the model’s distribution and the target distribution, or certain approximations/surrogates of these.
In this section we will establish the basic notation and properties of these objective functions, and discuss the various ways in which they can be formulated. Each of these formulations will be analogous to a particular formulation of the Fisher information matrix and natural gradient (as defined in Section 5), which will differ in subtle but important ways.
In the idealized setting, input vectors are drawn independently from a target distribution with density function , and the corresponding (target) outputs from a conditional target distribution with density function .
We define the goal of learning as the minimization of the KL divergence from target joint distribution , whose density is , to the learned distribution , whose density is . Note that the second is not a typo here, since we are not learning the distribution over , only the conditional distribution of given . Our objective function is thus
This is equivalent to the expected KL divergence
since we have
It is often the case that we only have samples from and no direct knowledge of its density function. Or the expectation w.r.t. in eqn. 2 may be too difficult to compute. In such cases, we can substitute an empirical training distribution in for , which is given by a set of samples from . This gives the objective
Provided that is known for each in and that can be efficiently computed, we can use the above expression as our objective.
Otherwise, as is often the case, we might only have access to a single sample from for each , giving an empirical training distribution . Substituting this in for gives the objective function
where we have extended to a set of the pairs (which agrees with how was defined in Section 3). This is the same objective as is minimized in standard maximum likelihood learning
This kind of objective function fits into the general supervised learning framework described in Section 3 as follows. We define the learned conditional distribution to be the composition of the deterministic neural network function , and an “output” conditional distribution (with associated density function ), so that
We then define the loss function as .
Given a loss function which is not explicitly defined this way one can typically still find a corresponding to make the definition apply. In particular, if has the same finite integral w.r.t. for each , then one can define by taking , where the proportion is w.r.t. both and .
5 Various definitions of the natural gradient and the Fisher information matrix
The usual definition of the natural gradient (Amari, 1998) which appears in the literature is
where is the Fisher information matrix of w.r.t. . is given by
where gradients and Hessians are taken w.r.t. . For the purposes of brevity we will often refer to the Fisher information matrix simply as the “Fisher”.
It can be immediately seen from the first of these expressions for that it is positive semi-definite (PSD) (since it’s the expectation of something which is trivially PSD, a vector outer-product). And from the second expression we can see that it also has the interpretation of being the negative expected Hessian of .
Because where doesn’t depend on , we have
and so can also be written as the expectation (w.r.t. ) of the Fisher information matrix of as follows:
In Amari (1998), this version of is computed explicitly for a basic perceptron model (basically a neural network with 0 hidden layers) in the case where .
However in practice the real may be not directly available, or it may be difficult to integrate over . For example, the conditional Hessian corresponding to a multilayer neural network may be far too complicated to be analytically integrated, even for a very simple . In such situations may be replaced with its empirical version giving
This is the version of considered in Park et al. (2000).
From these expressions we can see that that when (as in Section 4), the Fisher has the interpretation of being the expectation under of the Hessian of :
Meanwhile, the Hessian of is also given by the expected value of the Hessian of , except under the distribution instead of (where is given by the density ). In other words
Thus can be seen as an approximation of in some sense.
6 Geometric interpretation
The negative gradient can be interpreted as the steepest descent direction for in the sense that it yields the most reduction in per unit of change in , where change is measured using the standard Euclidean norm . More formally we have
This interpretation exposes the strong dependence of the gradient on the Euclidean geometry of the parameter space (as defined by the norm ).
One way to motivate the natural gradient is to show that it can be viewed as a steepest descent direction, much like the negative gradient can be, except with respect to a metric that is intrinsic to the distributions being modeled as opposed to the default Euclidean metric in parameter space. In particular, the natural gradient can be derived by adapting the steepest descent formulation to use an alternative definition of (local) distance based on the “information geometry” (Amari and Nagaoka, 2000) of the space of probability distributions (as induced by the parameters). The particular distance function333Note that this is not a formal “distance” function in the usual sense since it is not symmetric. which gives rise to the natural gradient turns out to be
To make this formal, we will first show how the KL divergence and the Fisher are fundamentally connected. The Taylor series expansion of the above distance is
where “” is short-hand to mean terms that are order 3 or higher in the entries of . Thus defines the local quadratic approximation of this distance, and so gives the mechanism of local translation between the geometry of the space of distributions, and that of the original parameter space with its default Euclidean geometry.
To make use of this connection we first observe, as in Arnold et al. (2011), that for a general positive definite matrix we have
where the notation is defined by .
Then taking and using the above Taylor series expansion of the KL divergence to show that as , with some extra work (Arnold et al., 2011) it follows that
Thus the negative natural gradient is indeed the steepest descent direction in the space of distributions where distance is (approximately) measured in local neighborhoods by the KL divergence. While this might seem impossible since the KL divergence is in general not symmetric in its two arguments, it turns out that is locally/asymptotically symmetric as goes to zero, and so will be (approximately) symmetric in a local neighborhood 444This follows from the fact the second order term of the Taylor series of is also given by ..
Note that both and are defined in terms of the standard basis in -space, and so obviously depend on the parameterization of . But the KL divergence does not, and instead only depends on the form of the predictive distribution . Thus, the direction in distribution space defined implicitly by will be invariant to our choice of parameterization (whereas the direction defined by will not be).
By using the smoothly varying PSD matrix to locally define a metric tensor at every point in parameter space, a Riemannian manifold can be generated over the space of distributions. Note that the associated metric of this space won’t be the KL divergence (this isn’t even a valid metric), although it will be “locally equivalent” to the square root of the KL divergence in the sense that the two will approximate each other within a small neighborhood.
7 2nd-order optimization
The basic idea in 2nd-order optimization is to compute the update to by minimizing some local quadratic approximation of centered around the current iterate . That is, we compute and then update according to , where is defined by
and where is the “curvature matrix”, which is symmetric. The “sub-problem” of optimizing can be performed exactly by solving the dimensional linear system , whose solution is when is invertible.
Gradient descent, the canonical 1st-order method, can be viewed in the framework of 2nd-order methods as making the choice for some , resulting in the update . In the case where is convex and Lipschitz-smooth555By this we mean that for all and . with constant , a safe/conservative choice that will ensure convergence with is (e.g. Nesterov, 2013). The intuition behind this choice is that will act as a global upper bound on the curvature of , in the sense that 666Here we define to mean that is PSD., so that never extends past the point that would be safe in the worst-case scenario where the curvature sharply increases to the upper bound as one travels along . More concretely, one can show that given this choice of , upper bounds , and will therefore never predict a reduction in where there is actually a sharp increase (e.g. due to curving unexpectedly upward on the path from to ). Minimizing is therefore guaranteed not to increase beyond the current value since . But despite these nice properties, this choice will almost always overestimate the curvature in most directions, leading to updates that move unnecessarily slowly along directions of consistent low curvature.
While neural networks haven’t been closely studied by optimization researchers, many of the local optimization issues related to neural network learning can be seen as extreme special cases of problems which arise more generally in continuous optimization. For example, tightly coupled parameters with strong local dependencies, and large variations in scale along different directions in parameter space (which may arise due to the “vanishing gradient” phenomenon (Hochreiter et al., 2000)), are precisely the sorts of issues for which 2nd-order optimization is well suited. Gradient descent on the other hand is well-known to be very sensitive to such issues, and in order to avoid large oscillations and instability must use a learning rate which is inversely proportional to the size of the curvature along the highest curvature direction. 2nd-order optimization methods provide a much more powerful and elegant solution to the problem of variations in scale/curvature along different directions, by selectively re-scaling the gradient along different eigen-directions of the curvature matrix according to their associated curvature (eigenvalue), instead of employing a one-size-fits-all step-size.
In the classical Newton’s method we take , in which case becomes the 2nd-order Taylor-series approximation of centered at . This choice gives us the most accurate local model of the curvature possible, and allows for very rapid exploration of low-curvature directions yielding faster convergence. Unfortunately, Newton’s method runs into numerous problems when applied to neural network training objectives, such as being sometimes indefinite (and thus being unbounded below in directions of negative curvature) and related issues of model trust, where the method implicitly “trusts” its own local quadratic model of the objective too much, causing it to generate huge and nonsensical updates that increase . This problem is particular to 2nd-order methods because they use a much less conservative model of the curvature than gradient descent, that may start out as accurate around , but which may quickly become a severe underestimate as one travels along . Fortunately, using the Gauss-Newton approximation to the Hessian (as discussed in Section 8), and applying various update damping/trust-region techniques (as discussed in Section 10), these issues can be mostly overcome.
Another important issue preventing the naive application of 2nd-order methods to neural networks is the very high dimensionality of the parameter space, which prohibits the calculation/storage/inversion of the -entry curvature matrix . To address this, various approximate Newton methods have been developed within the optimization and machine learning communities. These methods work by approximating with something easier to compute/store/invert such as a low-rank or diagonal matrix, or by performing only approximate/incomplete optimization of . A survey of such methods is outside the scope of this report, so we refer to the reader to Martens (2016).
8 The generalized Gauss-Newton matrix
The classical Gauss-Newton matrix (or more simply the Gauss-Newton matrix) is the curvature matrix which arises in the Gauss-Newton method for non-linear least squares problems. It is applicable to our standard neural network training objective in the case where , and is given by
where is the Jacobian of w.r.t. the parameters . It is usually defined as the approximation to the Hessian of (w.r.t. ) obtained by dropping the second term inside the sum of the following expression for :
where is the Hessian (w.r.t. ) of the -th component of .
An alternative way to derive the classical Gauss-Newton is to simply replace the non-linear function by its own local linear approximation, centered at the current value of . In particular, we replace by so that becomes a quadratic function of , with derivative and Hessian given by .
Schraudolph (2002) showed how the idea of the Gauss-Newton matrix can be generalized to the situation where is any loss function which is convex in . The generalized formula for is
where is the Hessian of w.r.t. , evaluated at . Because is convex, will be PSD for each , and thus so will . We will call this the Generalized Gauss-Newton matrix (GGN). Analogously to the case of the classical Gauss-Newton matrix (which assumed ), the GGN can be obtained by dropping the second term inside the sum of the following expression for the Hessian (e.g. Nocedal and Wright, 2006):
Here, is the gradient of w.r.t. , evaluated at . Note if we have for some local optimum that for each and , which corresponds to network making an optimal prediction for each training case over each dimension, then . In such a case, the behavior of a 2nd-order optimizer using will approach the behavior of the ideal Newton method as it converges to .
Like the Hessian, the GGN can be used to define a local quadratic model of , as given by:
In approximate Newton/2nd-order methods based on the GGN, parameter updates are computed by minimizing w.r.t. . The exact minimizer is often too difficult to compute, and so practical methods like the Hessian-free optimization of Martens (2010), or Krylov Subspace Descent (Vinyals and Povey, 2012) will only approximately minimize .
A key property of which is not shared by the Hessian is that it is PSD, and can thus be used to define a local quadratic model to the objective which is bounded. While the unboundedness of local quadratic models defined by the Hessian can be worked around by imposing a trust region, it has nevertheless been observed by various researchers Schraudolph (2002); Martens (2010); Vinyals and Povey (2012) that works much better in practice for neural network optimization.
Since computing the whole matrix explicitly is usually too expensive, the GGN is typically accessed via matrix-vector products. To compute such products efficiently one can use the method of Schraudolph (2002), which is a generalization of the well-known method for computing such products with the classical Gauss-Newton. The method is similar in cost and structure to standard backpropagation, although it can sometimes be tricky to implement (see Martens and Sutskever (2012)).
As pointed out in Martens and Sutskever (2011), the GGN can also be derived by generalizing the previously described alternative derivation of the classical Gauss-Newton matrix to the situation where is an arbitrary convex loss. In particular, if we substitute the linearization for in as before (where is the linearization of ), it is not difficult to see that the Hessian of the resulting will be equal to the GGN.
Schraudolph (2002) advocated that when computing the GGN, and be redefined so that as much as possible of the network’s computation is formally performed by instead of , while maintaining the convexity of . This is because, unlike , is not linearly approximated in the GGN, and so its associated second-order derivative terms are faithfully captured. What this almost always means in practice is that what is usually thought of as the final non-linearity of the network (i.e. ) is folded into , and the network itself just computes the identity function at its top layer. Interestingly, in many natural situations which occur in practice, doing this gives a much simpler and more elegant expression for . Exactly when and why this happens will be made clear in Section 9.
8.1 Speculation on possible advantages of the GGN over the Hessian
Unlike the Hessian, the GGN is positive semi-definite (PSD). This means that it never models the curvature in any direction as negative. The most obvious problem with negative curvature is that the quadratic model will predict an unbounded quadratic improvement in the objective for moving in certain directions. Indeed, without the use of some kind of trust-region or damping technique (as discussed in Section 10) the update produced by minimizing the quadratic model will be infinitely large in any direction of negative curvature.
While curvature can indeed be negative in a local neighborhood (as measured by the Hessian), we know it must quickly become non-negative as we travel along any particular direction, given that our loss is convex in and bounded below. Meanwhile, positive curvature predicts a quadratic penalty, and in the worst case merely underestimates how badly the objective will eventually increase along a particular direction.
Because contributions made to the GGN for each training case and each individual component of are PSD, there can be no cancellation between positive and negative/indefinite contributions. This means that the GGN can be more robustly estimated from subsets of the training data than the Hessian. By analogy, consider how much harder it is to estimate the scale of the mean value of a variable when that variable can take on both positive and negative values, and has a mean close to .
This property of being PSD for individual training cases and components of also means that positive curvature from one training case, or one component of the network’s prediction, cannot be cancelled out by negative curvature from others. If we believe that negative curvature is less “trustworthy” than positive curvature over larger distances, then it seems like a good idea to prevent positive curvature from being cancelled in this manner.
Notably the GGN is not an upper bound on the Hessian (in the PSD sense), as it fails to model all of the positive curvature contained in the latter. But crucially, it only fails to model the (positive or negative) curvature coming from the network function , as opposed to the curvature coming from the loss function . (To see this, recall the decomposition of the Hessian from eqn. 6, noting that the term dropped from the Hessian depends only on the gradients of and the Hessian of components of .) Curvature coming , whether it is positive or negative, is arguably less trustworthy/stable across long distance than curvature coming from , as argued below.
The following decomposition of the Hessian is a generalization of eqn. 6:
Here, is the gradient of w.r.t. , is the Hessian of (i.e. the function which computes ) w.r.t. , and is the Jacobian of (viewed as a function of and ) w.r.t. .
We can see from this equation that the curvature coming from the network function is a sum of curvature terms coming from each neural unit, weighted by the gradient of the loss w.r.t. that unit’s output . It seems reasonable to expect that the sign of these terms may be subject to frequent change, due both to changes in the sign of “local Hessian” of ( is typically non-convex), and to changes in the sign of the loss derivative w.r.t. that unit’s output (), which depends on the behavior of all of the layers above . This is to be contrasted with the curvature term arising from the loss, which remains PSD everywhere.
Finally, it is worth noting that for networks with piece-wise linear activation functions, such as the popular RELUs (given by ), the network function has zero curvature almost everywhere, since when , and is undefined otherwise. Thus the Hessian will coincide with the GGN for such networks at all points where the former is defined.
9 Computational aspects of the natural gradient and connections to the generalized Gauss-Newton matrix
9.1 Computing the Fisher (and matrix-vector products with it)
where is the Jacobian of w.r.t. , and is the gradient of w.r.t. , evaluated at (with defined as near the end of Section 4).
As was first shown by Park et al. (2000), the Fisher information matrix is thus given by
where is the Fisher information matrix of the predictive distribution at .
is itself given by
where is the Hessian of w.r.t. , evaluated at .
Note that even if ’s density function is known, and is relatively simple, only for certain choices of and ) will it be possible to analytically evaluate the expectation w.r.t. in the above expression for .
For example, if we take , , and to be a simple neural network with no hidden units and a single tan-sigmoid output unit, then both and its inverse can be computed efficiently (Amari, 1998). This situation is exceptional however, and for even slightly more complex models, such as neural networks with one or more hidden layers, it has never been demonstrated how to make such computations feasible in high dimensions.
Fortunately the situation improves significantly if is replaced by , as this gives
which is easy to evaluate when is. Moreover, this is essentially equivalent to the expression in eqn. 5 for the generalized Gauss-Newton matrix (GGN), except that we have the Fisher of the predictive distribution () instead of Hessian of the loss () as the “inner” matrix.
It also suggests a straightforward and efficient way of computing matrix-vector products with , using an approach similar to the one in Schraudolph (2002) for computing matrix-vector products with the GGN. In particular, one can multiply by using a linearized forward pass, then multiply by (which will be easy if is sufficiently simple), and then finally multiply by using a standard backwards pass.
9.2 Qualified equivalence of the GNN and the Fisher
As we shall see in this subsection, the connections between the GGN and the Fisher run deeper than just similar expressions and similar algorithms for computing matrix-vector products.
In Park et al. (2000) it was shown that if the density function of has the form where is some univariate density function over , then is equal to a re-scaled777Where the re-scaling constant is determined by properties of . version of the classical Gauss-Newton matrix for non-linear least squares, with regression function given by . And in particular, the choice turns the learning problem into exactly non-linear least squares, and into precisely the classical Gauss-Newton matrix.
Heskes (2000) showed that the Fisher and the classical Gauss-Newton matrix are equivalent in the case of the squared error loss and proposed using the Fisher as an approximation to the Hessian in more general contexts. Concurrently with this work (Martens, 2014), Pascanu and Bengio (2014) showed that for several common loss functions like cross-entropy and squared error, the GGN and the Fisher are equivalent.
We will show that in fact there is a much more general equivalence between the two matrices, starting from observation that the expressions for the GGN in eqn. 5 and the Fisher in eqn. 7 are identical up to the equivalence of and .
First, note that may not even be convex in , and so the GGN won’t necessarily be well-defined. But even if is convex in , it won’t be true in general that , and so the GGN and the Fisher will differ. However, there is an important class of ’s for which will hold, provided that we have (putting us in the framework of Section 4).
Notice that , and (which follows from ). Thus, the two matrices being equal is equivalent to the condition
While this condition may seem arbitrary, it is actually very natural and holds in the important case where corresponds to an exponential family model with “natural” parameters given by . That is, when we have
for some function , where is the normalizing constant/partition function. In this case we have which doesn’t depend on , and so eqn. 8 holds trivially.
Examples of such ’s include:
multivariate normal distributions where parameterizes only the mean
multivariate normal distributions where is the concatenation of and the vectorization of
multinomial distributions where the softmax of is the vector of probabilities for each class
Note that the loss function corresponding to the multivariate normal is the familiar squared error, and the one corresponding to the multinomial distribution is the familiar cross-entropy.
As discussed in Section 8, when constructing the GGN one must pay attention to how and are defined with regards to what parts of the neural network’s computation are performed by each function. For example, the softmax computation performed at the final layer of a classification network is usually considered to be part of the network itself and hence to be part of . The output of this computation are normalized probabilities, which are then fed into a cross-entropy loss of the form . But the other way of doing it, which Schraudolph (2002) recommends, is to have the softmax function be part of instead of , which results in a GGN which is slightly closer to the Hessian due to “less” of the computational pipeline being linearized before taking the 2nd-order Taylor series approximation. The corresponding loss function is in this case. As we have established above, doing it this way also has the nice side effect of making the GGN equivalent to the Fisher, provided that is a exponential family model with as its natural parameters.
This (qualified) equivalence between the Fisher and the GGN suggests how the GGN can be generalized to cases where it might not otherwise be well-defined. In particular, it suggests formulating the loss as the negative log density for some distribution and then taking the Fisher of this distribution. Sometimes, this might be as simple as defining as per the discussion at the end of Section 4.
For example, suppose our loss is defined as the negative log probably of a multi-variate normal distribution parameterized by and (so that ). In other words, suppose that
In this case the loss Hessian is equal to
It is not hard to verify that this matrix is indefinite for certain settings of and (e.g. = 2, ). Therefore is not convex in and we cannot define a GGN matrix from it (since the definition of the GGN requires this).
To resolve this problem we can use the Fisher in place of of in the formula for the GGN, which by eqn. 7 yields . Alternatively, we can insert reparameterization operations into our network to transform and into the natural parameters and , and then proceed to compute the GGN as usual, noting that in this case, so that will be PSD (and therefore convex in ). Either way will yield the same curvature matrix, due to the above discussed equivalence of the Fisher and GGN matrix for natural parameterizations.
10 Constructing practical natural gradient methods, and the role of damping
Assuming that it is easy to compute, the simplest way to use the natural gradient in optimization is to substitute it in place of the standard gradient within a basic gradient descent approach. This gives the iteration
where is a schedule of step-sizes/step-sizes.
Choosing the step-size schedule can be difficult. There are adaptive schemes which are largely heuristic in nature (Amari, 1998) and some non-adaptive prescriptions such as , which have certain theoretical convergence guarantees in the stochastic setting, but which won’t necessarily work well in practice.
Ideally, we would want to apply the natural gradient method with infinitesimally small steps and produce a smooth idealized path through the space of realizable distributions. But since this is usually impossible in practice, and we don’t have access to any other simple description of the class of distributions parameterized by that we could work with more directly, we must take non-negligible discrete steps in the given parameter space888In principle, we could move to a much more general class of distributions, such as those given by some non-parametric formulation, where we could work directly with the distributions themselves. But even assuming such an approach would be practical from a computational efficiency standpoint, we would lose the various advantages that we get from working with powerful parametric models like neural networks. In particular, we would lose their ability to generalize to unseen data by modeling the “computational process” which explains the data, instead of merely using smoothness and locality to generalize..
The fundamental problem with simple schemes such as the one in eqn. 9 is that they implicitly assume that the natural gradient is a good direction to follow over non-negligible distances in the original parameter space, which will not be true in general. Traveling along a straight line in the original parameter space will not yield a straight line in distribution space, and so the resulting path may instead veer far away from the target that the natural gradient originally pointed towards. This is illustrated in Figure 1.
Fortunately, we can exploit the (qualified) equivalence between the Fisher and the GGN in order to produce natural gradient-like updates which will often be appropriate to take with . In particular, we know from the discussion in Section 8 that the GGN matrix can serve as a reasonable proxy for the Hessian of , and will often produce smaller and more “conservative” updates as it tends to model the curvature as being higher in most directions than the Hessian does. Meanwhile, the update produced by minimizing the GGN-based local quadratic model is given by , which will be equal to the natural gradient when . Thus, the natural gradient, with scaling factor , can be seen as the optimal update according to an approximate, and perhaps slightly conservative, 2nd-order model of .
But just as in the case of approximate 2nd-order methods, the break-down in the accuracy of the quadratic approximation of over long distances, combined with the potential for the natural gradient to be very large (e.g. when contains some very small eigenvalues), can often lead to very large and very poor update proposals. And simply re-scaling the update by reducing may be too crude a mechanism to deal with this subtle problem, as it will affect all eigen-directions (of ) equally, including those in which the natural gradient is already sensible or even overly conservative.
Instead, the connection between natural gradient descent and 2nd-order methods suggests the use of some of the various update “damping” techniques that have been developed for the latter, which work by constraining or penalizing the solution for in various ways during the optimization of . Examples include Tikhonov regularization/damping and the closely related trust-region method (e.g. Nocedal and Wright, 2006), and other more sophisticated ones such as the “structural damping” approach of Martens and Sutskever (2011), or the approach present in Krylov Subspace Descent (Vinyals and Povey, 2012). See Martens and Sutskever (2012) for an in-depth discussion of these and other damping techniques.
This idea is well supported by practical experience since, for example, the Hessian-free optimization approach of Martens (2010) generates its updates using an Tikhonov damping scheme applied to the GGN matrix (which for the objectives they optimized in the paper were equivalent to the Fisher), and these updates are used effectively with and make a lot more progress on the objective than optimally re-scaled updates computed without damping (i.e. the raw natural gradient).
11 The empirical Fisher
An approximation of the Fisher known as the “empirical Fisher” (denoted ), which is often used in practical natural gradient methods, is obtained by taking the inner expectation of eqn. 3 over the target distribution (or its empirical surrogate ) instead of the model’s distribution .
In the case that one uses , this yields the following simple form:
This matrix is often incorrectly referred to as the Fisher, or even the Gauss-Newton, although it is in general not equivalent to either of those matrices.
11.1 Comparisons to the standard Fisher
Like the Fisher , the empirical Fisher is PSD. But unlike , it is essentially free to compute, provided that one is already computing the gradient of . It can also be applied to objective functions which might not involve a probabilistic model in any obvious way.
Compared to , which is of rank , has a rank of , which can make it easier to work with in practice. For example, the problem of computing the diagonal (or various blocks) is easier for the empirical Fisher than it is for higher rank matrices like the standard Fisher (Martens et al., 2012). This has motivated its use in optimization methods such as TONGA (Le Roux et al., 2008), and as the diagonal preconditioner of choice in the Hessian-free optimization method (Martens, 2010). Interestingly however, there are stochastic estimation methods (Chapelle and Erhan, 2011; Martens et al., 2012) which can be used to efficiently estimate the diagonal (or various blocks) of the standard Fisher , and these work quite well in practice.
Despite the various practical advantages of using , there are good reasons to use true Fisher instead of whenever possible. In addition to Amari’s extensive theory developed for the exact natural gradient (which uses ), perhaps the best reason for using over is that turns out to be a reasonable approximation to the Hessian of in certain important special cases, which is a property that lacks in general.
For example, as discussed in Section 5, when the loss is given by (as in Section 4), can be seen as an approximation of , because both matrices have the interpretation of being the expected Hessian of the loss under some distribution. Due to the similarity of the expression for in eqn. 3 and the one above for it might be tempting to think that is given by the expected Hessian of the loss under (which is actually the formula for ) in the same way that is given by eqn. 4, however this is not the case in general.
And as we saw in Section 9, given certain assumptions about how the GGN is computed, and some additional assumptions about the form of the loss function , turns out to be equivalent to the GGN. This is very useful since the GGN can be used to define a local quadratic approximation of , whereas normally doesn’t have such an interpretation. Moreover, Schraudolph (2002) and later Martens (2010) compared to the GGN and observed that the latter performed much better as a curvature matrix within various neural network optimization methods.
As concrete evidence for why the empirical Fisher is, at best, a questionable choice for the curvature matrix, consider the following example. We will set , , , and , so that is a simple convex quadratic function of given by . In this example we have that , , while . If we use as our curvature matrix for some exponent , then it is easy to see that an iteration of the form
will fail to converge to the minimizer () unless and the step-size goes to sufficiently fast. And even when it does converge, it will only be at a rate comparable to the speed at which goes to , which in typical situations will be either or . Meanwhile, a similar iteration of the form
which uses the exact Fisher as the curvature matrix, will experience very fast linear convergence999Here we mean “linear” in the classical sense that and not in the sense that with rate , for any fixed step-size satisfying .
It is important to note that this example uses a noise-free version of the gradient, and that this kind of linear convergence is (provably) impossible in most realistic stochastic/online settings. Nevertheless, we would argue that a highly desirable property of any stochastic optimization method should be that it can, in principle, revert to an optimal (or nearly optimal) behavior in the deterministic setting. This might matter a lot in practice, since the gradient may end up being sufficiently well estimated in earlier stages of optimization from only a small amount of data (which is a common occurrence in our experience), or in later stages provided that larger mini-batches or other variance-reducing procedures are employed (e.g. Le Roux et al., 2012; Johnson and Zhang, 2013). More concretely, the pre-asymptotic convergence rate of stochastic 2nd-order optimizers can still strongly depend on the choice of the curvature matrix, as we will show in Section 12.
11.2 Recent diagonal methods based on the empirical Fisher
Recently, a spate of stochastic optimization methods have been proposed that are all based on diagonal approximations of the empirical Fisher . These include the diagonal version of AdaGrad (Duchi et al., 2011), RMSProp (Tieleman and Hinton, 2012), Adam (Ba and Kingma, 2015), etc. Such methods use iterations of the following form (possibly with some slight modifications):
where the curvature matrix is taken to be a diagonal matrix with adapted to maintain some kind of estimate of the diagonal of (possibly using information from previous iterates/mini-batches), is an estimate of produced from the current mini-batch,