Backward Feature Correction: How Deep Learning Performs Deep Learning

Backward Feature Correction: How Deep Learning Performs Deep Learning

Abstract

How does a 110-layer ResNet learn a high-complexity classifier using relatively few training examples and short training time? We present a theory towards explaining this in terms of hierarchical learning. We refer hierarchical learning as the learner learns to represent a complicated target function by decomposing it into a sequence of simpler functions to reduce sample and time complexity. This paper formally analyzes how multi-layer neural networks can perform such hierarchical learning efficiently and automatically simply by applying stochastic gradient descent (SGD) to the training objective.

On the conceptual side, we present, to the best of our knowledge, the first theory result indicating how very deep neural networks can still be sample and time efficient on certain hierarchical learning tasks, when no known non-hierarchical algorithms (such as kernel method, linear regression over feature mappings, tensor decomposition, sparse coding, and their simple combinations) are efficient. We establish a new principle called “backward feature correction”, which we believe is the key to understand the hierarchical learning in multi-layer neural networks.

On the technical side, we show for regression and even for binary classification, for every input dimension , there is a concept class consisting of degree multi-variate polynomials so that, using -layer neural networks as learners, SGD can learn any target function from this class in time using samples to any regression or classification error, through learning to represent it as a composition of layers of quadratic functions. In contrast, we present lower bounds stating that several non-hierarchical learners, including any kernel methods, neural tangent kernels, must suffer from super-polynomial sample or time complexity to learn functions in this concept class even to any error.

\setitemize

itemsep=0mm, topsep=2mm, leftmargin=8mm \setenumerateitemsep=0mm, topsep=2mm, leftmargin=8mm

Prelude.   Deep learning is sometimes also referred to as hierarchical learning.1 In practice, multi-layer neural networks, as the most representative hierarchical learning model, often outperform non-hierarchical ones such as kernel methods, SVM over feature mappings, etc. However, from a theory standpoint,

Are multi-layer neural networks actually performing deep learning?

With huge non-convexity obstacles arising from the structure of neural networks, it is perhaps not surprising that existing theoretical works, to the best of our knowledge, have only been able to demonstrate that multi-layer neural networks can efficiently perform tasks that are already known solvable by non-hierarchical (i.e. shallow) learning methods. This is especially true for all recent results on based on neural tangent kernels [6, 40, 3, 5, 18, 8, 7, 62, 17, 15, 35, 25, 43, 28, 57, 13]), which are just kernel methods instead of hierarchical learning. This is the motivation of our research to study the hierarchical learning process in multi-layer neural networks.

1 Introduction

How does a 110-layer ResNet [30] learn a high-complexity classifier for an image data set using relatively few training examples? How can the 100-th layer of the network discover a sophisticated function of the input image efficiently by simply applying stochastic gradient descent (SGD) on the training objective? In this paper, we present a theoretical result towards explaining this efficient deep learning process of such multi-layer neural networks in terms of hierarchical learning.

The term hierarchical learning in supervised learning refers to an algorithm that learns the target function (e.g. the labeling function) using a composition of simpler functions. The algorithm would first learn to represent the target function using simple functions of the input, and then aggregate these simple functions layer by layer to create more and more complicated functions to fit the target. Empirically, it has been discovered for a long time that hierarchical learning, in many applications, requires fewer training examples [11] when comparing to non-hierarchical learning methods that learn the target function in one shot.

Hierarchical learning is also everywhere around us. There is strong evidence that human brains perform learning in hierarchically organized circuits, which is the key for us to learn new concept class with relatively few examples [24]. Moreover, it is also observed that the human’s decision making follows a hierarchical process as well: from “meta goals” to specific actions, which was the motivation for hierarchical reinforcement learning [38]. In machine learning, hierarchical learning is also the key to success for many models, such as hierarchical linear models [55], Bayesian networks/probabilistic directed acyclic graphical models [20], hierarchical Dirichlet allocation [52], hierarchical clustering [47], deep belief networks [39].

Hierarchical learning and multi-layer neural networks.   Perhaps the most representative example of hierarchical learning is neural network. A multi-layer neural network is defined via layers , where each layer represents a simple function (linear + activation) of the previous layers. Thus, multi-layer neural network defines a natural hierarchy: during learning, each network layer could use simple compositions of the learnt functions from previous layers to eventually represent the target function. Empirically, neural networks have shown great success across many different domains [37, 30, 26, 49]. Moreover, it is also well-known [59] that in learning tasks such as image recognition, each layer of the network indeed uses composition of previous layers to learn a function with an increasing complexity.

In learning theory, however, little is known about hierarchical learning, especially for neural networks. Known results mostly focus on the representation power: there are functions that can be represented using -layer networks (under certain distributions), but requires an exponentially larger size to represent using 2-layer network [19]. However, the constructed function and data distribution in [19] separating the power of 2 and 3-layer networks are quite contrived.2 When connecting it back to the actual learning process, to the best of our knowledge, there is no theoretical guarantee that training a -layer network from scratch (e.g. training via SGD from random initialization) can actually learn this separating function efficiently.3 Hence, while the hierarchical structure of -layer network gives it more representation power than 2-layer ones, can the actual learning algorithm learn this “power of hierarchy” efficiently? In other words, from a theory point, these representation results can not answer the following question:

How can multi-layer neural networks perform efficient hierarchical learning when trained by SGD?

Before understanding “how” to this question, let us quickly mention to this date and to the best of our knowledge, it remains even unclear in theory whether for every , some -layer neural network trained via SGD “can” actually use its hierarchical structure as an advantage to learn a function class efficiently, which is otherwise not efficiently learnable by non-hierarchical models. In fact, due to the extreme non-convexity in a multi-layer network, for theoretical purpose, the hierarchical structure is typically even a disadvantage for efficient training. One example of such “disadvantage” is deep linear network (DLN), whose hierarchical structure has no advantage over linear functions in representation power, but becomes an obstacle for training.4

In other words, not only “How can multi-layer neural networks perform efficient hierarchical learning?” is not answered in theory, but even the significantly simpler question “Can multi-layer neural networks efficiently learn simple functions that are already learnable by non-hierarchical models?” is non-trivial due to the extreme non-convexity caused by the hierarchical structures in multi-layer networks. For such reason, it is not surprising that most existing theoretical works on the efficient learning regime of neural networks either focus on (1) two-layer networks [36, 51, 56, 22, 50, 53, 12, 61, 41, 10, 42, 54, 23, 9, 46, 58, 41] which do not have any hierarchical structure, or (2) a multi-layer network but essentially only the last layer is trained [14, 33], or (3) reducing a multi-layer hierarchical neural network to non-hierarchical models such as kernel methods (a.k.a. the neural tangent kernel approach) [6, 40, 3, 5, 18, 8, 7, 62, 17, 15, 35, 25, 13, 43, 28, 57, 13].

While the cited theoretical works shed great light on understanding the learning process of neural networks, none of them treat neural networks as a hierarchical learning model. Thus, we believe they are insufficient for understanding the ultimate power of neural networks. Motivated by such insufficiency, we propose to study the following fundamental question regarding the hierarchical learning in neural networks:

Question.

For every , can we prove that -layer neural networks can efficiently learn a concept class, which is not learnable by any layer network of the same type (i.e. of the same activation function), and more importantly, not learnable by non-hierarchical methods such as the kernel methods (including neural tangent kernels defined by random initialized neural nets) or linear regression over feature mappings, given the same amount of sample and time complexity?

We consider a type of neural networks as the set of neural networks equipped with the same activation function . A positive answer to the first question indicates that going deeper in the network hierarchy can indeed learn a larger class of functions efficiently. A positive answer to the second question indicates that the hierarchical structure of the network is indeed used as an advantage comparing to non-hierarchical learning methods, hence the neural network is indeed performing hierarchical learning.

In this paper we give the first theoretical result towards answering this question: for every , there is certain type of -layer neural networks equipped with quadratic activation functions so that training such networks by SGD indeed efficiently and hierarchically learns a concept class. Here, “efficient” means that to achieve any inverse polynomial generalization error, the number of training examples required to train the network is polynomial in the input dimension, and the total running time is also polynomial. Moreover, we also give lower bounds showing that this concept class is not learnable by non-hierarchical learning methods such as any kernel method (in particular, including the neural tangent kernel given by the initialization of the learner network) or linear regression over feature mappings, or even two-layer networks with certain polynomial activation functions, require super-polynomial sample or time complexity.

Figure 1: Convolutional features of the first layer in AlexNet. In the first 80 epochs, we train only the first layer, freezing layer 2 through 5; in the next 120 epochs, we train all the layers together (starting from the weights in epoch 80). Details in Appendix J.
Observation: In the first 80 epochs, when the first layer is trained until convergence, its features can already catch certain meaningful signals, but cannot get further improved. As soon as the 2nd through 5th layers are added to the training parameters, features of the first layer get improved again.

How can deep learning perform hierarchical learning?   Our paper not only proves such a separation, but also gives, to the best of our knowledge, the first result showing how deep learning can actually perform hierarchical learning when trained by SGD. We identify two critical steps in the hierarchical learning process (see also Figure 1):

  • The forward feature learning step, where a higher-level layer can learn its features using the simple combinations of the learnt features from lower-level layers.

  • The backward feature correction step, where a lower-level layer can also learn to further improve the quality of its features using the learnt features in higher-level layers.

While “forward feature learning” is standard in theory, to the best of our knowledge, “backward feature correction” is not yet recorded anywhere in the theory literature. As we demonstrate both in theory and in experiment (see Figure 1), this is a most critical step in the hierarchical learning process of multi-layer neural networks, and we view it as the main conceptual contribution of this paper.

1.1 Our Theorem

Let us now go into notations. The type of networks we consider is DenseNet [32]:

Here, is the activation function, where we pick in this paper, are weight matrices, and the final (1-dimensional) output is a weighted summation of the outputs of all layers. The set defines the connection graph (the structure of the network). For vanilla feed-forward network, it corresponds to so each layer only uses information from the immediate previous layer. ResNet [30] (with skip connection) corresponds to with weight sharing (namely, ). In this paper, we can handle any connection graph with the only restriction being there is at least one “skip link,” or in symbols, for every , we require , but for some .

One of the main reasons we pick quadratic activation is because it is easy to measure the network’s representation power. Clearly, in quadratic DenseNet, each layer learns a quadratic function of the (weighted) summation of previous layers, so in layer , the hidden neurons represent a degree- multivariate polynomial of the input . Hence, the concept class that can be represented by -layer quadratic DenseNet is obviously increasing with . The question remains to answer is: Can -layer quadratic DenseNet use its hierarchical structure as an advantage to learn certain class of degree- polynomials more efficiently than non-hierarchical models?

We answer this positively. Our main result can be sketched as the follows:

Theorem (informal).

For every input dimension and , there is a class of degree- polynomials and input distributions such that,

  • Given training samples and running time, by performing SGD over the training objective starting from random initialization, the -layer quadratic DenseNet can learn any function in this concept class with any generalization error .

  • Any kernel method, any linear regression over feature mappings, or any two-layer neural networks equipped with arbitrary degree- activations, require either sample or time complexity, to achieve any non-trivial generalization error such as , for any .

The concept class (the class of functions to be learnt) considered in this paper is simply given by -layer quadratic DenseNets with neurons in the -th layer. Thus, each function in this concept class is equipped with a hierarchical structure defined by DenseNet, and our positive result is “using DenseNet to learn the hierarchical structure of an unknown DenseNet.”

We also point out that in our result, the learner network has neurons while the target network in the concept class has neurons. Thus, the learner network is over-parameterized, which is standard in deep learning. However, the necessity of over-parameterization here is for a very different reason comparing to existing theory work [6]. We will discuss more in Section 6.1.

1.2 Our Conceptual Message: How Deep Learning Performs Hierarchical Learning

Intuitively, the polynomials in our concept class are of degree and can depend on unknown directions of the input. Thus, when , using non-hierarchical learning methods, the typical sample or time complexity is at least (and we have shown lower bounds for kernel methods and linear regression over feature mappings). Even if the learner performs hierarchical learning for levels, it still cannot avoid learning in one level a degree- polynomial that depends on variables, which typically requires sample/time complexity .

In contrast, our quadratic DenseNet only uses sample and time complexity . The efficiency gain is due to the fact that the learning process is highly hierarchical: the network first learns a crude degree-4 approximation of the target function, and then it identifies hidden features and use them to fit a degree-8 approximation of the target function (using degree-2 polynomial over hidden features). Following this fashion, the learning task decomposes from learning a degree polynomial in one-shot which requires time , into learning one quadratic functions at a time for times, which can be done individually in time . This is, from a high level, where the efficiency gain comes from, but there is more to say:

Critical observation: backward feature correction.   In our quadratic DenseNet model, when training the first layer of the learner network, it tries to fit the target function using the best degree- polynomial. This polynomial might not be the one used by the target network due to over-fitting to higher-degree terms. As a concrete example, the best degree-4 polynomial to fit is usually not , even under Gaussian distribution. However, through the hierarchical learning process, those higher-degree terms in will gradually get discovered by the higher levels of the learner network and “subtracted” from the training objective.

As a result, the features (i.e., intermediate outputs) of lower levels of the learner network can get improved due to less over-fitting. We explain this phenomenon more carefully in Figure 2 for the case of a 4-layer network. It provides theoretical explanation towards how lower levels get improved through hierarchical learning when we train lower and higher levels together.

Figure 2: We explain the hierarchical learning process in a 4-layer example. The back and blue arrows correspond to “forward feature learning”. The red arrows correspond to “backward feature correction”.

Hierarchical learning is NOT layer-wise training.   Our result also shed lights on the following critical observation in practice: typically layer-wise training (i.e. if we train layers one by one starting from lower levels) 5 performs much worse comparing to training all the layers together. The fundamental reason is due to the missing piece of “backward feature correction”: The function learnt by the lower levels is not accurate enough if we only train lower levels; by training lower and higher levels together, the functions generated by lower layers also get improved.

Hierarchical learning is NOT simulating known (non-hierarchical) algorithms.   To the best of our knowledge, this seems to be the first theory result in the literature for training a neural network via SGD, to solve an underlying problem not yet known solvable by existing algorithms, such as kernel methods (including applying kernel methods multiple times), tensor decomposition methods, etc. Thus, neural network training could be indeed performing hierarchical learning, instead of simulating known (non-hierarchical) algorithms.

1.3 More on Related Works

Learning Two-Layer Network [36, 51, 56, 22, 50, 53, 12, 61, 41, 10, 42, 54, 23, 9, 46, 58, 41].   There is a rich history of works considering the learnability of neural networks trained by SGD. However, as we mentioned before, many of these works only focus on network with layers or only one layer in the network is trained. Hence, the learning process is not hierarchical.

Neural Tangent Kernel [6, 40, 3, 5, 18, 8, 7, 62, 17, 15, 35, 25, 43, 28, 57, 13].   There is a rich literature approximating the learning process of over-parameterized networks using the neural tangent kernel (NTK) approach, where the kernel is defined by the gradient of a neural network at random initialization [35]. We stress that one should not confuse this hierarchically-defined kernel with a multi-layer network with hierarchical learning in the paper. As we pointed out, hierarchical learning means that each layer learns a combination of previous layers. In NTK, such combinations are prescribed by the random initializations of the neural network, which are not learnt during the training process. As our negative result shows, for certain learning tasks, hierarchical learning is indeed superior than any kernel method, including those hierarchical-defined kernels prescribed from any neural networks. Hence, in this task, the learnt combinations are indeed superior to the randomly prescribed ones given by the initialization of the network.

Three Layer Result [4].   This paper shows that 3-layer neural networks can learn the so-called “second-order NTK,” which is not a linear model; however, second-order NTK is also learnable by doing a nuclear norm constraint linear regression over the feature mappings defined by the initialization of a neural network. Thus, the underlying learning process is still not truly hierarchical.

Three-Layer ResNet Result [2].   This recent paper shows that 3-layer ResNet can perform some weaker form of implicit hierarchical learning, with better sample or time complexity than any kernel method or linear regression over feature mappings. Our result is greatly inspired by [2], but with several major differences.

First and foremost, the result [2] can also be achieved by non-hierarchical methods such as simply applying kernel method twice— essentially layer-wise learning without backward feature correction.6 Thus, the work [2] is a weaker version of hierarchical learning without backward feature correction.

Second, we prove in this paper a “poly vs. super-poly” running time separation, which is what one refers to as “efficient vs non-efficient” in the traditional theoretical computer science language. In contrast, the result in [2] is regarding “poly vs. bigger poly” (in the standard regime where the output dimension is constant).7

Third, without backward feature correction, the error incurred from lower layers in [2] cannot be improved through training (see Footnote 6), and thus their theoretical result does not lead to arbitrarily small generalization error like we do in this paper. This also prevents [2] from going beyond layers; our result in this paper holds for every , demonstrating that going deeper in the hierarchy can actually have consistent advantage.

2 Target Network and Learner Network

We consider a target network defined as

where the weight matrices for every . Each index set is a subset of . We assume that (1). (so there is a connection to the immediate previous layer) and (2). for every , (so there is at least one skip connection).

We consider target functions consisting of the coordinate summation of each layer:

where ,8 and it satisfies and . We will provide more explanation of at Section 3. For analysis purpose, we use the convention if , and define

We remark here that for , is of degree and is of degree . It is convenient to think of as the “features” used by target network .

2.1 Learner Network

Typically, for theory, by training a learner neural network, the objective is to construct network of the same structure (possibly with over-parameterization) so it simulates :

Here, and we choose and for every . In other words, the amount of over-parameterization is quadratic (i.e., from ) per layer. We want to construct the weight matrices so that

Learner Network Re-parameterization.   In this paper, it is more convenient to consider the re-parameterized network : We first re-parameterize the weight matrix , where

  • are randomly initialized for all , not changed during training;

  • weights are trainable, for every and , and the dimension for and for .

Define functions , ,9 as well as (it is convenient to think of those as the “features” used by learner network )

(2.1)
(2.2)

We define its final output

and we shall use this function to fit the target .

It is easy to verify that when and when , by defining we have that each and . In this paper, we will mostly work with this re-parameterization for efficient training purpose. As we shall see, we will impose regularizers on during the training to enforce that they are close to each other. The idea of using a larger unit (i.e., ) for training and using a smaller unit (i.e., ) to learn the larger one is called knowledge distillation, which is commonly used in practice [31].

Truncated Quadratic Activation.   To make our analysis simpler, it would be easier to work with an activation function that has bounded derivatives in the entire space. For each layer , we consider a “truncated but smoothed” version of the square activation function defined as follows. For some sufficiently large (to be chosen later), and setting , let

and in the range , function can be chosen as any monotone increasing function such that are bounded for every . Our final choice of will make sure that when taking expectation over data, the difference between and is negligible.

Accordingly, we define the network with respect to the truncated activation as follows.

The truncated function is only for training propose to ensure the network is Lipschitz, so we can obtain efficient running time. The original quadratic activation does not have an absolute Lipschitz bound. We also use instead of when its clear from content.

For notational simplicity, we concatenate the weight matrices used in the same layer as follows:

2.2 Training Objective

For simplicity, we first state our result for regression problem in the realizable case, where we simply want to minimize the difference between the output of the learner network and the labels , we will state the result for agnostic case and for classification in the next section.

Intuitively, we shall add a regularizer to ensure that , that is is a low-rank approximation of . This ensures that . The main reason for this “low rank approximation” is explained in Section 6. Furthermore, we shall add a regression loss to minimize . This ensures that

Specifically, we use the following training objective:

where the loss is and

and for a given set consisting of i.i.d. samples from the true distribution , we minimize ( denotes is uniformly sampled from the training set )

(2.3)

The other regularizers we used are just (squared) Frobenius norm on the weight matrices, which are used everywhere in practice. For the original quadratic activation network, we also denote by

and .

3 Statement of Main Result

For simplicity, we only state here as a special case of our main theorem which is sufficiently interesting, and the full theorem can be found at Appendix A.

In this special case, there exists an absolute constants such that, for every , consider any target network and underlying data distribution satisfying some properties (namely, properties defined in Section 5 with ). Suppose in addition that the network width is diminishing and there is an information gap for . Moreover, we assume in the connection graph , meaning that the skip connections do not go very deep, unless directly connected to the input.

Theorem 3.1 (special case).

For every , every , every , and every target network satisfying the above parameters. Then, given i.i.d. samples from , by applying SGD over the training objective (2.3), with probability at least 0.99, we can find a learner network in time such that:

Note implies . Hence, when for instance , to achieve regression error, the learning algorithm has to truly learn all the layers of , as opposed to for instance ignoring the last layer which will incur error . We give more details about the training algorithm in Section 4.

Comparing to Kernel Methods.   The target function in is of degree . We show as a lower bound in Appendix H.1 that, any kernel method must suffer sample complexity when , even when all . This is due to the fact that kernel methods cannot perform hierarchical learning so have to essentially “write down” all the monomials of degree in order to express the target function, which suffers a lot in the sample complexity.

On the other hand, one might hope for a “sequential kernel” learning of this target function, by first applying kernel method to identify degree- polynomials used in (e.g., and for ), and then use them as features to learn higher degrees. We point out:

  • Even if we know all the degree- polynomials, the network width at layer can still be as large as , so we still need to learn a degree polynomial over dimension . This cannot be done by kernel method with sample complexity.

  • Even if we do “sequential kernel” for rounds, this is similar to layer-wise training and misses the crucial “backward feature correction.” As we pointed out in the intro, and shall later explain Section 6.1, this is unlikely to recover the target function to good accuracy.

  • Even if we do “sequential kernel” together with “backward feature correction”, this may not work since the backward correction may not lead to sufficient accuracy on intermediate features. Concretely, say we optimistically know the feature mappings up to error for , and fit the target function by kernel method on top of features . This does not mean we can obtain that is close to (sophisticated reasons deferred to Section 6). Thus, we cannot improve the quality of features of previous layers.10

{mdframed}

Significance of Our Result? To the best of our knowledge,

  • We do not know any other simple algorithm that can learn the target functions considered in this paper within the same efficiency, the only simple learning algorithm we are aware of is to train a neural network to perform hierarchical learning.

  • This seems to be the only theory result in the literature for training a neural network via SGD, to solve an underlying problem that is not known solvable by existing algorithms, such as kernel methods (including applying kernel methods multiple times), tensor decomposition methods, sparse coding, etc. Thus, the neural network is indeed performing hierarchical learning, instead of simulating known (non-hierarchical) algorithms.

Agnostic Learning.   Our theorem also works in the agnostic setting, where the labeling function satisfies and for some unknown . The SGD algorithm can learn a function with error at most for any constant given i.i.d. samples of . Thus, the learner can compete with the performance of the best target network. We present the result in Appendix A.4 and state its special case below.

Theorem 3.2 (special case, agnostic).

For every constant , in the same setting Theorem 3.1, given i.i.d. samples from and given their corresponding labels , by applying SGD over the training objective , with probability at least 0.99, we can find a learner network in time such that:

3.1 On Information Gap and classification problem

We have made a gap assumption . We can view this “gap” as that in the target function , higher levels contribute less to its output. This is typical for tasks such as image classification on CIFAR-10, where the first convolutional layer can already be used to classify of the data correctly. The higher-level layers have diminishing contributions to the signal (see Figure 3 for an illustration and we also refer to [60] for concrete measures). We emphasize, in practice, researchers do fight for even the final performance gain by going for (much) larger networks, so those higher-level functions can not be ignored.

To formally justify this gap assumption, it is also beneficial to consider a classification problem. Let us w.l.o.g. scale so that , and consider a two-class labeling function given as:

where is a Gaussian random variable independent of . Here, can be viewed either a coordinate of the entire input , or more generally as linear direction for the input . For notation simplicity, we focus on the former view.

Using probabilistic arguments, one can derive that except for fraction of the input , the label function is fully determined by the target function up to layer ; or in symbols,11

In other words,

is (approximately) the increment in classification accuracy

when we use an -layer network comparing to -layer ones

Therefore, the gap assumption is equivalent to saying that harder data (which requires deeper networks to learn correctly) are fewer in the training set, which can be very natural. For instance, around 70% images of the CIFAR-10 data can be classified correctly by merely looking at their rough colors and patterns using a one-hidden-layer network. Only the final accuracy gain requires much refined arguments such as whether there is a beak on the animal face which can only be detected using very deep networks. As another example, humans use much more training examples to learn counting, than to learn basic calculus, than to learn advanced calculus.

We refer the readers to Figure 3 which shows that indeed the increment in accuracy as we go deeper in neural networks is diminishing.

Figure 3: Performance of ResNet on CIFAR-10 dataset with various depths. One can confirm that deeper layers have diminishing contributions to the classification error. Experiment details in Appendix J.

In this classification regime, our Theorem 3.1 still applies as follows. Recall the cross entropy (i.e., logistic loss) function where is the label and is the prediction. In this regime, we can choose a training loss function

where the parameter is around is for proper normalization and the training objective is

(3.2)

We have the following corollary of Theorem 3.1:

Theorem 3.3.

In the same setting Theorem 3.1, and suppose additionally . Given i.i.d. samples from and given their corresponding labels , by applying SGD over the training objective , with probability at least 0.99, we can find a learner network in time such that:

Intuitively, Theorem 3.3 is possible because under the choice of , up to small multiplicative factors, “  -loss equals   ” becomes near identical to “cross-entropy loss equals ”. This is also why we need to add a factor in from of the regularizers in (3.2). We make this more rigorous in Appendix G (see Proposition G.43).

4 Training algorithm

We describe our algorithm in Algorithm 1.12 It is almost the vanilla SGD algorithm: in each iteration, it gets a random sample , computes (stochastic) gradient in , and moves in the negative gradient direction with step length .

Besides standard operations such as setting learning rates and regularizer weights, our only difference from SGD is to invoke (at most times) the k-SVD decomposition algorithm to obtain a warm-start for each matrix when it first becomes available. This warm-up is mainly for theoretical purpose to avoid singularities in and it serves little role in actually learning . Essentially all of the learning is done by SGD.13

We emphasize once again that, when layer begins to train (by setting step length to be nonzero), Algorithm 1 continues to train all layers . This helps to “correct” the error in layer (recall “backward feature correction” and Figure 1). The algorithm does not work if one just trains layer and ignores others.

We specify the choices of thresholds and , and the choices of regularizer weights in full in Section A. Below, we calculate their values in the special case Theorem 3.1.

(4.1)

As for the network width , sample size , and SGD learning rate , in the special case Theorem 3.1 one can set , and .

0:  Data set of size , network size , learning rate , target error .
1:  current target error ;  ;  ;  ;   for every .
2:  while  do
3:     while  do
4:        for  do
5:           if  and  then
6:              , . \hfill  
7:           end if
8:           if  and  then
9:              set according to (4.1)
10:              , , .
11:           end if
12:           . \hfill  for a random sample
13:           . \hfill   is any polynomially-small Gaussian noise;
14:        end for \hfill   is for theory purpose to escape saddle points [21].
15:     end while
16:      and for every .
17:  end while
18:  return   and , representing .
Algorithm 1 SGD for DenseNet

5 Assumptions on Target Network and Distribution

Target Network.   We assume the target network satisfies the following properties

  1. (monotone) .

  2. (normalized) for some for all and .

  3. (well-conditioned) the singular values of are between and for all pairs.

Properties are standard and satisfied for many practical networks (in fact, many practical networks have weight matrices close to unitary, see e.g. [34]).

For property , although there exists worst case matrices with , we would like to point out when each is of the form where are random row/column orthonormal matrices, then with probability at least , it holds that as long as .14 Another view is that practical networks are all equipped with batch-normalization, which ensures that .

Input Distribution.   We assume the input distribution has the following property:

  1. (isotropic). There is an absolute constant such that for every , we have that

    (5.1)
  2. (degree-preserving). For every positive integer , there exists positive value such that for every polynomial over with maximum degree , let be the polynomial consists of all the degree exactly monomials of , then the following holds

    (5.2)

    For , such inequality holds with (can be easily proved using Hermite polynomial expansion). 15

  3. (hyper-contractivity). There exists absolute constant such that, for some value , we have: for every degree polynomial .

    (5.3)

    If , we have (see Lemma 2). This implies that for some value , we also have, for every degree polynomial , for every integer ,

    (5.4)

    If , we have ; and more generally we have .

Assumptions 1 and 3 are very common assumptions for distributions, and they are satisfied for sub-gaussian distributions or even heavy-tailed distributions such as . Assumption 2 says that the data has certain variance along every “high-degree directions”, which is also typical for distributions such like Gaussian or heavy-tailed distributions.

We would like to point out that it is possible to have a distribution satisfying assumption 2 to be a mixture of -distributions, where non of the individual distribution satisfies Assumption 2. For example, the distribution can be a mixture of -distributions, the -th distribution satisfies that and other coordinates are i.i.d. standard Gaussian. Thus, non of the individual distribution is degree-preserving, however, the mixture of them is as long as .

It is easy to check that simple distributions satisfying the following parameters.

Proposition 5.4.

Our distributional assumption is satisfied for when , where has constant singular values, it is also satisfied for a mixture of arbitrarily many ’s as long as each has constant singular values and for each , the -th row: has the same norm for every .

In the special case of the main theorem stated in Theorem 3.1, we work with the above parameters. In our full Theorem A.6, we shall make the dependency of those parameters transparent.

6 Proof Intuitions

In this high-level intuition let us first ignore the difference between truncated activations and the true quadratic activation. We shall explain at the end why we need to do truncation.

6.1 A Though Experiment

We provide intuitions about the proof by first considering the following extremely simplified example: , , and for some . In our language, due to notational convenience, refers to having only two trainable layers, that we refer to as the second and third layers.

Richer representation by over-parameterization.   Since , one would hope for the second layer of the network to learn and (by some representation of its neurons), and feed this as an input to the third layer. If so, the third layer could learn a quadratic function over