A Unified Framework of Online Learning Algorithms for Training Recurrent Neural Networks
We present a framework for compactly summarizing many recent results in efficient and/or biologically plausible online training of recurrent neural networks (RNN). The framework organizes algorithms according to several criteria: (a) past vs. future facing, (b) tensor structure, (c) stochastic vs. deterministic, and (d) closed form vs. numerical. These axes reveal latent conceptual connections among several recent advances in online learning. Furthermore, we provide novel mathematical intuitions for their degree of success. Testing various algorithms on two synthetic tasks shows that performances cluster according to our criteria. Although a similar clustering is also observed for gradient alignment, alignment with exact methods does not alone explain ultimate performance, especially for stochastic algorithms. This suggests the need for better comparison metrics.
Keywords: real-time recurrent learning, backpropagation through time, approximation, biologically plausible learning, local, online
Training recurrent neural networks (RNN) to learn sequence data is traditionally done with stochastic gradient descent (SGD), using the backpropagation through time algorithm (BPTT, Werbos et al., 1990) to calculate the gradient. This requires “unrolling” the network over some range of time steps and performing backpropagation as though the network were feedforward under the constraint of sharing parameters across time steps (“layers”). BPTT’s success in a wide range of applications (Mikolov et al., 2010; Graves, 2013; Bahdanau et al., 2016, 2014; Cho et al., 2015; Graves et al., 2016) has made it the industry standard; however, there exist alternative online algorithms for training RNNs. These compute gradients in real time as the network runs forward, without explicitly referencing past activity or averaging over batches of data. There are two reasons for considering online alternatives to BPTT. One is practical: computational costs do not scale with . The other is conceptual: human brains are able to learn long-term dependencies without explicitly memorizing all past brain states, and understanding online learning is a key step in the larger project of understanding human learning.
The classic online learning algorithm is real-time recurrent learning (RTRL, Williams and Zipser, 1989), which is equivalent to BPTT in the limit of a small learning rate (Murray, 2019). RTRL recursively updates the total derivative of the hidden state with respect to the parameters, eliminating the need to reference past activity but introducing an order memory requirement. In practice, this is often more computationally demanding than BPTT (order ), hence not frequently used in applications. Nor is RTRL at face value a good model of biological learning, for the same reason: no known biological mechanism exists to store—let alone manipulate—a float for each synapse-neuron pair. Thus RTRL and online learning more broadly have remained relatively obscure footnotes to both the deep learning revolution itself and its impact on computational neuroscience.
Recent advances in recurrent network architectures have brought the issue of online learning back into the spotlight. While vanishing gradients used to significantly limit the extent of the temporal dependencies that an RNN could learn, new architectures like LSTMs (Hochreiter and Schmidhuber, 1997) and GRUs (Cho et al., 2014) have dramatically expanded this learnable time horizon. Unfortunately, taking advantage of this capacity requires an equally dramatic expansion in computational resources, if using BPTT. This has led to an explosion of novel online learning algorithms (Tallec and Ollivier, 2017; Mujika et al., 2018; Roth et al., 2019; Murray, 2019; Jaderberg et al., 2017) which aim to improve on the efficiency of RTRL, in many cases using update rules that might be implemented by a biological circuit.
The sheer number and variety of these approaches pose challenges for both theory and practice. It is not always completely clear what makes various algorithms different from one another, how they are conceptually related, or even why they might work in the first place. There is a pressing need in the field for a cohesive framework for describing and comparing online methods. Here we aim to provide a thorough overview of modern online algorithms for training RNNs, in a way that provides a clearer understanding of the mathematical structure underlying all these different approaches. Our framework organizes the existing literature along several axes that encode meaningful conceptual distinctions:
Past facing vs. future facing
The tensor structure of the algorithm
Stochastic vs. deterministic update
Closed form vs. numerical solution for update
These axes will be explained in detail later, but briefly: the past vs. future axis is a root distinction that divides algorithms by the type of gradient they calculate, while the other three describe their representations and update principles. Table 1 contains (to our knowledge) all recently published online learning algorithms for RNNs, categorized according to these criteria. We can already see that many combinations of these characteristics manifest in the literature, suggesting that new algorithms could be developed by mixing and matching properties. (We provide a concrete example of this in §3.4.)
Here we describe each algorithm in unified notation that makes clear their classification by these criteria. In the process, we generate novel intuitions about why different approximations can be successful and discuss some of the finer points of their biological plausibility. Finally, we simulate each algorithm on a common set of synthetic tasks with vanilla RNN architecture for simplicity. We compare performance and analyze gradient alignments to see to what extent their empirical similarity is predicted by their similarity according to our framework. Algorithm performance roughly clusters according to criteria (a)-(d) across tasks, lending credence to our approach. Curiously, gradient alignment with exact methods (RTRL and BPTT) does not predict performance, despite its ubiquity as a tool for analyzing approximate learning algorithms.
2 Past- and future-facing perspectives of online learning
Before we dive into the details of these algorithms, we first articulate what we mean by past- and future-facing, related to the “reverse/forward accumulation” distinction concurrently described by Cooijmans and Martens (2019). Consider a recurrent neural network that contains, at each time step , a state . This state is updated via a function , which is parameterized by a flattened vector of parameters . Here counts the total number of input dimensions, including the recurrent inputs , task inputs , and an additional input clamped to (to represent bias). For some initial state , defines the network dynamics by
At each time step an output is computed by another function , parameterized by . We will typically choose an affine-softmax readout for , with output weights/bias . A loss function calculates an instantaneous loss , quantifying to what degree the predicted output matches the target output .
The goal is to train the network by gradient descent (or other gradient-based optimizers such as ADAM from Kingma and Ba, 2014) on the total loss w.r.t. the parameters and . It is natural to learn online, because only information at present time is required to calculate the gradient . So the heart of the problem is to calculate .
The parameter is applied via at every time step, and we denote a particular application of at time as . Of course, a recurrent system is constrained to share parameters across time steps, so a perturbation is effectively a perturbation across all applications , i.e., . In principle, each application of the parameters affects all future losses , . The core of any recurrent learning algorithm is to estimate the influence of one parameter application on one loss , since these individual terms are necessary and sufficient to define the global gradient
This raises the question of how to sum these components to produce individual gradients to pass to the optimizer. In truncated BPTT, one unrolls the graph over some range of time steps and sums for all in that range with (see §4.1.1). This does not qualify as an “online” learning rule, because it requires two independent time indices—at most one can represent “real time” leaving the other to represent the future or the past. If we can account for one of the summations via dynamic updates, then the algorithm is online or temporally local, i.e. not requiring explicit reference to the past or future. As depicted in Fig. 1, there are two possibilities. If from Eq. (1) corresponds to real time, then the gradient passed to the optimizer is
In this case, we say learning is past facing, because the gradient is a sum of the influences of past applications of on the current loss. On the other hand, if from Eq. (1) represents real time, then the gradient passed to the optimizer is
Here we say learning is future facing, because the gradient is a sum of influences by the current application of on future losses.
2.1 Past-facing online learning algorithms
Here we derive a fundamental relation leveraged by past-facing (PF) online algorithms. Let index real time, and define the influence matrix , where and are respectively the number of hidden units and the number of parameters defining . tracks the derivatives of the current state with respect to each parameter :
Let’s rewrite Eq. (4) with matrix notation and unpack it by one time step:
A simple recursive formula emerges, wherein the influence matrix is updated by multiplying its current value by the Jacobian of the network and then adding the immediate influence . To compute the gradient that ultimately gets passed to the optimizer, we simply use the chain rule over the current hidden state :
where the immediate credit assignment vector is defined to be and is calculated by backpropagating the error through the derivative of the output function (or approximated by Feedback Alignment, see Lillicrap et al., 2016). In the end, we compute a derivative in Eq. (6) that is implicitly a sum over the many terms of Eq. (2), using formulae that depend explicitly only on times and . For this reason, such a learning algorithm is online, and it is past facing because the gradient computation is of the form in Eq. (2).
2.2 Future-facing online learning algorithms
Here we show a symmetric relation for future-facing (FF) online algorithms. The credit assignment vector is a row vector defined as the gradient of the loss with respect to the hidden state . It plays a role analogous to and has a recursive update similar to Eq. (5):
As in the PF case, the gradient is ultimately calculated using the chain rule over :
The recursive relations for PF and FF algorithms are of identical form given the following changes: (1) swap the roles of and , (2) swap the roles of and , and (3) flip the direction of all derivatives. This clarifies the fundamental trade-off between the PF and FF approaches to online learning. On the one hand, memory requirements favor FF because is a scalar while is a matrix. On the other, only PF can truly be run online, because the time direction of the update in FF is opposite the forward pass. Thus, efficient PF algorithms must compress , while efficient FF algorithms must predict .
3 Past-facing algorithms
3.1 Real-Time Recurrent Learning
The Real-Time Recurrent Learning (RTRL, Williams and Zipser, 1989) algorithm directly applies Eqs. (5) and (6) as written. We call the application of Eq. (5) the “update” to the learning algorithm, which is deterministic and in closed form. Implementing Eq. (5) requires storing floats in and performing multiplications in , which is neither especially efficient nor biologically plausible. However, several efficient (and in some cases, biologically plausible) online learning algorithms have recently been developed, including Unbiased Online Recurrent Optimization (UORO; Tallec and Ollivier, 2017; §3.2), Kronecker-Factored RTRL (KF-RTRL; Mujika et al., 2018; §3.3), Kernel RNN Learning (KeRNL; Roth et al., 2019; §3.5), and Random-Feedback Online Learning (RFLO; Murray, 2019; §3.6). We claim that these learning algorithms, whether explicitly derived as such or not, are all implicitly approximations to RTRL, each a special case of a general class of techniques for compressing . In the following section, we clarify how each of these learning algorithms fits into this broad structure.
3.1.1 Approximations to RTRL
To concretely illuminate these ideas, we will work with a special case of , a time-continuous vanilla RNN:
where , , is some point-wise nonlinearity (e.g. ), and is the network’s inverse time constant. The trainable parameters are folded via the indexing into the weight matrix , whose columns hold the recurrent weights, the input weights, and a bias. By reshaping into its natural matrix form , we can write the influence matrix as an order-3 influence tensor
Thus specifies the effect on the -th unit of perturbing the direct connection from the -th unit to the -th unit. The immediate influence can also be written as a tensor. By differentiating Eq. (9), we see it takes the sparse form
because can affect the -th unit directly only if . Many approximations of RTRL involve a decomposition of into a product of lower-order tensors. For example, UORO represents by an outer product , which has a memory requirement of only . Similarly, KF-RTRL uses a Kronecker-product decomposition . We can generalize these cases into a set of six possible decompositions of into products of lower-order tensors and :
Each such decomposition has a memory requirement of . Of course, it is not sufficient to write down an idealized decomposition for a particular time point; there must exist some efficient way to update the decomposition as the network runs forwards. We now go through each algorithm and show the mathematical techniques used to derive update equations and categorize them by the criteria outlined in Table 1.
3.2 Unbiased Online Recurrent Optimization (UORO)
Tallec and Ollivier (2017) discovered a technique for approximating as an outer product , where and . The authors proved a crucial lemma (see Appendix A or Tallec and Ollivier, 2017) that gives, in closed form, an unbiased rank-1 estimate of a given matrix over the choice of a random vector with and . They leverage this result to derive a closed-form update rule for and at each time step, without ever having to explicitly (and expensively) calculate . We present an equivalent formulation in terms of tensor components, i.e.,
where represents the “rolled-up” components of , as in w.r.t. . Intuitively, the -th component of the influence matrix is constrained to be the product of the -th unit’s “sensitivity” and the -th parameter’s “efficacy” . Eqs. (10) and (11) show the form of the update and why it is unbiased over , respectively:
The cross terms vanish in expectation because . Thus, by induction over , the estimate of remains unbiased at every time step. The constants are chosen at each time step to minimize total variance of the estimate by balancing the norms of the cross terms. This algorithm’s update is stochastic due to its reliance on the random vector , but it is in closed form because it has an explicit update formula (Eq. 10). Both its memory and computational complexity are .
3.3 Kronecker-Factored RTRL (KF-RTRL)
Mujika et al. (2018) leverage the same lemma as in UORO, but using a decomposition of in terms of a Kronecker product , where now and . This decomposition is more natural, because the immediate influence factors exactly as a Kronecker product for vanilla RNNs, where . To derive the update rule for UORO, one must first generate a rank-1 estimate of as an intermediate step, introducing more variance, but in KF-RTRL, this step is unnecessary. In terms of components, the compression takes the form
As in UORO, the cross terms vanish in expectation, and the estimate is unbiased by induction over . This algorithm’s updates are also stochastic and in closed form. Its memory complexity is , but its computation time is because of the matrix-matrix product in Eq. (13).
3.4 Reverse KF-RTRL (R-KF-RTRL)
Our exploration of the space of different approximations naturally raises a question: is an approximation of the form
also possible? We refer to this method as “Reverse” KF-RTRL (R-KF-RTRL) because, in matrix notation, this would be formulated as , where and . We propose the following update for and in terms of a random vector :
Eq. (16) shows that this estimate is unbiased, using updates that are stochastic and in closed form, like its sibling algorithms. Its memory and computational complexity are and , respectively. R-KF-RTRL is actually more similar to UORO than KF-RTRL, because does not naturally factor like Eq. (14), introducing more variance. Worse, it has the computational complexity of KF-RTRL due to the matrix-matrix multiplication in Eq. (15). KF-RTRL stands out as the most effective of these 3 algorithms, because it estimates with the lowest variance due to its natural decomposition structure. (See Mujika et al., 2018 for variance calculations.)
3.4.1 Optimal Kronecker-Sum Approximation (OK)
We briefly mention an extension of KF-RTRL by Benzing et al. (2019), where the influence matrix is approximated not by 1 but rather a sum of Kronecker products, or, in components
On the RTRL update, the index of is propagated forward by the Jacobian, and then the immediate influence—itself a Kronecker product—is added. Now is approximated by Kronecker products
but the authors developed a technique to optimally reduce this sum back to Kronecker products, keeping the memory complexity and computational complexity constant. This update is stochastic because it requires explicit randomness in the flavor of the above algorithms, and it is numerical because there is no closed form solution to the update. We leave the details to the original paper.
3.5 Kernel RNN Learning (KeRNL)
Roth et al. (2019) developed a learning algorithm for RNNs that is essentially a compression of the influence matrix of the form . We will show that this algorithm is also an implicit approximation of RTRL, although the update rules are fundamentally different than those for UORO, KF-RTRL and R-KF-RTRL. The eligibility trace updates by temporally filtering the immediate influences with unit-specific, learnable timescales :
The sensitivity matrix is chosen to approximate the multi-step Jacobian with help from the learned timescales:
We will describe how is learned later, but for now we assume this approximation holds and use it to show how the KeRNL update is equivalent to that of RTRL. We have dropped the explicit time-dependence from , because it updates too slowly for Eq. (18) to be specific to any one time point. If we unpack this approximation by one time step, we uncover the consistency relation
Then the eligibility trace update effectively implements the RTRL update, assuming inductively that is well approximated by :
In Eq. (21), we use each of the special cases from Eq. (20). Of course, the and have to be learned, and Roth et al. (2019) use gradient descent to do so. We leave details to the original paper; briefly, they run in parallel a perturbed forward trajectory to estimate the LHS of Eq. (18) and then perform SGD on the squared difference between the LHS and RHS, giving gradients for and .
3.6 Random-Feedback Online Learning (RFLO)
Coming from a computational neuroscience perspective, Murray (2019) developed a beautifully simple and biologically plausible learning rule for RNNs, which he calls Random-Feedback Online Learning (RFLO). He formulates the rule in terms of an eligibility trace that filters the non-zero immediate influence elements by the network inverse time constant :
Then the approximate gradient is ultimately calculated111As the “random feedback” part of the name suggests, Murray goes a step further in approximating by random feedback weights á la Lillicrap et al., 2016, but we assume exact feedback in this paper for easier comparisons with other algorithms. as
By observing that
we see that RFLO is a special case of KeRNL, in which we fix , . Alternatively, and as hinted in the original paper, we can view RFLO as a special case of RTRL under the approximation , because the RTRL update reduces to RFLO with containing along the diagonals:
Fig. 2 illustrates how is contained in the influence matrix . This algorithm’s update is deterministic and in closed form, with memory and computational complexity .
4 Future-facing algorithms
4.1 Backpropagation Through Time (BPTT)
For many applications, a recurrent network is unrolled only for some finite number of time steps, and backpropagation through time (BPTT) manifests as the computation of the sum over every in the graph. This can be efficiently accomplished using
(see Eq. 7) to propagate credit assignment backwards. However, in our framework, where a network is run on an infinite-time horizon, there are two qualitatively different ways of unrolling the network. We call them “efficient” and “future-facing” BPTT.
4.1.1 Efficient backpropagation through time (E-BPTT)
For this method, we simply divide the graph into non-overlapping segments of truncation length and perform BPTT between and as described above, using Eq. (23). It takes computation time to compute one gradient, but since this computation is only performed once every time steps, the computation time is effectively , with memory requirement . A problem with this approach is that it does not treat all time points the same: an application of occurring near the end of the graph segment has less of its future influence accounted for than applications of occurring before it, as can be visualized in Fig. 3. And since any one gradient passed to the optimizer is a sum across both and , it is not an online algorithm by the framework we presented in §2. Therefore, for the purpose of comparing with online algorithms, we also show an alternative version of BPTT that calculates a future-facing gradient (up to truncation) for every .
4.1.2 Future-facing backpropagation through time (F-BPTT)
In this version of BPTT, we keep a dynamic list of truncated credit assignment estimates for times :
where each truncated credit assignment estimate includes the influences of only up to time :
At current time , every element is extended by adding , calculated by backpropagating from the current loss , while the explicit credit assignment is appended to the front of the list. To compensate, the oldest credit assignment estimate is removed and combined with the immediate influence to form a (truncated) gradient
which is passed to the optimizer to update the network. This algorithm is “online” in that it produces strictly future-facing gradients at each time step, albeit delayed by the truncation time and requiring memory of the network states from . Each update step requires computation, but since the update is performed at every time step, computation remains a factor of more expensive than E-BPTT. Memory requirement is still . Fig. 3 illustrates the differences among these methods and RTRL, using a triangular lattice as a visualization tool. Each point in the lattice is one derivative with , and the points are grouped together into discrete gradients passed to the optimizer.
4.2 Decoupled Neural Interfaces (DNI)
Jaderberg et al. (2017) developed a framework for online learning by predicting credit assignment. Whereas PF algorithms face the problem of a large influence tensor that needs a compressed representation, FF algorithms face the problem of incomplete information: at time , it is impossible to calculate without access to future network variables. The approach of Decoupled Neural Interfaces (DNI) is to simply make a linear prediction of (Czarnecki et al., 2017) based on the current hidden state and the current labels :