Continuous Learning in a Hierarchical Multiscale Neural Network
Abstract
We reformulate the problem of encoding a multiscale representation of a sequence in a language model by casting it in a continuous learning framework. We propose a hierarchical multiscale language model in which short timescale dependencies are encoded in the hidden state of a lowerlevel recurrent neural network while longer timescale dependencies are encoded in the dynamic of the lowerlevel network by having a metalearner update the weights of the lowerlevel neural network in an online metalearning fashion. We use elastic weights consolidation as a higherlevel to prevent catastrophic forgetting in our continuous learning framework.
” \aclfinalcopy
1 Introduction
Language models are a major class of natural language processing (NLP) models whose development has lead to major progress in many areas like translation, speech recognition or summarization \citepschwenk_continuous_2012,arisoy_deep_2012,rush_neural_2015,nallapati_abstractive_2016. Recently, the task of language modeling has been shown to be an adequate proxy for learning unsupervised representations of highquality in tasks like text classification \citephoward_finetuned_2018, sentiment detection \citepradford_learning_2017 or word vector learning \citeppeters_deep_2018.
More generally, language modeling is an example of online/sequential prediction task, in which a model tries to predict the next observation given a sequence of past observations. The development of better models for sequential prediction is believed to be beneficial for a wide range of applications like modelbased planning or reinforcement learning as these models have to encode some form of memory or causal model of the world to accurately predict a future event given past events.
One of the main issues limiting the performance of language models (LMs) is the problem of capturing longterm dependencies within a sequence.
Neural network based language models \citephochreiter_long_1997, cho_learning_2014 learn to implicitly store dependencies in a vector of hidden activities \citepmikolov_recurrent_2010. They can be extended by attention mechanisms, memories or caches \citepbahdanau_neural_2014,tran_recurrent_2016,graves_neural_2014 to capture longrange connections more explicitly. Unfortunately, the very local context is often so highly informative that LMs typically end up using their memories mostly to store short term context \citepdaniluk_frustratingly_2016.
In this work, we study the possibility of combining shortterm representations, stored in neural activations (hidden state), with mediumterm representations encoded in a set of dynamical weights of the language model. Our work extends a series of recent experiments on networks with dynamically evolving weights \citepba_using_2016, ha_hypernetworks_2016,krause_dynamic_2017,moniz_nested_2018 which show improvements in sequential prediction tasks. We build upon these works by formulating the task as a hierarchical online metalearning task as detailed below.
The motivation behind this work stems from two observations.
On the one hand, there is evidence from a physiological point of view that timecoherent processes like working memory can involve differing mechanisms at differing timescales. Biological neural activations typically have a 10 ms coherence timescale, while shortterm synaptic plasticity can temporarily modulate the dynamic of the neural network itself on timescales of 100 ms to minutes. Longer timescales (a few minutes to several hours) see longterm learning kicks in with permanent modifications to neural excitability \citeptsodyks_neural_1998,abbott_synaptic_2004,barak_persistent_2007,ba_using_2016. Interestingly, these psychological observations are paralleled, on the computational side, by a series of recent works on recurrent networks with dynamically evolving weights that show benefits from dynamically updating the weights of a network during a sequential task \citepba_using_2016, ha_hypernetworks_2016,krause_dynamic_2017,moniz_nested_2018.
In parallel to that, it has also been shown that temporal data with multiple timescales dependencies can naturally be encoded in a hierarchical representation where higherlevel features are changing slowly to store long timescale dependencies and lowerlevel features are changing faster to encode short timescale dependencies and local timing \citepschmidhuber_learning_1992,el_hihi_hierarchical_1995,koutnik_clockwork_2014, chung_hierarchical_2016.
As a consequence, we would like our model to encode information in a multiscale hierarchical representation where

short timescale dependencies can be encoded in fastupdated neural activations (hidden state),

medium timescale dependencies can be encoded in the dynamic of the network by using dynamic weights updated more slowly, and

a long timescale memory can be encoded in a static set of parameters of the model.
In the present work, we take as dynamic weights the full set of weights of a RNN language model (usually word embeddings plus recurrent, input and output weights of each recurrent layer).
2 Dynamical Language Modeling
Given a sequence of discrete symbols , the language modeling task consists in assigning a probability to the sequence which can be written, using the chainrule, as
(1) 
where is a set of parameters of the language model.
In the case of a neuralnetworkbased language model, the conditional probability is typically parametrized using an autoregressive neural network as
(2) 
where are the parameters of the neural network.
In a dynamical language modeling framework, the parameters of the language model are not tied over the sequence but are allowed to evolve. Thus, prior to computing the probability of a future token , a set of parameters is estimated from the past parameters and tokens as and the updated parameters are used to compute the probability of the next token .
In our hierarchical neural network language model, the updated parameters are estimated by a higher level neural network parametrized by a set of (static) parameters :
(3) 
2.1 Online metalearning formulation
The function computed by the higher level network , estimating from an history of parameters and data points , can be seen as an online metalearning task in which a highlevel metalearner network is trained to update the weights of a lowlevel network from the loss of the lowlevel network on a previous batch of data.
Such a metalearner can be trained \citepandrychowicz_learning_2016 to reduce the loss of the lowlevel network with the idea that it will generalize a gradient descent rule
(4) 
where is a learning rate at time and is the gradient of the loss of the language model on the th dataset with respect to previous parameters .
Ravi and Larochelle (\citeyearravi_optimization_2016) made the observation that such a gradient descent rule bears similarities with the update rule for LSTM cellstates
(5) 
when , and
We extend this analogy to the case of a multiscale hierarchical recurrent model illustrated on figure 1 and composed of:

Lowerlevel / short timescale: a RNNbased language model encoding representations in the activations of a hidden state,

Middlelevel / medium timescale: a metalearner updating the set of weights of the language model to store mediumterm representations, and

Higherlevel / long timescale: a static longterm memory of the dynamic of the RNNbased language model (see below).
The metalearner is trained to update the lowerlevel network by computing and updating the set of weights as
(6) 
This hierarchical network could be seen as an analog of the hierarchical recurrent neural networks \citepchung_hierarchical_2016 where the gates , and can be seen as controlling a set of COPY, FLUSH and UPDATE operations:

COPY (): part of the state copied from the previous state ,

UPDATE (): part of the state updated by the loss gradients on the previous batch, and

FLUSH (): part of the state reset from a static long term memory .
One difference with the work of [Chung et al.(2016)Chung, Ahn, and Bengio] is that the memory was confined to the hidden in the later while the memory of our hierarchical network is split between the weights of the lowerlevel network and its hiddenstate.
The metalearner can be a feedforward or a RNN network. In our experiments, simple linear feedforward networks lead to the lower perplexities, probably because it was easier to regularize and optimize. The metalearner implements coordinatesharing as described in [Andrychowicz et al.(2016)Andrychowicz, Denil, Gomez, Hoffman, Pfau, Schaul, Shillingford, and de Freitas, Ravi and Larochelle(2016)] and takes as input the loss and lossgradients over a previous batch (a sequence of tokens as illustrated on figure 1). The size of the batch adjusts the tradeoff between the noise of the loss/gradients and updating frequency of the mediumterm memory, smaller batches leading to faster updates with higher noise.
2.2 Continual learning
The interaction between the metalearner and the language model implements a form of continuallearning and the language model thus faces a phenomenon known as catastrophic forgetting \citepfrench_catastrophic_1999. In our case, this correspond to the lowerlevel network overspecializing to the lexical field of a particular topic after several updates of the metalearner (e.g. while processing a long article on a specific topic).
To mitigate this effect we use a higherlevel static memory initialized using ”elastic weight consolidation” (EWC) introduced by Kirkpatrick et al. (\citeyearkirkpatrick_overcoming_2017) to reduce forgetting in multitask reinforcement learning.
Casting our task in the EWC framework, we define a task A which is the language modeling task (prediction of next token) when no context is stored in the weights of the lowerlevel network. The solution of task A is a set of weights toward which the model could advantageously come back when the context stored in the weights become irrelevant (for example when switching between paragraphs on different topics). To obtain a set of weights for task A, we train the lowerlevel network (RNN) alone on the training dataset and obtain a set of weights that would perform well on average, i.e. when no specific context has been provided by a contextdependent weight update performed by the metalearner.
We then define a task B which is a language modeling task when a context has been stored in the weights of the lowerlevel network by an update of the metalearner. The aim of EWC is to learn task B while retaining some performance on task A.
Empirical results suggest that many weights configurations result in similar performances \citepsussmann_uniqueness_1992 and there is thus likely a solution for task B close to a solution for task A. The idea behind EWC is to learn task B while protecting the performance in task A by constraining the parameters to stay around the solution found for task A.
This constraint is implemented as a quadratic penalty, similarly to spring anchoring the parameters, hence the name elastic. The stiffness of the springs should be greater for parameters that most affect performance in task A. We can formally write this constrain by using Bayes rule to express the conditional log probability of the parameters when the training dataset is split between the training dataset for task A () and the training dataset for task B ():
(7) 
The true posterior probability on task A is intractable, so we approximate the posterior as a Gaussian distribution with mean given by the parameters and a diagonal precision given by the diagonal of the Fisher information matrix F which is equivalent to the second derivative of the loss near a minimum and can be computed from firstorder derivatives alone.
3 Related work
Several works have been devoted to dynamically updating the weights of neural networks during inference. A few recent architectures are the FastWeights of \citetba_using_2016, the Hypernetworks of \citetha_hypernetworks_2016 and the Nested LSTM of \citetmoniz_nested_2018. The weights update rules of theses models use as inputs one or several of (i) a previous hidden state of a RNN network or higher level network and/or (ii) the current or previous inputs to the network. However, these models do not use the predictions of the network on the previous tokens (i.e. the loss and gradient of the loss of the model) as in the present work. The architecture that is most related to the present work is the study on dynamical evaluation of \citetkrause_dynamic_2017 in which a loss function similar to the loss function of the present work is obtained empirically and optimized using a large hyperparameter search on the parameters of the SGDlike rule.
4 Experiments
4.1 Architecture and hyperparameters
As mentioned in 2.2, a set of pretrained weights of the RNN language model is first obtained by training the lowerlevel network and computing the diagonal of the Fisher matrix around the final weights.
Then, the metalearner is trained in an online metalearning fashion on the validation dataset (alternatively, a subset of the training dataset could be used). A training sequence is split in a sequence of minibatches , each batch containing inputs tokens () and associated targets (). In our experiments we varied between 5 and 20.
The metalearner is trained as described in \citepandrychowicz_learning_2016,li_learning_2016 by minimizing the sum over the sequence of LM losses: . The metalearner is trained by truncated backpropagation through time and is unrolled over at least 40 steps as the reward from the mediumterm memory is relatively sparse [Li and Malik(2016)].
To be able to unroll the model over a sufficient number of steps while using a stateoftheart language model with over than 30 millions parameters, we use a memoryefficient version of back propagation through time based on gradient checkpointing as described by Grusly et al. (\citeyeargruslys_memoryefficient_2016).
4.2 Experiments
We performed a series of experiments on the Wikitext2 dataset [Merity et al.(2016)Merity, Xiong, Bradbury, and Socher] using an AWDLSTM language model \citepmerity_regularizing_2017 and a feedforward and RNN metalearner.
The test perplexity are similar to perplexities obtained using dynamical evaluation \citepkrause_dynamic_2017, reaching with a linear feedforward metalearner when starting from a onelevel language model with test perplexity of .
In our experiments, the perplexity could not be improved by using a RNN metalearner or a deeper metalearner. We hypothesis that this may be caused by several reasons. First, storing a hidden state in the metalearner might be less important in an online metalearning setup than it is in a standard metalearning setup \citepandrychowicz_learning_2016 as the target distribution of the weights is nonstationary. Second, the size of the hidden state cannot be increased significantly without reducing the number of steps along which the metalearner is unrolled during metatraining which may be detrimental.
Some quantitative experiments are shown on Figure 2 using a linear feedforward network to illustrate the effect of the various layers in the hierarchical model. The curves shows differences in batch perplexity between model variants.
The top curve compares a onelevel model (language model) with a twolevels model (language model + metalearner). The metalearner is able to learn mediumterm representations to progressively reduce perplexity along articles (see e.g. articles C and E). Right sample 1 (resp. 2) details sentences at the begging (resp. middle) of article E related to a warship called ”Ironclad”. The addition of the metalearner reduces the loss on a number of expression related to the warship like ”ironclad” or ”steel armor”.
Bottom curve compares a threelevels model (language model + metalearner + longterm memory) with the twolevels model. The local loss is reduced at topics changes and beginning of new topics (see e.g. articles B, D and F). The right sample 3 can be contrasted with sample 1 to illustrate how the hierarchical model is able to better recover a good parameter space following a change in topic.
References
 L. F. Abbott and Wade G. Regehr. 2004. Synaptic computation. Nature, 431(7010):796–803.
 Marcin Andrychowicz, Misha Denil, Sergio Gomez, Matthew W. Hoffman, David Pfau, Tom Schaul, Brendan Shillingford, and Nando de Freitas. 2016. Learning to learn by gradient descent by gradient descent. arXiv:1606.04474 [cs]. ArXiv: 1606.04474.
 Ebru Arisoy, Tara N. Sainath, Brian Kingsbury, and Bhuvana Ramabhadran. 2012. Deep neural network language models. In Proceedings of the NAACLHLT 2012 Workshop: Will We Ever Really Replace the Ngram Model? On the Future of Language Modeling for HLT, pages 20–28. Association for Computational Linguistics.
 Jimmy Ba, Geoffrey Hinton, Volodymyr Mnih, Joel Z. Leibo, and Catalin Ionescu. 2016. Using Fast Weights to Attend to the Recent Past. arXiv:1610.06258 [cs, stat]. ArXiv: 1610.06258.
 Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. 2014. Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473.
 Omri Barak and Misha Tsodyks. 2007. Persistent activity in neural networks with dynamic synapses. PLoS computational biology, 3(2):e35.
 Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, and Yoshua Bengio. 2014. Learning Phrase Representations using RNN EncoderDecoder for Statistical Machine Translation. arXiv:1406.1078 [cs, stat]. ArXiv: 1406.1078.
 Junyoung Chung, Sungjin Ahn, and Yoshua Bengio. 2016. Hierarchical Multiscale Recurrent Neural Networks. arXiv:1609.01704 [cs]. ArXiv: 1609.01704.
 MichaÅ Daniluk, Tim RocktÃ¤schel, Johannes Welbl, and Sebastian Riedel. 2016. Frustratingly Short Attention Spans in Neural Language Modeling.
 Salah El Hihi and Yoshua Bengio. 1995. Hierarchical Recurrent Neural Networks for Longterm Dependencies. In Proceedings of the 8th International Conference on Neural Information Processing Systems, NIPS’95, pages 493–499, Cambridge, MA, USA. MIT Press.
 Robert M. French. 1999. Catastrophic forgetting in connectionist networks. Trends in Cognitive Sciences, 3(4):128–135.
 Alex Graves, Greg Wayne, and Ivo Danihelka. 2014. Neural turing machines. arXiv preprint arXiv:1410.5401.
 AudrÅ«nas Gruslys, Remi Munos, Ivo Danihelka, Marc Lanctot, and Alex Graves. 2016. MemoryEfficient Backpropagation Through Time. arXiv:1606.03401 [cs]. ArXiv: 1606.03401.
 David Ha, Andrew Dai, and Quoc V. Le. 2016. HyperNetworks. arXiv:1609.09106 [cs]. ArXiv: 1609.09106.
 Sepp Hochreiter and JÃ¼rgen Schmidhuber. 1997. Long ShortTerm Memory. Neural Comput., 9(8):1735–1780.
 Jeremy Howard and Sebastian Ruder. 2018. Finetuned Language Models for Text Classification. arXiv:1801.06146 [cs, stat]. ArXiv: 1801.06146.
 James Kirkpatrick, Razvan Pascanu, Neil Rabinowitz, Joel Veness, Guillaume Desjardins, Andrei A. Rusu, Kieran Milan, John Quan, Tiago Ramalho, Agnieszka GrabskaBarwinska, Demis Hassabis, Claudia Clopath, Dharshan Kumaran, and Raia Hadsell. 2017. Overcoming catastrophic forgetting in neural networks. Proceedings of the National Academy of Sciences, 114(13):3521–3526.
 Jan KoutnÃk, Klaus Greff, Faustino Gomez, and JÃ¼rgen Schmidhuber. 2014. A Clockwork RNN. In Proceedings of the 31st International Conference on International Conference on Machine Learning  Volume 32, ICML’14, pages II–1863–II–1871, Beijing, China. JMLR.org.
 Ben Krause, Emmanuel Kahembwe, Iain Murray, and Steve Renals. 2017. Dynamic Evaluation of Neural Sequence Models. arXiv:1709.07432 [cs]. ArXiv: 1709.07432.
 Ke Li and Jitendra Malik. 2016. Learning to Optimize. arXiv:1606.01885 [cs, math, stat]. ArXiv: 1606.01885.
 Stephen Merity, Nitish Shirish Keskar, and Richard Socher. 2017. Regularizing and Optimizing LSTM Language Models. arXiv:1708.02182 [cs]. ArXiv: 1708.02182.
 Stephen Merity, Caiming Xiong, James Bradbury, and Richard Socher. 2016. Pointer Sentinel Mixture Models. arXiv:1609.07843 [cs]. ArXiv: 1609.07843.
 Tomas Mikolov, Martin KarafiÃ¡t, Lukas Burget, Jan CernockÃ½, and Sanjeev Khudanpur. 2010. Recurrent neural network based language model, volume 2.
 Joel Ruben Antony Moniz and David Krueger. 2018. Nested LSTMs. arXiv:1801.10308 [cs]. ArXiv: 1801.10308.
 Ramesh Nallapati, Bowen Zhou, Caglar Gulcehre, and Bing Xiang. 2016. Abstractive text summarization using sequencetosequence rnns and beyond. arXiv preprint arXiv:1602.06023.
 Matthew E. Peters, Mark Neumann, Mohit Iyyer, Matt Gardner, Christopher Clark, Kenton Lee, and Luke Zettlemoyer. 2018. Deep contextualized word representations. arXiv:1802.05365 [cs]. ArXiv: 1802.05365.
 Alec Radford, Rafal Jozefowicz, and Ilya Sutskever. 2017. Learning to Generate Reviews and Discovering Sentiment. arXiv:1704.01444 [cs]. ArXiv: 1704.01444.
 Sachin Ravi and Hugo Larochelle. 2016. Optimization as a Model for FewShot Learning.
 Alexander M. Rush, Sumit Chopra, and Jason Weston. 2015. A neural attention model for abstractive sentence summarization. arXiv preprint arXiv:1509.00685.
 J. Schmidhuber. 1992. Learning Complex, Extended Sequences Using the Principle of History Compression. Neural Computation, 4(2):234–242.
 Holger Schwenk. 2012. Continuous space translation models for phrasebased statistical machine translation. Proceedings of COLING 2012: Posters, pages 1071–1080.
 HÃ©ctor J. Sussmann. 1992. Uniqueness of the weights for minimal feedforward nets with a given inputoutput map. Neural Networks, 5(4):589–593.
 Ke Tran, Arianna Bisazza, and Christof Monz. 2016. Recurrent Memory Networks for Language Modeling. arXiv:1601.01272 [cs]. ArXiv: 1601.01272.
 Misha Tsodyks, Klaus Pawelzik, and Henry Markram. 1998. Neural Networks with Dynamic Synapses.