When Recurrent Models Don’t Need To Be Recurrent
We prove stable recurrent neural networks are well approximated by feed-forward networks for the purpose of both inference and training by gradient descent. Our result applies to a broad range of non-linear recurrent neural networks under a natural stability condition, which we observe is also necessary. Complementing our theoretical findings, we verify the conclusions of our theory on both real and synthetic tasks. Furthermore, we demonstrate recurrent models satisfying the stability assumption of our theory can have excellent performance on real sequence learning tasks.
When Recurrent Models Don’t Need To Be Recurrent
John Miller University of California, Berkeley firstname.lastname@example.org Moritz Hardt University of California, Berkeley email@example.com
noticebox[b]Preprint. Work in progress.\end@float
Recurrent neural networks are a popular modeling choice for solving sequence learning problems arising in domains such as speech recognition and natural language processing. At the outset, recurrent neural networks are non-linear dynamical systems commonly trained to fit sequence data via some variant of gradient descent.
Recurrent models feature flexibility and expressivity that come at a cost. Empirical experience shows these models are often more delicate to tune and more brittle to train pascanu2013difficulty () than standard feed-forward architectures. Recurrent architectures can also introduce significant computational burden compared with feed-forward implementations.
In response to these shortcomings, a growing line of empirical research succeeds in replacing recurrent models effectively by feed-forward models in important applications, including translation vaswani2017attention (); gehring2017convolutional (), speech synthesis van2016wavenet (), and language modeling dauphin2017language ().
This development raises an intriguing question for theoretical investigation:
Can well-behaved recurrent neural networks in principle always be replaced by feed-forward models of comparable size without loss in performance?
To answer this question, we need to understand what class of recurrent neural networks we ought to call well-behaved. In principle, it easy to contrive non-linear recurrent models that on some input sequences cannot be approximated by feed-forward models. But would such recurrent models be trainable by gradient descent?
Characterizing exactly which recurrent models are learnable by gradient descent is a delicate task beyond the reach of current theory. We will therefore instead work the fundamental control-theoretic notion of stability. This criterion roughly agrees with the requirement that the gradients of the training objective do not explode over time.
Loosely speaking, we prove stable recurrent models have good feed-forward approximations. Moreover, we prove that if gradient descent succeeds in training the recurrent model, it will also succeed in training the feed-forward model and vice-versa. This shows not only are the models equivalent for inference, they are also equivalent for training via gradient descent.
Of course, it is easy to violate the stability assumption; systems used in practice don’t necessarily satisfy stability. However, we experimentally show stability can be enforced without loss of performance on a benchmark sequence task. In other cases not captured by our theory, we show competitive sequence models exhibit the same qualitative phenomena– for instance, limited sensitivity to inputs in the distant past– that allow us to approximate stable recurrent models by feed-forward models.
We prove stable recurrent models do not need to be recurrent, and we give experimental evidence suggesting this conclusion extends to a broader class of commonly used models. Taken together, our theory and experiments suggest that recurrent models are not an inevitable choice in sequence learning.
In this work, we make the following contributions.
We identify stability as a natural requirement for the analysis of recurrent models and show, under the stability assumption, feed-forward networks can approximate recurrent networks for both inference and training.
We provide a unified analysis for general non-linear dynamical systems, and we complement this analysis with sufficient conditions that imply the assumptions of our theorems for several commonly used model classes, including long short-term memory (LSTM) networks.
We empirically validate our results on synthetic data and show the same principles and phenomena underlying our theoretical analysis also appear in competitive models trained on a benchmark language modeling task.
2 Problem statement and results
We consider general non-linear dynamical systems given by a differentiable state-transition map parameterized by The hidden state evolves in discrete time steps according to the update rule
Here, the vector is an arbitrary input provided to the system at time . This formulation allows us to unify the analysis for several examples of interest, including linear dynamical systems, recurrent neural networks (RNN), and Long Short-Term Memory (LSTM) networks. For instance, in the RNN case, given , and a pointwise non-linearity , the system evolves according to
We assume the state transition map is smooth in and , and the initial state . Without loss of generality, we also assume for all . Otherwise, we can reparameterize without affecting expressivity of .
Throughout this paper, we focus on stable recurrent models. This corresponds to assuming the state-transition map is contractive, so there exists some such that, for any weights , states , and input ,
In the RNN case, stability corresponds to requiring , where is the Lipschitz constant of . More broadly, we assume there some compact convex domain so that the map is -contractive for all .
We study when the system (1) can be approximated by a feed-forward model with finite context. While there are many choices for a feed-forward approximation, we consider the simplest one—truncation of the system to some finite context . In other words, the feed-forward approximation moves over the input sequence with a sliding window of length producing an output every time the sliding window advances by one step. Formally, for context length chosen in advance, we define the truncated model via the update rule
Note that is a function only of the previous inputs , and can be implemented as an autoregressive, depth- feed-forward model.
Let denote a prediction function that maps a state to outputs . Let denote the predictions from the truncated model. To simplify the presentation, the prediction function is not parameterized. This is without loss of generality because it’s always possible to fold the parameters into the system itself. In the sequel, we study both during and after training.
2.1 Our results
Our first result concerns inference in stable recurrent models. For fixed weights , the predictions of truncated model well approximate the predictions of the full recurrent model at test-time.
Proposition (Informal version of Proposition 1).
Assuming the system is -contractive and under additional Lipschitz assumptions, we show if , then the difference in predictions between the recurrent and truncated model is negligible, .
Equipped with our approximation result, we turn towards optimization. We prove if it is possible to train a stable recurrent model via gradient descent to perform well on some task, then it is possible to get equally good performance by instead training an autoregressive feed-forward model.
Concretely, suppose both the full recurrent model and the truncated model are initialized at a common point and optimized to minimize some loss function on a common sequence of inputs. This results in a weight vector for the full recurrent model and a weight vector for the truncated model. We show that for truncation parameter , after steps of gradient descent, the weights of the recurrent and feed-forward model are -close in Euclidean distance.
Theorem (Informal version of Theorem 1).
Assume the system is -contractive. Under additional smoothness and Lipschitz assumptions on the system , the prediction function , and the loss , we show if
then after steps of projected gradient descent with decaying step size , , which in turn implies
In practice the cost of training a fully recurrent model can be prohibitive, in which case truncation is commonly used for computational reasons. Our theorem gives reassurance that this truncation step does not hurt training performance. Contrast this with operations like compression and weight sparsification of a neural net, which done after training do not hurt inference but can certainly make optimization harder by reducing the number of trainable model parameters.
2.2 Related work
In the linear dynamical system setting, tu2017non () exploit the connection between stability and a truncated system approximation to prove bounds on the number of samples needed to learn a truncated approximation to the full stable system. Their approximation result is the same as our inference result in the linear dynamical system case, and we extend this result to the non-linear setting. We also analyze the impact of truncation on training with gradient descent. Results of this kind are completely new to our knowledge.
Learning dynamical systems with gradient descent has been a recent topic of interest in the machine learning community, For instance, hardt2016gradient () show gradient descent can efficiently learn stable, linear dynamical systems. In contrast, our analysis controls the difference between the truncated and full-system solutions obtained by gradient descent. Roughly speaking, these results can be combined with ours to show, when gradient descent succeeds for a class of stable dynamical systems, it succeeds for the truncated systems as well. Work by sedghi2016training () gives a moment-based approach for learning some classes of non-linear recurrent neural networks.
The vanishing gradient problem was first introduced in bengio1994learning () and further explored in pascanu2013difficulty (). Our work is complementary to both of these papers; while they view the vanishing gradient problem primarily as an optimization issue to be overcome, we interpret vanishing gradients as a representational limitation that restricts the power of recurrent architectures. In particular, recurrent models with vanishing gradients can be well approximated by feed-forward models with limited context. Further, this result applies not just at inference time, but throughout training via gradient descent.
From an empirical perspective, bai2018empirical () conducted a detailed evaluation of recurrent and convolutional, feed-forward models on a variety of sequence modeling tasks. In diverse settings, they reliably find feed-forward models outperform their recurrent counterparts. However, their work does not offer an principled explanation for this phenomenon.
Our training time analysis builds on the stability analysis of gradient descent in hardt2016train (), but interestingly uses it for an entirely different purpose.
3 Approximation during inference
Suppose we train a full recurrent model and obtain a prediction . For an appropriate choice of context , the truncated model makes essentially the same prediction as the full recurrent model. To show this result, we first control the difference between the hidden states of both models.
Assume is -contractive and -Lipschitz in . Assume the input sequence for all . If , then the difference in hidden states .
Lemma 1 effectively says stable models do not have long-term memory– distant inputs do not change the states of the system. For this reason, it is a key building block for both our inference and our subsequent training-time analysis. A full proof is given in the appendix. If the prediction function is Lipschitz, Lemma 1 immediately implies that the predictions between the recurrent and truncated model are nearly identical. This leads us to the following proposition.
If is a -Lipschitz and -contractive map, and is Lipschitz and then .
4 Approximation during training
In this section, we show gradient descent for stable recurrent models finds essentially the same solutions as gradient descent for truncated models. Consequently, both the recurrent and truncated models found by gradient descent make essentially the same predictions.
Our proof technique is to initialize both the recurrent and truncated models at the same point and track the divergence in weights throughout the course of gradient descent. Roughly, we show if , then after steps of gradient descent, the difference in the weights between the recurrent and truncated models is at most .
Even if the gradients are similar for both models at the same point, it is a priori possible that slight differences in the gradients accumulate over time and lead to divergent weights where no meaningful comparison is possible. Building on similar techniques as hardt2016train (), we show that gradient descent itself is stable, and this type of divergence cannot occur.
The gradient descent result requires two essential lemmas. The first bounds the difference in gradient between the full and the truncated model. The second establishes the gradient map of both the full and truncated models is Lipschitz. We defer proofs of both lemmas to the appendix.
Let denote the loss function evaluated on recurrent model after time steps, and define similarly for the truncated model.
Assume (and therefore ) is Lipschitz and smooth. Assume is smooth, -contractive, and Lipschitz in and . Assume the inputs satisfy , then
where , suppressing dependence on the Lipschitz and smoothness parameters.
For any , suppose is smooth, -contractive, and Lipschitz in . If is Lipschitz and smooth, then
where , suppressing dependence on the Lipschitz and smoothness parameters.
Let be the weights of the recurrent model on step and define similarly for the truncated model. At initialization, . For sufficiently large, Lemma 2 guarantees the difference between the gradient of the recurrent and truncated models is negligible. Therefore, after a gradient update, is small. Lemma 3 then guarantees that this small difference in weights does not lead to large differences in the gradient on the subsequent time step. For an appropriate choice of learning rate, formalizing this argument leads to the following proposition.
The decaying step size in our theorem is consistent with the regime in which gradient descent is known to be stable for non-convex training objectives hardt2016train (). While the decay is faster than many learning rates encountered in practice, classical results nonetheless show that with this learning rate gradient descent still converges to a stationary point; see p. 119 in bertsekas99nonlinear () and references there. In section 7, we give empirical evidence the rate is necessary for our theorem and show examples of stable systems trained with constant or rates that do not satisfy our bound.
Critically, the bound in Proposition 2 goes to 0 as . In particular, if we take and , then after steps of projected gradient descent, . For this choice of , we obtain the main theorem. The proof is left to the appendix.
Let be Lipschitz and smooth. Assume is smooth, -contractive, Lipschitz in and . Assume the inputs are bounded, and the prediction function is -Lipschitz. If , then after steps of projected gradient descent with step size , .
5 Counterexamples without stability
While the stability assumption on the state-transition map might seem limiting, it is in fact necessary on two counts. First, without stability, there are trivial counterexamples where finite-length truncation can be arbitrarily bad, even for large values of . This alone rules out both the inference and optimization results without additional assumptions. Second, without stability, it is difficult to show gradient descent converges to a stationary point, even in the linear dynamical system case. Indeed, there exist simple counterexamples where gradient descent fails to converge. Both points are made precise in the propositions below, and the proofs are deferred to the appendix.
There exists an unstable system such that, for any finite truncation length , as .
There exists a system such that, if is not constrained to the set where is stable, then gradient descent does not converge to a stationary point, and as the number of iterations .
6 Examples of stable models
Our results are stated in the language of general non-linear dynamical systems, and our assumptions are given in terms of a generic state-to-state transition map . In this section, we show how linear dynamical systems, recurrent neural networks, and LSTMs fit into this general framework and give non-trivial sufficient conditions to ensure stability for each class.
6.1 Linear dynamical systems
Given matrices , the state-transition map for a linear dynamical system is
The model is stable provided , and Lipschitz in provided is bounded. For a stable linear dynamical system, it is easy show , and consequently the model is Lipschitz in . It’s a simple exercise to check that such a system satisfies the remaining Lipschitz and smoothness assumptions.
6.2 Recurrent neural networks
Given a Lipschitz, point-wise non-linearity and matrices and , the state-transition map for a recurrent neural network (RNN) is
If is -Lipschitz, the model is stable provided . Indeed, for any states , and any ,
Our remaining Lipschitz and smoothness assumptions are satisfied if is smooth and . For concreteness, in the appendix, we show each of the assumptions holds when is , which is 1-smooth and 1-Lipschitz. On the other hand, our results do not apply for the non-smooth ReLu.
6.3 Long short-term memory networks
Long Short-Term Memory (LSTM) networks are another commonly used class of sequence models hochreiter1997long (). The state is a pair of vectors , and the model is parameterized by eight matrices, and , for . The state-transition map is given by
where denotes elementwise multiplication, and is the logistic function.
We provide conditions under which the iterated system is stable. Let . If the weights and inputs are bounded, then since for any finite input. This means the next state must “forget” a non-trivial portion of . We leverage this phenomenon to give sufficient conditions for to be contractive in the norm, which in turn implies the iterated system is contractive in the norm for . Let denote the induced matrix norm, which corresponds to the maximum absolute row sum
If , , , and , then the iterated system is stable on the set of reachable states.
The proof is given in the appendix. As a consequence of Proposition 5, both Proposition 1 and Theorem 1 apply to at the cost of an additional factor in the choice of truncation length . We leave it as an open problem to find different parameter regimes where the system is stable, as well as resolve whether the original system is stable.
In the experiments, we verify the conclusions of our theoretical investigations using synthetic data and present evidence our results hold beyond settings captured by our theory using a benchmark language modeling task, WikiText-2 merity2016pointer (). All of our language modeling experiments use publicly available code,111 https://github.com/pytorch/examples/tree/master/word_language_model and details of the hyperparameters for all experiments are given in the appendix.
7.1 Understanding the main gradient descent bound
The key result underlying Theorem 1 is the bound on the parameter difference while running gradient descent obtained in Proposition 2. We show this bound has the correct qualitative scaling using random instances.
Concretely, we sample random Gaussian input sequences and for . We fix and randomly initialize a stable linear dynamical system or -RNN (details in appendix). We fix the truncation length to , set the learning rate to for , and take gradient steps. These parameters are chosen so that the bound does not become vacuous – by triangle inequality, we always have .
In Figure 1(b), we plot the parameter error as training progresses for both a linear dynamical system and a recurrent neural network with non-linearities (averaged over 10 runs). The error scales comparably with the bound given in Proposition 2. We also find for larger step-sizes like or constant (omitted from the plot to reduce clutter), the bound fails to hold, suggesting the condition is necessary.
7.2 Moving beyond stability
Our theoretical results require stability, and Section 5 shows stability is necessary without further assumptions. On the other hand, recurrent models trained in practice are not a-priori stable. We first show imposing the stability constraint need not significantly change the performance of benchmark models. We then provide evidence phenomena like vanishing gradients and truncated system approximation hold outside of settings captured by our theory, posing challenges for future work.
Enforcing stability doesn’t hurt performance.
Stable models can achieve good performance on Wikitext-2. To demonstrate this, we trained two single layer recurrent neural networks. The first model is unconstrained, and the second model is constrained to , which ensures stability by Section 6. Concretely, after each gradient update, we project the hidden-to-hidden matrix onto the spectral norm ball by computing the SVD and thresholding the singular values to lie in . All of the hyperparameters were chosen via grid-search to maximize the performance of the unconstrained model. At convergence, there is little difference between the two models. The unconstrained model achieves a final test perplexity of 146.7, whereas the stable, constrained model achieves a final test perplexity of 143.5.
The importance of this result is two-fold. First, our theory applies to models that achieve reasonable performance on a benchmark task. Second, it suggests other conditions on the data distribution or weight matrices combine so models trained in practice are effectively stable and thus permit approximation by feed-forward networks.
The central phenomena that makes feed-forward approximation during training possible is vanishing gradients. Indeed, vanishing gradients are a key ingredient in our proof of Lemma 2. LSTMs and recurrent neural networks trained in practice exhibit vanishing gradients and limited sensitivity to past inputs beyond what’s guaranteed by our theory. In Figure (2), we train an LSTM and RNN on Wikitext-2 and plot for for ranging over the validation set at the end of each epoch. We do not enforce the spectral norm constraint in either case. The LSTM and the RNN both suffer from limited sensitivity to distant inputs at initialization and throughout training. The gradients of the LSTM vanish more slowly than those of the RNN, but both models exhibit the same qualitative behavior.
Approximation during training.
In settings not captured by our existing results, a phenomenon similar to Proposition 2 holds empirically. In particular, grows slowly with , and the rate of this growth decreases as the value of the truncation parameter increases. As a representative example, we trained truncated RNN and LSTM models on Wikitext-2 for various values of . Training the full recurrent model is impractical, and hence we assume well captures the full-recurrent model. All of the models are initialized at the same point, and we track the distance between the hidden-to-hidden matrices as training progresses.
In Figure (3), we plot for as training proceeds. In the RNN case, denotes the recurrent matrix trained with truncation length , and in the LSTM case, denotes the concatenation of trained with truncation length .
After an initial rapid increase in distance, grows slowly, similar to the results obtained in the random Gaussian case and in Proposition 2. Moreover, as suggested by our theory, there is a diminishing return to choosing larger values of the truncation parameter in terms of the accuracy of the approximation.
-  Shaojie Bai, J Zico Kolter, and Vladlen Koltun. An empirical evaluation of generic convolutional and recurrent networks for sequence modeling. arXiv preprint arXiv:1803.01271, 2018.
-  Yoshua Bengio, Patrice Simard, and Paolo Frasconi. Learning long-term dependencies with gradient descent is difficult. IEEE transactions on neural networks, 5(2):157–166, 1994.
-  Dimitri P. Bertsekas. Nonlinear Programming. Athena Scientific, 1999.
-  Yann N Dauphin, Angela Fan, Michael Auli, and David Grangier. Language modeling with gated convolutional networks. In International Conference on Machine Learning, pages 933–941, 2017.
-  Jonas Gehring, Michael Auli, David Grangier, Denis Yarats, and Yann N Dauphin. Convolutional sequence to sequence learning. In International Conference on Machine Learning, pages 1243–1252, 2017.
-  Moritz Hardt, Tengyu Ma, and Benjamin Recht. Gradient descent learns linear dynamical systems. arXiv preprint arXiv:1609.05191, 2016.
-  Moritz Hardt, Benjamin Recht, and Yoram Singer. Train faster, generalize better: Stability of stochastic gradient descent. In International Conference on Machine Learning, pages 1225–1234, 2016.
-  Sepp Hochreiter and Jürgen Schmidhuber. Long short-term memory. Neural computation, 9(8):1735–1780, 1997.
-  Stephen Merity, Caiming Xiong, James Bradbury, and Richard Socher. Pointer sentinel mixture models. arXiv preprint arXiv:1609.07843, 2016.
-  Razvan Pascanu, Tomas Mikolov, and Yoshua Bengio. On the difficulty of training recurrent neural networks. In International Conference on Machine Learning, pages 1310–1318, 2013.
-  Hanie Sedghi and Anima Anandkumar. Training input-output recurrent neural networks through spectral methods. CoRR, abs/1603.00954, 2016.
-  Stephen Tu, Ross Boczar, Andrew Packard, and Benjamin Recht. Non-asymptotic analysis of robust control from coarse-grained identification. arXiv preprint arXiv:1707.04791, 2017.
-  Aaron Van Den Oord, Sander Dieleman, Heiga Zen, Karen Simonyan, Oriol Vinyals, Alex Graves, Nal Kalchbrenner, Andrew Senior, and Koray Kavukcuoglu. Wavenet: A generative model for raw audio. arXiv preprint arXiv:1609.03499, 2016.
-  Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. In Advances in Neural Information Processing Systems, pages 6000–6010, 2017.
Appendix A Deferred proofs
a.1 Proofs from section 3
Proof of Lemma 1.
For any , by triangle inequality,
Applying the stability and Lipschitz assumptions and then summing a geometric series,
Now, consider the difference between hidden states at time step . Unrolling the iterates steps and then using the previous display yields
and solving for gives the result. ∎
a.2 Proofs from section 4
Before proceeding, we introduce notation for our smoothness assumption. We assume the map satisfies four smoothness conditions: for any reachable states , and any weights , there are some scalars such that
a.2.1 Gradient difference due to truncation is negligible
In the section, we argue the difference in gradient with respect to the weights between the recurrent and truncated models is . For sufficiently large (independent of the sequence length), the impact of truncation is therefore negligible. The proof leverages the “vanishing-gradient” phenomenon– the long-term components of the gradient of the full recurrent model quickly vanish. The remaining challenge is to show the short-term components of the gradient are similar for the full and recurrent models.
Proof of Lemma 2.
The Jacobian of the loss with respect to the weights is
where is the partial derivative of with respect to , assuming is constant with respect to . Expanding the expression for the gradient, we wish to bound
The first term consists of the “long-term components” of the gradient for the recurrent model. The second term is the difference in the “short-term components” of the gradients between the recurrent and truncated models. We bound each of these terms separately.
For the first term, by the Lipschitz assumptions, and . Since is -contractive, so . Using submultiplicavity of the spectral norm,
Focusing on the second term, by triangle inequality and smoothness,
Using Lemma 1 to upper bound (a),
Using the triangle inequality, Lipschitz and smoothness, (b) is bounded by
where the last line used for . It remains to bound (c), the difference of the hidden-to-hidden Jacobians. Peeling off one term at a time and applying triangle inequality, for any ,
so (c) is bounded by . Ignoring Lipschitz and smoothness constants, we’ve shown the entire sum is . ∎
a.2.2 Stable recurrent models are smooth
In this section, we prove that the gradient map is Lipschitz. First, we show on the forward pass, the difference between hidden states and obtained by running the model with weights and , respectively, is bounded in terms of . Using smoothness of , the difference in gradients can be written in terms of , which in turn can be bounded in terms of . We repeatedly leverage this fact to conclude the total difference in gradients must be similarly bounded.
We first show small differences in weights don’t significantly change the trajectory of the recurrent model.
For some , suppose are -contractive and Lipschitz in . Let be the hidden state at time obtain from running the model with weights on common inputs . If , then
By triangle inequality, followed by the Lipschitz and contractivity assumptions,
Iterating this argument and then using , we obtain a geometric series in .