Multi-task learning with deep model based reinforcement learning

Multi-task learning with deep model based reinforcement learning


In recent years, model-free methods that use deep learning have achieved great success in many different reinforcement learning environments. Most successful approaches focus on solving a single task, while multi-task reinforcement learning remains an open problem. In this paper, we present a model based approach to deep reinforcement learning which we use to solve different tasks simultaneously. We show that our approach not only does not degrade but actually benefits from learning multiple tasks. For our model, we also present a new kind of recurrent neural network inspired by residual networks that decouples memory from computation allowing to model complex environments that do not require lots of memory.


Recently, there has been a lot of success in applying neural networks to reinforcement learning, achieving super-human performance in many ATARI games ([8]; [9]). Most of these algorithms are based on -learning, which is a model free approach to reinforcement learning. This approaches learn which actions to perform in each situation, but do not learn an explicit model of the environment. Apart from that, learning to play multiple games simultaneously remains an open problem as these approaches heavily degrade when increasing the number of tasks to learn.

In contrast, we present a model based approach that can learn multiple tasks simultaneously. The idea of learning predictive models has been previously proposed ([13]; [12]), but all of them focus on learning the predictive models in an unsupervised way. We propose using the reward as a means to learn a representation that captures only that which is important for the game. This also allows us to do the training in a fully supervised way. In the experiments, we show that our approach can surpass human performance simultaneously on three different games. In fact, we show that transfer learning occurs and it benefits from learning multiple tasks simultaneously.

In this paper, we first discuss why -learning fails to learn multiple tasks and what are its drawbacks. Then, we present our approach, Predictive Reinforcement Learning, as an alternative to overcome those weaknesses. In order to implement our model, we present a recurrent neural network architecture based on residual nets that is specially well suited for our task. Finally, we discuss our experimental results on several ATARI games.

2Previous work: Deep -learning

In recent years, approaches that use Deep -learning have achieved great success, making an important breakthrough when [8] presented a neural network architecture that was able to achieve human performance on many different ATARI games, using just the pixels in the screen as input.

As the name indicates, this approach revolves around the -function. Given a state and an action , returns the expected future reward we will get if we perform action in state . Formally, the -function is defined in Equation 1.

For the rest of this subsection, we assume the reader is already familiar with Deep -learning and we discuss its main problems. Otherwise, we recommend skipping to the next section directly as none of the ideas discussed here are necessary to understand our model.

As the true value of the -function is not known, the idea of Deep -learning is iteratively approximating this function using a neural network1 which introduces several problems.

First, the -values depend on the strategy the network is playing. Thus, the target output for the network given a state-action pair is not constant, since it changes as the network learns. This means that apart from learning an strategy, the network also needs to remember which strategy it is playing. This is one of the main problems when learning multiple tasks, as the networks needs to remember how it is acting on each of the different tasks. [11] and [10] have managed to successfully learn multiple tasks using -learning. Both approaches follow a similar idea: an expert network learns to play a single game, while a multi-tasking network learns to copy the behavior of an expert for each different game. This means that the multi-tasking network does not iteratively approximate the -function, it just learns to copy the function that the single-task expert has approximated. That is why their approach works, they manage to avoid the problem of simultaneously approximating all the -functions, as this is done by each single task expert.

Apart from that, the network has to change the strategy very slightly at each update as drastically changing the strategy would change the -values a lot and cause the approximation process to diverge/slow-down. This forces the model to interact many times with the environment in order to find good strategies. This is not problematic in simulated environments like ATARI games where the simulation can easily be speed up using more computing power. Still, in real world environments, like for example robotics, this is not the case and data efficiency can be an important issue.

3Predictive Reinforcement Learning

In order to avoid the drawbacks of Deep -learning, we present Predictive Reinforcement Learning (PRL). In our approach, we separate the understanding of the environment from the strategy. This has the advantage of being able to learn from different strategies simultaneously while also being able to play strategies that are completely different to the ones that it learns from. We will also argue that this approach makes generalization easier. But before we present it, we need to define what we want to solve.

3.1Prediction problem

The problem we want to solve is the following: given the current state of the environment and the actions we will make in the future, how is our score going to change through time?

To formalize this problem we introduce the following notation:

  • : The observation of the environment at time . In the case of ATARI games, this corresponds to the pixels of the screen.

  • : The total accumulated reward at time . In the case of ATARI games, this corresponds to the in-game score.

  • : The control that was performed at time . In the case of ATARI games, this corresponds to the inputs of the ATARI controller: up, right, shoot, etc.

Then, we want to solve the following problem: For a given time and a positive integer , let the input to our model be an observation and a set of future controls . Then, we want to predict the change in score for the next time steps, i.e. . Figure 1 illustrates this with an example.

Figure 1: We chose i = 0 and k = 1. We assume a_0 to be the pixels in the current image (the left one) and c_1 to be the jump action. Then, given that input, we want to predict r_{1} - r_0, which is 1, because we earn a reward from time 0 to time 1.
Figure 1: We chose and . We assume to be the pixels in the current image (the left one) and to be the jump action. Then, given that input, we want to predict , which is , because we earn a reward from time to time .

Observe that, unlike in -learning, our predictions do not depend on the strategy being played. The outputs only depend on the environment we are trying to predict. So, the output for a given state-actions pair is always the same or, in the case of non-deterministic environments, it comes from the same distribution.


We have defined what we want to solve but we still need to specify how to implement a model that will do it. We will use neural networks for this and we will divide it into three different networks as follows:

  • Perception: This network reads a state and converts it to a lower dimensional vector that is used by the Prediction.

  • Prediction: For each , this network reads the vector and the corresponding control and generates a vector that will be used in the next steps of the Prediction and Valuation. Observe that this is actually a recurrent neural network.

  • Valuation: For each , this network reads the current vector of the Prediction and predicts the difference in score between the initial time and the current one, i.e, .

Figure ? illustrates the model. Observe that what we actually want to solve is a supervised learning problem. Thus, the whole model can be jointly trained with simple backpropagation. We will now proceed to explain each of the components in more detail.

Figure 2: The recurrent model
Figure 2: The recurrent model
Figure 3: The same model unfolded in time
Figure 3: The same model unfolded in time


The Perception has to be tailored for the kind of observations the environment returns. For now, we will focus only on vision based Perception. As we said before, the idea of this network is to convert the high dimensional input to a low dimensional vector that contains only the necessary information for predicting the score. In the case of video games, it is easy to see that such vector exists. The input will consists of thousands of pixels but all we care about is the position of a few key objects, like for example, the main character or the enemies. This information can easily be encoded using very few neurons. In our experiments, we convert an input consisting of pixels into a vector of just real values.

In order to do this, we use deep convolutional networks. These networks have recently achieved super-human performance in very complex image recognition tasks [3]. In fact, it has been observed that the upper layers in these models learn lower dimensional abstract representations of the input ([14], [6]). Given this, it seems reasonable to believe that if we use any of the successful architectures for vision, our model will be able to learn a useful representation that can be used by the Prediction.


For the Prediction network, we present a new kind of recurrent network based on residual neural networks [3], which is specially well suited for our task and it achieved better results than an LSTM [4] with a similar number of parameters in our initial tests.

Residual Recurrent Neural Network (RRNN) We define the RRNN in Figure 4 using the following notation: is the layer normalization function [1] which normalizes the activations to have a median of and standard deviation of . “” is the concatenation of two vectors. can be any parameterizable and differentiable function, e.g., a multilayer perceptron.

Figure 4: The equations of the RRNN and a diagram of the network.
Figure 4: The equations of the RRNN and a diagram of the network.

As in residual networks, instead of calculating what the new state of the network should be, we calculate how it should change (). As shown by [3] this prevents vanishing gradients or optimization difficulties. outputs a vector with mean and standard deviation . As we proof2 in Observation ?, this prevents internal exploding values that may arise from repeatedly adding to . It also avoids the problem of vanishing gradients in saturating functions like sigmoid or hyperbolic tangent.

Taking into account that the median is and the standard deviation is , simply substituting the values in the formula for the standard deviation shows the observation.

The idea behind this network is mimicking how a video game’s logic works. A game has some variables (like positions or speeds of different objects) that are slightly modified at each step. Our intuition is that the network can learn a representation of these variables (), while learns how they are transformed at each frame. Apart from that, this model decouples memory from computation allowing to increase the complexity of without having to increase the number of neurons in . This is specially useful as the number of real valued neurons needed to represent the state of a game is quite small. Still, the function to move from one frame to the next can be quite complex, as it has to model all the interactions between the objects such as collisions, movements, etc.

Even if this method looks like it may be just tailored for video games, it should work equally well for real world environments. After all, physics simulations that model the real world work in the same way, with some variables that represent the current state of the system and some equations that define how that system evolves over time.


The Valuation network reads the vector at time and outputs the change in reward for that time step, i.e. . Still, it is a key part of our model as it allows to decouple the representation learned by the Prediction from the reward function. For example, consider a robot in a real world environment. If the Perception learns to capture the physical properties of all surrounding objects (shape, mass, speed, etc.) and the Prediction learns to make a physical simulation of the environment, this model can be used for any possible task in that environment, only the Valuation would need to be changed.


As we previously said, finding an optimal strategy is a very hard problem and this part is the most complicated. So, in order to test our model in the experiments, we opted for hard-coding a strategy. There, we generate a set of future controls uniformly at random and then we pick the one that would maximize our reward, given that the probability of dying is low enough. Because of this, the games we have tried have been carefully selected such that they do not need very sophisticated and long-term strategies.

Still, our approach learns a predictive model that is independent of any strategy and this can be beneficial in two ways. First, the model can play a strategy that is completely different to the ones it learns from. Apart from that, learning a predictive model is a very hard task to over-fit. Consider a game with possible control inputs and a training set where we consider the next time steps. Then, there are possible control sequences. This means that every sequence we train on is unique and this forces the model to generalize. Unfortunately, there is also a downside. Our approach is not able to learn from good strategies because we test our model with many different ones in order to pick the best. Some of these strategies will be quite bad and thus, the model needs to learn what makes the difference between a good and a bad set of moves.



Our experiments have been performed on a computer with a GeForce GTX 980 GPU and an Intel Xeon E5-2630 CPU. For the neural network, we have used the Torch7 framework and for the ATARI simulations, we have used Alewrap, which is a Lua wrapper for the Arcade Learning Environment [2].


Figure 5: Each layer is followed by a Batch Normalization  and a Rectifier Linear Unit.
Figure 5: Each layer is followed by a Batch Normalization and a Rectifier Linear Unit.

For the Perception, we used a network inspired in deep residual networks [3]. Figure 5 shows the architecture. The reason for this, is that even if the Perception is relatively shallow, when unfolding the Prediction network over time, the depth of the resulting model is over layers deep.

For the Prediction, we use a Residual Recurrent Neural Network. Table ? describes the network used for the function. Finally, Table 2 illustrates the Valuation network.

Table 1: Valuation network. We apply Layer Normalization to bound the incoming values to the network.

Input Output
ReLU + Linear 100 + 3 500
ReLU + Linear 500 100
Table 2: Valuation network. We apply Layer Normalization to bound the incoming values to the network.

Input Output
LN 100 100
Linear + ReLU 100 100
Linear + Sigmoid 100 2


In our experiments, we have trained on three different ATARI games simultaneously: Breakout, Pong and Demon Attack.

We preprocess the images following the same technique of [8]. We take the maximum from the last 2 frames to get a single black and white image for the current observation. The input to the Perception is a tensor containing the last observations. This is necessary to be able to use a feed-forward network for the Perception. If we observed a single frame, it would not be possible to infer the speed and direction of a moving object. Not doing this would force us to use a recurrent network on the Perception, making the training of the whole model much slower.

In order to train the Prediction, we unfold the network over time (25 time steps) and treat the model as a feed-forward network with shared weights.

For our Valuation, network we output two values. First, the probability that our score is higher than in the initial time step. Second, we output the probability of dying. This is trained using cross entropy loss.

To train the model, we use an off-line learning approach for simplicity. During training we alternate between two steps. First, generate and store data and then, train the model off-line on that data.

4.4Generating data

In order to generate the data, we store tuples as we are playing the game. That is, for each time , we store the following:

  • : A tensor, containing consecutive black and white frames of size each.

  • : For , each is a dimensional vector that encodes the control action performed at time . The first dimension corresponds to the shoot action, the second to horizontal actions and the third to vertical actions. For example, represent pressing shoot and left.

  • : For , we store a 2 dimensional binary vector . is if we die between time and . is if we have not lost a life and we also earn a point between time and .

Initially, we have an untrained model, so at each time step, we pick an action uniformly at random and perform it. For the next iterations, we pick a and do the following to play the game:

  1. Run the Perception network on the last frames to obtain the initial vector.

  2. Generate sequences of actions uniformly at random. Apart from that, take the best sequence from the previous time step and also consider it. This gives a total of sequences. Then, for each sequence, run the Prediction and Valuation networks with the vector obtained in Step 1.

  3. Finally, pick a sequence of actions as follows. Consider only the moves that have a low enough probability of dying. From those, pick the one that has the highest probability of earning a point. If none has a high enough probability, just pick the one with the lowest probability of dying.

We start with and increase it every few iterations up to . For the full details check Appendix A. In order to accelerate training, we run several games in parallel. This allows to run the Perception, Prediction and Valuation networks together with the ATARI simulation in parallel, which heavily speeds up the generation of data without any drawback.


In the beginning, we generate training cases for each of the games by playing randomly, which gives us a total of training cases. Then, for the subsequent iterations, we generate additional training cases per game ( in total) and train again on the whole dataset. That is, at first we have training cases, afterwards , then and so on.

The training is done in a supervised way as depicted in Figure 3. and are given as input to the network and as target. We minimize the cross-entropy loss using mini-batch gradient descent. For the full details on the learning schedule check Appendix A.

In order to accelerate the process, instead of training a new network in each iteration, we keep training the model from the previous iteration. This has the effect that we would train much more on the initial training cases while the most recent ones would have an ever smaller effect as the training set grows. To avoid this, we assign a weight to each iteration and sample according to these weights during training. Every three iterations, we multiply by three the weights we assign to them. By doing this, we manage to focus on recent training cases, while still preserving the whole training set.

Observe that we never tell our network which game it is playing, but it learns to infer it from the observation . Also, at each iteration, we add cases that are generated using a different neural network. So our training set contains instances generated using many different strategies.


We have trained a model on the three games for a total of iterations, which correspond to time steps per game ( hours of play at Hz). Each iteration takes around two hours on our hardware. We have also trained an individual model for each game for time steps. In the individual models, we reduced the length of the training such that the number of parameter updates per game is the same as in the multi-task case. Unless some kind of transfer learning occurs, one would expect some degradation in performance in the multi-task model. Figure ? shows that not only there is no degradation in Pong and Demon Attack, but also that there is a considerable improvement in Breakout. This confirms our initial belief that our approach is specially well suited for multi-task learning.

Figure 6: Breakout
Figure 6: Breakout
Figure 7: Pong
Figure 7: Pong
Figure 8: Demon Attack
Figure 8: Demon Attack

We have also argued that our model can potentially play a very different strategy from the one it has observed. Table 3 shows that this is actually the case. A model that has learned only from random play is able to play at least times better.

Demon Attack’s plot in Figure 8 shows a potential problem we mentioned earlier which also happens in the other two games to a lesser extent. Once the strategy is good enough, the agent dies very rarely. This causes the model to “forget” which actions lead to a death and makes the score oscillate.

Table 3: After one iteration Preditive Reinforcement Learning (PRL) has only observed random play but it can play much better. This means that it is able to generalize well to many situations it has not observed during training.
Pong Breakout Demon
Human score 9.3 31.8 3401
Random -20.7 1.7 152
PRL Iteration 2 -8.62 13.2 1220


We have presented a novel model based approach to deep reinforcement learning that opens new lines of research in this area. We have shown that it can beat human performance in three different tasks simultaneously and that it can benefit from learning multiple tasks.

Still, the model has two areas that can be addressed in future work: long-term dependencies and the instability during training. The first, can potentially be solved by combining our approach with -learning based techniques. For the instability, balancing the training set or oversampling hard training cases could alleviate the problem.

Finally, we have also presented a new kind of recurrent network which can be very useful for problems were little memory and a lot of computation is needed.


I thank Angelika Steger and Florian Meier for their hardware support in the final experiments and comments on previous versions of the paper.

AImplementation details

Due to the huge cost involved in training the agents, we have not exhaustively searched over all the possible hyper parameters. Still, we present them here for reproducibility of the results.

  • Number of strategies: As explained in Section 4.4, we need to pick a number of strategies we consider at each step. Initially, we pick , raise it to at iteration and finally, at iteration , we set it to for the remaining of the experiment.

  • Confidence interval: We also need to pick how safe we want to play, i.e., where we set the threshold for the set of actions we consider. For simplicity, in Breakout and Pong, we set it to and only pick the safest option. In Demon Attack, initially we only consider actions with a survival probability higher than for three iterations. After that, we reduce it to for another three iterations. Then, we set it to until iteration and finally, reduce it to for the rest of the iterations.

  • Learning schedule: For training we use the Adam [7] optimizer with a batch size of . We use a learning rate of for the first iterations, then reduce it to for the next iterations and finally set it to for the rest of the experiment. We make a total of parameter updates per iteration ( in the case of single-task networks) and divide the learning rate in half after updates for the remaining of the iteration. We add a weight decay of and clamp the gradients element-wise to the range.

Apart from that, at the beginning of each episode, we pick an uniformly at random and do not perform any action for the initial time steps of that episode. This idea was also used by [8] to avoid any possible over-fitting. In addition, we also press shoot to start a new episode every time we die in Breakout, since in the first iterations the model learns that the safest option is not to start a new episode. This causes the agent to waste a lot of time without starting a new episode.


  1. We do not explain the process, but [8] give a good explanation on how this is done.
  2. The bound is not tight but it is sufficient for our purposes and straightforward to prove.


  1. Layer Normalization.
    Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. arXiv
  2. The arcade learning environment: An evaluation platform for general agents.
    Marc G. Bellemare, Yavar Naddaf, Joel Veness, and Michael Bowling. In IJCAI International Joint Conference on Artificial Intelligence, volume 2015-January, pp. 4148–4152, 2015.
  3. Deep Residual Learning for Image Recognition.
    Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. arXiv
  4. Long Short-Term Memory.
    Sepp Hochreiter and Urgen Schmidhuber. Neural computation
  5. Batch normalization: Accelerating deep network training by reducing internal covariate shift.
    Sergey Ioffe and Christian Szegedy. arXiv
  6. Deep visual-semantic alignments for generating image descriptions.
    Andrej Karpathy and Fei Fei Li. In Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition, volume 07-12-June-2015, pp. 3128–3137, 2015.
  7. Adam: A method for stochastic optimization.
    Diederik Kingma and Jimmy Ba. arXiv
  8. Human-level control through deep reinforcement learning.
    Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Andrei a Rusu, Joel Veness, Marc G Bellemare, Alex Graves, Martin Riedmiller, Andreas K Fidjeland, Georg Ostrovski, Stig Petersen, Charles Beattie, Amir Sadik, Ioannis Antonoglou, Helen King, Dharshan Kumaran, Daan Wierstra, Shane Legg, and Demis Hassabis. Nature
  9. Asynchronous Methods for Deep Reinforcement Learning.
    Volodymyr Mnih, Adrià Puigdomènech Badia, Mehdi Mirza, Alex Graves, Timothy P Lillicrap, Tim Harley, David Silver, and Koray Kavukcuoglu. arXiv
  10. Actor-Mimic: Deep Multitask and Transfer Reinforcement Learning.
    Emilio Parisotto, Jimmy Lei Ba, and Ruslan Salakhutdinov. arXiv
  11. Policy Distillation.
    Andrei A Rusu, Sergio Gomez Colmenarejo, Caglar Gulcehre, Guillaume Desjardins, James Kirkpatrick, Razvan Pascanu, Volodymyr Mnih, Koray Kavukcuoglu, and Raia Hadsell. arXiv
  12. Learning a Driving Simulator.
    Eder Santana and George Hotz. arXiv
  13. On Learning to Think: Algorithmic Information Theory for Novel Combinations of Reinforcement Learning Controllers and Recurrent Neural World Models.
    Jürgen Schmidhuber. arXiv
  14. Understanding Neural Networks Through Deep Visualization.
    Jason Yosinski, Jeff Clune, Anh Nguyen, Thomas Fuchs, and Hod Lipson. arXiv
Comments 0
Request Comment
You are adding the first comment!
How to quickly get a good reply:
  • Give credit where it’s due by listing out the positive aspects of a paper before getting into which changes should be made.
  • Be specific in your critique, and provide supporting evidence with appropriate references to substantiate general statements.
  • Your comment should inspire ideas to flow and help the author improves the paper.

The better we are at sharing our knowledge with each other, the faster we move forward.
The feedback must be of minimum 40 characters and the title a minimum of 5 characters
Add comment
Loading ...
This is a comment super asjknd jkasnjk adsnkj
The feedback must be of minumum 40 characters
The feedback must be of minumum 40 characters

You are asking your first question!
How to quickly get a good answer:
  • Keep your question short and to the point
  • Check for grammar or spelling errors.
  • Phrase it like a question
Test description