Multi-task Batch Reinforcement Learning with Metric Learning

Multi-task Batch Reinforcement Learning with Metric Learning

Abstract

We tackle the Multi-task Batch Reinforcement Learning problem. Given multiple datasets collected from different tasks, we train a multi-task policy to perform well in unseen tasks sampled from the same distribution. The task identities of the unseen tasks are not provided. To perform well, the policy must infer the task identity from collected transitions by modelling its dependency on states, actions and rewards. Because the different datasets may have state-action distributions with large divergence, the task inference module can learn to ignore the rewards and spuriously correlate only state-action pairs to the task identity, leading to poor test time performance. To robustify task inference, we propose a novel application of the triplet loss. To mine hard negative examples, we relabel the transitions from the training tasks by approximating their reward functions. When we allow further training on the unseen tasks, using the trained policy as an initialization leads to significantly faster convergence compared to randomly initialized policies (up to improvement and across 5 different Mujoco task distributions) 2. We name our method MBML (Multi-task Batch RL with Metric Learning).

1 Introduction

Combining neural networks (NN) with reinforcement learning (RL) has led to many recent advances. Since training NNs requires diverse datasets and collecting real world data is expensive, most RL successes are limited to scenarios where the data can be cheaply generated in a simulation. Since offline data is essentially free for many applications, RL methods should use it whenever possible. This is especially true because practical deployments of RL are bottle-necked by its poor sample efficiency. Algorithms that work with offline datasets are crucial for applying RL to real-world scenarios, motivating a flurry of recent works in Batch RL Siegel et al. (2020); Agarwal et al. (2019); Kumar et al. (2019); Fujimoto et al. (2019); Chen et al. (2019). These works introduce specialized algorithms to stabilize training from offline datasets. However, offline datasets are not necessarily diverse. In this work, we focus on collecting diverse datasets and how the properties of the datasets influence the policy search procedure. By collecting diverse offline datasets, we hope the networks will generalize without further training to unseen tasks or provide good initialization that speeds up convergence when we perform further on-policy training.

To collect diverse datasets, it occurs to us that we should collect data from different tasks. However, datasets collected from different tasks may have state-action distributions with large divergence. Such dataset bias presents a unique challenge in robust task inference. We provide a brief description of the problem setting, the challenge and our contributions below. For ease of exposition, we refer to such datasets as having little overlap in their state-action visitation frequencies thereafter.

We tackle the Multi-task Batch RL problem. We train a policy from multiple datasets, each generated by interaction with a different task. We measure the performance of the trained policy on unseen tasks sampled from the same task distributions as the training tasks. To perform well, the policy must first infer the identity of the unseen tasks from collected transitions and then take the appropriate actions to maximize returns. To train the policy to infer the task identity, we can train it to distinguish between the different training tasks when given transitions from the tasks as input. These transitions are referred to as the context set Rakelly et al. (2019). Ideally, the policy should model the dependency of the task identity on both the rewards and the state-action pairs in the context set. To achieve this, we can train a task identification network that maps the collected experiences, including both state-action pairs and rewards, to the task identity or some task embedding. This approach, however, tends to fail in practice. Since the training context sets do not overlap significantly in state-action visitation frequencies, it is possible that the learning procedure would minimize the loss function for task identification by only correlating the state-action pairs and ignoring rewards, which would cause mistakes in identifying testing tasks. This is an instance of the well-known phenomena of ML algorithms cheating when given the chance Chu et al. (2017) and is further illustrated in Fig. 1. We limit our explanations to the cases where the tasks differ in reward functions. Extending our approach to task distribution with different transition functions is easily done. We provide experimental results for both cases.

Our contributions are as follows. To the best of our knowledge, we are the first to highlight the issue of the task inference module learning the wrong correlation from biased dataset. We propose a novel application of the triplet loss to robustify task inference. To mine hard negative examples, we approximate the reward function of each task and relabel the rewards in the transitions from the other tasks. When we train the policy to differentiate between the original and relabelled transitions, we force it to consider the rewards since their state-action pairs are the same. Training with the triplet loss generalizes better to unseen tasks compared to alternatives. When we allow further training on the unseen tasks, using the policy trained from the offline datasets as initialization significantly increase convergence speed (up to improvement in sample efficiency).

To the best of our knowledge, the most relevant related work is Siegel et al. (2020), which is solving a different problem from ours. They assume access to the ground truth task identity and reward function of the testing task. Our policy does not know the testing task’s identity and must infer it through collected trajectories. We also do not have access to the reward function of the testing tasks.

Figure 1: A toy example to illustrate the challenge. The agent must navigate from the origin to a goal location. Left: Goal 1 and Goal 2 denote the two training tasks. The red and blue squares indicate the transitions collected from task 1 and 2 respectively. We can train the task inference module to infer the task identity to be 1 when the context set contains the red transitions and 2 when the context set contains the blue transitions. Since there are no overlap between the red and blue squares, the task inference module learns to correlate the state-action pairs to the task identity. Right: The failure of the task inference module. The policy must infer the task identity from the randomly collected transitions, denoted by the green squares. The agent needs to navigate to goal 1 during testing. However, if the green squares have more overlap with the blue squares, the task inference module will predict 2 to be the task identity. The agent therefore navigates to the wrong goal location.

2 Preliminaries and Problem Statement

To help the reader follow our explanation, we include a symbol definition table in Appendix A.

We model a task as a Markov Decision Process , with state space , action space , transition function , initial state distribution , reward function , and horizon . At each discrete timestep , the agent is in a state , picks an action , arrives at , and receives a reward . The performance measure of policy is the expected sum of rewards , where is a trajectory generated by using to interact with .

2.1 Batch Reinforcement Learning

A Batch RL algorithm solves the task using an existing batch of transitions . A recent advance in this area is Batch Constrained Q-Learning (BCQ) Fujimoto et al. (2019). Here, we explain how BCQ selects actions. Given a state , a generator outputs multiple candidate actions . A perturbation model takes as input the state-candidate action and generates small correction . The corrected action with the highest estimated value is selected as :

(1)

To help the reader follow our discussion, we illustrate graphically how BCQ selects action in Appendix B. In our paper, we use BCQ as a routine. The take-away is that BCQ takes as input a batch of transitions and outputs three learned functions .

2.2 Multi-task Batch Reinforcement Learning

Given batches, each containing transition tuples from one task, , we define the Multi-task Batch RL problem as:

(2)

where an algorithm only has access to the batches. defines a task distribution. The subscript indexes the different tasks. The tasks have the same state and action space and only differ in the transition and reward functions Zintgraf et al. (2020). We measure performance by computing average returns over unseen tasks sampled from the same task distribution. The policy is not given identity of the unseen tasks before evaluation and must infer it from collected transitions.

In multi-task RL, we can use a task inference module to infer the task identity from a context set. The context set for a task consists of transitions from task and is denoted . The task inference module takes as input and outputs a posterior over the task identity. We sample a task identity from the posterior and inputs it to the policy in addition to the state, i.e. . We model with the probabilistic and permutation-invariant architecture from Rakelly et al. (2019). outputs the parameters of a diagonal Gaussian. For conciseness, we sometimes use the term policy to also refer to the task inference module. It should be clear from the context whether we are referring to or .

We assume that each batch contains data generated by a policy while learning to solve task . Thus, if solving each task involve visiting different subspace of the state space, the different batches do not have significant overlap in their state-action visitation frequencies. This is illustrated in Fig. 1.

3 Proposed algorithm

3.1 Learning multi-task policy from offline data with distillation

In Multi-task RL, Rusu et al. (2015); Teh et al. (2017); Ghosh et al. (2017); Czarnecki et al. (2019); Parisotto et al. (2015) demonstrate the success of distilling multiple single-task policies into a multi-task policy. Inspired by these works, we propose a distillation procedure to obtain a multi-task policy in the Multi-task Batch RL setting. In Sec. 3.2, we argue such distillation procedure alone is insufficient due to the constraints the batch setting imposes on the policy search procedure.

The distillation procedure has two phases. In the first phase, we use BCQ to learn a different policy for each task, i.e. we learn different and independent policies. While we can use any Batch RL algorithm in the first phase, we use BCQ due to its simplicity. As described in Sec. 2.1, for each training batch, BCQ learns three functions: a state-action value function , a candidate action generator and a perturbation generator . The output of the first phase thus consists of three sets of networks , , and , where indexes over the training tasks.

In the second phase, we distill each set into a single network by incorporating a task inference module. The distilled function should recover different task-specific function depending on the inferred task identity. To distill the value functions into a single function , for each task , we sample a context and a pair from the batch . The task inference module takes as input and infers a task identity . Given as input, should assign similar value to as the value function for the task . The loss function with a -weighted KL regularization term Rakelly et al. (2019) is:

(3)

We also use Eq. 3 to train using the reparam trick Kingma and Welling (2013). Similarly, we distill the candidate action generators into . takes as input state , random noise and task identity . Depending on ’s value, we train to regress towards the different candidate action generator:

(4)

The bar on top of in Eq. 4 indicates the stop gradient operation. We thus do not use the gradient of Eq. 4 to train the task inference module Rakelly et al. (2019). Lastly, we distill the perturbation generators into a single network (Eq. 5). takes as input a state , a candidate action , and an inferred task identity . We train to regress towards the output of given the same state and candidate action as input. We obtain the candidate action by passing through the candidate action generator .

(5)

Note that the gradient of also updates . The final distillation loss is given in Eq. 6. We parameterize with feedforward NN as detailed in Appendix C.1.

(6)

3.2 Robust task inference with triplet loss design

Figure 2: Top: Value function distillation loss (Eq. 3) during training. Bottom: The performance of the multi-task policy trained with Eq. 6 versus BCQ.

Given the high performance of distillation in Multi-task RL Rusu et al. (2015); Teh et al. (2017); Ghosh et al. (2017); Czarnecki et al. (2019); Parisotto et al. (2015), it surprisingly performs poorly in Multi-task Batch RL, even on the training tasks. This is even more surprising because we can minimize the distillation losses (Fig. 2 top) and the single-task BCQ policies have high performance (Fig. 2 bottom). If the single-task policies perform well and we can distill them into a multi-task policy, why does the multi-task policy have poor performance? We argue the task inference module has learnt to model the posterior over task identity as conditionally dependent on only the state-action pairs in the context set , i.e. , where are random variables denoting states and actions, rather than the correct dependency where denotes the rewards.

The behavior of the trained multi-task policy supports this argument. In this experiment, each task corresponds to a running direction. To maximize returns, the policy should run with maximal velocity in the target direction. We found that the multi-task policy often runs in the wrong target direction, indicating incorrect task inference. At the beginning of evaluation, the task identity is not provided. The policy takes random actions, after which it uses the collected transitions to infer the task identity. Having learnt the wrong conditional dependency, the task inference module assigns high probability mass in the posterior to region in the task embedding space whose training batches overlap with the collected transitions (Fig. 1).

The fundamental reason behind the wrong dependency is the non-overlapping nature of the training batches. Minimizing the distillation loss does not require the policy to learn the correct but more complex dependency. The multi-task policy should imitate different single-task policy depending on which batch the context set was sampled from. If the batches do not overlap in state-action visitation frequencies, the multi-task policy can simply correlate the state-action pairs in the context with which single-task policy it should imitate. In short, if minimizing the training objective on the given datasets does not require the policy to model the dependency of the task identity on the rewards in the context set, there is no guarantee the policy will model this dependency. This is not surprising given literature on the non-identifiability of causality from observations Pearl (2009b); Peters et al. (2017). They also emphasize the benefit of using distribution change as training signal to learn the correct causal relationship Bengio et al. (2020).

Inspired by this literature, we introduce a distribution change into our dataset by approximating the reward function of each task with a learned function (training illustrated in Appendix D). Given a context set from task , we relabel the reward of each transition in using . Let index the transitions and denote the set of the relabelled transitions, we illustrate this process below:

(7)

Given the relabelled transitions, we leverage the triplet loss from the metric learning community Hermans et al. (2017) to enforce robust task inference, which is the most important design choice in MBML. Let be the number of training tasks, be a context set for task , be a context set for task () , and be the relabelled set as described above, the triplet loss for task is:

(8)

Input: Batches ; BCQ-trained , , and ; randomly initialized , and jointly parameterized by ; task inference module with randomly initialized

1:repeat
2:     Sample context set from
3:     Obtain relabelled transitions according to Eq. 7 for all pair of task
4:     Calculate using Eq. 9
5:     Calculate using Eq. 3, 4, 5
6:     Calculate using Eq. 10
7:     Update to minimize
8:until Done
Algorithm 1 Distillation and triplet loss
Figure 3: Action selection. Given , generates candidate actions . generates small corrections for the actions . The policy takes the corrected action with the highest value as estimated by .

where is the triplet margin, is the ReLU function and is a divergence measure. outputs the posterior over task identity, we thus choose to be the KL divergence.

Minimizing Eq. 8 accomplishes two goals. It encourages the task inference module to infer similar task identities when given either or as input. It also encourages to infer different task identities for and . We emphasize that the task inference module can not learn to correlate only the state-action pairs with the task identity since and contain the same state-action pairs, but they correspond to different task identities. To minimize Eq. 8, the module must model the correct conditional dependency when inferring the task identity.

Eq. 8 calculates the triplet loss when we use the learned reward function of task to relabel transitions from the remaining tasks. Following similar procedures for the remaining tasks lead to the loss:

(9)

The final loss to train the randomly initialized task inference module , the distilled value functions , the distilled candidate action generator , and the distilled perturbation generator is:

(10)

Alg. 1 illustrates the pseudo-code for the second phase of the distillation procedure. Detailed pseudo-code of the two-phases distillation procedures can be found in Appendix E. Fig. 3 briefly describes action selection from the multi-task policy. Appendix F provides detailed explanations. In theory, we can also use the relabelled transitions in Eq. 7 to train the single-task BCQ policy in the first phase, which we do not since we focus on task inference in this work.

4 Discussions

The issue of learning the wrong dependency does not surface when multi-task policies are tested in Atari tasks because their state space do not overlap Parisotto et al. (2015); Hessel et al. (2019); Espeholt et al. (2018b). Each Atari task has distinctive image-based state. The policy can perform well even when it only learns to correlate the state to the task identity. When Mujoco tasks are used to test online multi-task algorithms Zintgraf et al. (2020); Fakoor et al. (2019), the wrong dependency becomes self-correcting. If the policy infers the wrong task identity, it will collect training data which increases the overlap between the datasets of the different training tasks, correcting the issue overtime. However, in the batch setting, the policy can not collect more transitions to self-correct inaccurate task inference. Our insight also leads to exciting possibility to incorporate mechanism to quickly infer the correct causal relationship and improve sample efficiency in Multi-task RL, similar to how causal inference method has motivated new innovations in imitation learning de Haan et al. (2019).

Our first limitation is the reliance on the generalizability of simple feedforward NN. Future research can explore more sophisticated architecture, such as Graph NN with reasoning inductive bias Xu et al. (2019); Scarselli et al. (2008); Wu et al. (2020); Zhou et al. (2018) or structural causal model Pearl (2010, 2009a), to ensure accurate task inference. We also assume the learnt reward function of one task can generalize to state-action pairs from the other tasks, even when their state-action visitation frequencies do not overlap significantly. To increase the prediction accuracy, we use a reward ensemble to estimate epistemic uncertainty (Appendix D). We note that the learnt reward functions do not need to generalize to every state-action pairs, but only enough pairs so that the task inference module is forced to consider the rewards when trained to minimize Eq. 8. Crucially, we do not need to solve the task inference challenge while learning the reward funtions and using them for relabelling, allowing us to side-step the challenge of task inference.

The second limitation is in scope. We only demonstrate our results on tasks using proprioceptive states. Even though they represent high-dimensional variables in a highly nonlinear ODE, the model does not need to tackle visual complexity. The tasks we consider also have relatively dense reward functions and not binary reward functions. These tasks, such as navigation and running, are also quite simple in the spectrum of possible tasks we want an embodied agents to perform. These limitations represent exciting directions for future work.

Another interesting future direction is to apply supervised learning self-distillation techniques Xie et al. (2019); Mobahi et al. (2020), proven to improve generalization, to further improve the distillation procedure. To address the multi-task learning problem for long-horizon tasks, it would also be beneficial to consider skill discovery and composition from the batch data Peng et al. (2019); Sharma et al. (2020). However, in this setting, we still need effective methods to infer the correct task identity to perform well in unseen tasks. Our explanation in Sec. 3 only applies when the tasks differ in reward function. Extending our approach to task distributions with varying transition functions is trivial. Sec. 5 provide experimental results for both cases.

5 Experiment Results

We demonstrate the performance of our proposed algorithm (Sec. 5.1) and ablate the different design choices (Sec. 5.2). Sec. 5.3 shows that the multi-task policy can serve as a good initialization, significantly speeding up training on unseen tasks. Appendix C provides all hyper-parameters.

5.1 Performance evaluation on unseen tasks

Figure 4: Results on unseen test tasks. x-axis is training epochs. y-axis is average episode returns. The shaded areas denote one std.

We evaluate in five challenging task distributions from MuJoCo Todorov et al. (2012) and a modified task distribution UmazeGoal-M from D4RL Fu et al. (2020). In AntDir and HumanoidDir-M, a target direction defines a task. The agent maximizes returns by running with maximal speed in the target direction. In AntGoal and UmazeGoal-M, a task is defined by a goal location, to which the agent should navigate. In HalfCheetahVel, a task is defined as a constant velocity the agent should achieve. We also consider the WalkerParam environment where random physical parameters parameterize the agent, inducing different transition functions in each task. The state for each task distribution is the OpenAI gym state. We do not include the task-specific information, such as the goal location or the target velocity in the state. The target directions and goals are sampled from a circular arc. Details of these task distributions can be found in Appendix H.1.

We argue that the version of HumanoidDir used in prior works does not represent a meaningful task distribution, where a single task policy can already achieve the optimal performance on unseen tasks. We thus modify the task distribution so that a policy has to infer the task identity to perform well, and denote it as HumanoidDir-M. More details of this task distribution can be found in Appendix G.

There are two natural baselines. The first is by modifying PEARL Rakelly et al. (2019) to train from the batch, instead of allowing PEARL to collect more transitions. We thus do not execute line in Algorithm 1 in the PEARL paper. On line 13, we sample the context and the RL batch uniformly from the batch. The second baseline is Contextual BCQ. We modify the networks in BCQ to accept the inferred task identity as input. We train the task inference module using the gradient of the value function loss. MBML and the baselines have the same network architecture. We are very much inspired by PEARL and BCQ. However, we do not expect PEARL to perform well in our setting because it does not explicitly handle the difficulties of learning from a batch without interactions. We also expect that our proposed algorithm will outperform Contextual BCQ thanks to more robust task inference.

We measure performance by the average returns over unseen tasks, sampled from the same task distribution. We do not count the first two episodes’ returns Rakelly et al. (2019). We obtain the batch for each training task by training Soft Actor Critic (SAC) Haarnoja et al. (2018) with a fixed number of environment interactions. Appendix H provide more details on the environment setups and training procedures of the baselines.

From Fig. 4, MBML outperforms the baselines by a healthy margin in all task distributions. Even though PEARL does not explicitly handle the challenge of training from an offline batch, it is remarkably stable, only diverging in AntDir. Contextual BCQ is stable, but converges to a lower performance than MBML in all task distributions. An astude reader will notice the issue of overfitting, for example Contextual BCQ in HumanoidDir-M. Since our paper is not about determining early stopping conditions and to ensure fair comparisons among the different algorithms, we compute the performance comparisons using the best results achieved by each algorithm during training.

We also compare with MetaGenRL Kirsch et al. (2019). Since it relies on DDPG Lillicrap et al. (2015) to estimate value functions, which diverges in Batch RL Fujimoto et al. (2019), we do not expect it to perform well in our setting. Fig. LABEL:fig:metagenrl confirms this, where its performance quickly plummets and does not recover with more training. Combining MetaGenRL and MBML is interesting since MetaGenRL generalizes to out-of-distribution tasks.

5.2 Ablations

Figure 5: Ablation study. x-axis is training epochs. y-axis is average episode returns. The shaded areas denote one std.

We emphasize that our contributions lie in the triplet loss design coupled with transitions relabelling. Below, we provide ablation studies to demonstrate that both are crucial to obtain superior performance.

No relabelling. To obtain hard negative examples, we search over a mini-batch to find the hardest positive-anchor and negative-anchor pairs, a successful and strong baseline from metric learning Hermans et al. (2017). This requires sampling context sets for each task , where indexes the context sets sampled for each task. Let be the number of training tasks, the triplet loss is:

(11)

The term finds the positive-anchor pair for task by considering every pair of context sets from task and selecting the pair with the largest divergence in the posterior over task identities. The term finds the negative-anchor pair for task by considering every possible pair between the context sets sampled for task and the context sets sampled for the other tasks. It then selects the pair with the lowest divergence in the posterior over task identities as the negative-anchor pair.

No triplet loss. We train the task inference module using only gradient of the value function distillation loss (Eq. 3). To use the relabelled transitions, the module also takes as input the relabelled transitions during training. More concretely, given the context set from task , we sample an equal number of relabelled transitions from the other tasks . During training, the input to the task inference module is the union of the context set and the sampled relabelled transitions . In the full model, we also perform similar modification to the input of the module during training.

No transition relabelling and no triplet loss. This method is a simple combination of a task inference module and the distillation process. We refer to this algorithm as Neither in the graphs.

Fig. 5 compares our full model and the ablated versions. Our full model obtains higher returns than most of the ablated versions. For WalkerParam, our full model does not exhibit improvement over Neither. However, from Fig. 4, our full model significantly outperforms the baselines. We thus conclude that, in WalkerParam, the improvement over the baselines comes from distillation.

Comparing to the No triplet loss ablation, transition relabelling leads to more efficient computation of the triplet loss. Without the relabelled transitions, computing Eq. 11 requires . Our loss in Eq. 9 only requires . We also need to relabel the transitions only once before training the multi-task policy. It is also trivial to parallelize across tasks.

We also study reward estimation accuracy. Fig. LABEL:fig:reward_error shows that our reward model achieves low error on state-action pairs from another task, both with and without an ensemble. We also compare MBML against an ablated version that uses the ground truth reward function for relabelling on UmazeGoal-M. The model trained using the ground truth reward function only performs slightly better than the model trained using the learned reward function. We include in Appendix I experiments on margin sensitivity analysis and the benefit of the reward ensemble.

5.3 Using the multi-task policy to enable faster convergence when training on unseen tasks

While the multi-task policy generalize to unseen tasks, its performance is not optimal. If we allow further training, initializing networks with our multi-task policy significantly speeds up convergence to the optimal performance.

The initialization process is as followed. Given a new task, we use the multi-task policy to collect 10K transitions. We then train a new policy to imitate the actions taken by maximizing their log likelihood. As commonly done, the new policy outputs the mean and variance of a diagonal Gaussian distribution. The new policy does not take a task identity as input. The task inference module infers a task identity z from the 10K transitions. Fixing z as input, the distilled value function initializes the new value function. Given the new policy and the initialized value function, we train them with SAC by collecting more data. To stabilize training, we perform target policy smoothing Fujimoto et al. (2018) and double-Q learning Van Hasselt et al. (2016) by training two identically initialized value functions with different mini-batches (pseudo-codes and more motivations in Appendix J.1).

Fig. 6 compares the performance of the policies initialized with our multi-task policy to randomly initialized policies. Initializing the policies with the MBML policy significantly increases convergence speed in all five task distributions, demonstrating our method’s robustness. Even in the complex HumanoidDir-M task distribution, our method significantly speeds up the convergence, requiring only 85K environment interactions, while the randomly initialized policies require 350K, representing a improvement in sample efficiency. Similar conclusions hold when comparing against randomly initialized SAC where the two value functions are trained using different mini-batches (Appendix J.2). We also note that our initialization method does not require extensive hyper-parameter tuning.

: SAC initialized by our multi-task policy (Ours)   : Randomly initialized SAC (Random)

Figure 6: Initialization results on five task distributions. x-axis is number of interactions in thousands. y-axis is the average episode returns over unseen tasks. The shaded areas denote one std.

6 Related Works

Batch RL Recent advances in Batch RL Agarwal et al. (2019); Kumar et al. (2019); Fujimoto et al. (2019); Chen et al. (2019); Kumar et al. (2020) focus on the single-task setting, which does not require training a task inference module. Thus they are not directly applicable to the Multi-task Batch RL. Siegel et al. (2020); Cabi et al. (2020) also consider the multi-task setting but assume access to the ground truth task identity and reward function of the test tasks. Our problem setting also differs, where the different training batches do not have significant overlap in state-action visitation frequencies, leading to the challenge of learning a robust task inference module.

Task inference in multi-task setting The challenge of task inference in a multi-task setting has been tackled under various umbrellas. Meta RL Rakelly et al. (2019); Zintgraf et al. (2020); Fakoor et al. (2019); Humplik et al. (2019); Lan et al. (2019); Sæmundsson et al. (2018); Zintgraf et al. (2019) trains a task inference module to infer the task identity from a context set. We also follow this paradigm. However, our setting presents additional challenge to train a robust task inference module, which motivates our novel triplet loss design. As the choice of loss function is crucial to train an successful task inference module in our settings, we will explore the other loss functions, e.g. loss functions discussed in Roth et al. (2020), in future work. Other multi-task RL works Espeholt et al. (2018a); Yang et al. (2020); Yu et al. (); D’Eramo et al. (2019) focus on training a good multi-task policy, rather than the task inference module, which is an orthogonal research direction to ours.

Meta RL Meta RL Lan et al. (2019); Wang et al. (2016); Duan et al. (2016); Finn et al. (2017); Nichol et al. (2018); Houthooft et al. (2018) optimizes for quick adaptation. However, they require interactions with the environment during training. Even though we do not explicitly optimize for quick adaptation, we demonstrate that initializing a model-free RL algorithm with our policy significantly speeds up convergence on unseen tasks. Fakoor et al. (2019) uses the data from the training tasks to speed up convergence when learning on new tasks by propensity estimation techniques. This approach is orthogonal to ours and can potentially be combined to yield even greater performance improvement.

7 Conclusion

In Multi-task Batch RL, the non-overlapping nature of the different training batches present a new and unique challenge in learning a robust task inference module. We propose a novel application of the triplet loss to robustify task inference. To mine hard negative examples for the triplet loss, we approximate the reward functions of the different training tasks and relabel the reward value of the transitions from each task. Using a simple initialization procedure, our multi-task policy significantly accelerates convergence speed when we perform further on-policy training on unseen tasks.

Acknowledgement

We would like to acknowledge Professor Keith Ross (NYU) for initial discussions and inspiration for this work. We would like to thank Fangchen Liu (UC Berkeley) for pointing out a figure issue right before the paper submission deadline. Computing needs were supported by the Nautilus Pacific Research Platform.

Broader Impact

Positive impact

Our work provides a solution to learn a policy that generalizes to a set of similar tasks from only observational data. The techniques we propose have great potential to benefit various areas of the whole society. For example in the field of healthcare, we hope the proposed triplet loss design with hard negative mining can enable us to robustly train an automatic medical prescription system from a large batch of medical histories of different diseases and further generalize to new diseases Choi et al. (2019), e.g., COVID-19. Moreover, in the field of robotics, our methods can enable the learning of a single policy that solves a set of similar unseen tasks from only historical robot experiences, which tackles the sample efficiency issues given that sampling is expensive in the field of real-world robotics Cabi et al. (2020). Even though in some fields that require safe action selections, e.g, autonomous driving Geiger et al. (2012) and medical prescription, our learned policy cannot be immediately applied, it can still serve as a good prior to accelerate further training.

Negative impact

Evidently, the algorithm we proposed is a data-driven methods. Therefore, it is very likely that it will be biased by the training data. Therefore, if the testing tasks are very different from the training tasks, the learned policy may even result in worse behaviors than random policy, leading to safety issues. This will motivate research into safe action selection and distributional shift identification when learning policies for sequential process from only observational data.

\appendixpage\addappheadtotoc

Appendix A Symbol definition

Symbol Definition Dimension
state space
action space
transition function
initial state distribution
reward function
horizon
MDP , which defines a task -
task distribution -
dimension of the policy parameter
number of task
trajectory generated by interacting with -
state at time step
action selected at time step
corrected action
state at time step
policy function, parameterized by
number of transition tuples from one task batch
batch of transition tuples for task -
learned reward function for task
number of transition tuples in a context set
context set for task
relabeled context set by uing
Union of relabeled context set by uing
task identity for task
task inference module, parameterized by
, , , random variables: states, actions, rewards, task identity -
expected sum of rewards in induced by policy
expected sum of rewards in induced by policy
Q value function
Q value function for task
distilled Q value function
candidate action generator
candidate action generator for task
distilled candidate action generator
perturbation generator
perturbation generator for task
distilled perturbation generator
standard Gaussian distribution -
noise sampled from standard Gaussian distribution
triplet margin
divergence measure -
KL KL divergence -
stop gradient operation -
loss function to distill -
loss function to distill -
loss function to distill -
total distillation loss -
triplet loss for task -
mean triplet loss across all tasks -
final loss: -
Table 1: Symbol definition. Some of the symbol are overloaded. We make sure each term is clearly defined given the context.

Appendix B Action selection of BCQ policy

Figure 7: Action selection procedure of BCQ.

In this section, we provide the detailed action selection procedures for BCQ. To pick action given a state , we first sample a set of small noises from the standard Gaussian distribution. For each , the candidate action generator will generate a candidate action for state . For each of the candidate actions , the perturbation model will generate a small correction term by taking as input the state-candidate action pair. Therefore, a set of corrected candidate actions will be generated for the state . The corrected candidate action with the highest estimated value will be selected as .

Appendix C Hyper-parameters

c.1 Hyper-parameters of our proposed models

Hyper-parameters Value
Number of evaluation episodes 5
Task identity dimension 20
Number of candidate actions 10
Learning rate 0.0003
Training batch size 128
Context set size 64
KL regularization weighting term 0.1
Triplet margin 2.0
Reward prediction ensemble AntDir, AntGoal: 0.1
WalkerParam: 0.1
HumanoidDir-M: 0.2
HalfCheetahVel: 0.05
UmazeGoal-M: 0.02
Next state prediction ensemble 0.1
architecture MLP with 9 hidden layers, 1024 nodes each, ReLU activation
architecture MLP with 7 hidden layers, 1024 nodes each, ReLU activation
architecture MLP with 8 hidden layers, 1024 nodes each, ReLU activation
Table 2: Hyper-parameters of our proposed model
Hyper-parameters Value
Learning rate 0.0003
Training batch size 128
Reward prediction ensemble size 20
Reward prediction network architecture MLP with 1 hidden layers, 128 nodes, ReLU activation
Next state prediction ensemble size 20
Next state prediction network architecture MLP with 6 hidden layers, 256 nodes each, ReLU activation
Table 3: Hyper-parameters of reward and next state prediction ensemble

Table 2 provides the hyper-parameters for our proposed model and all of its ablated versions (Sec. 3, Sec. 5.2). The hyper-parameters for the reward ensembles and next state prediction ensembles are provided in Table 3. Our model uses the task inference module from PEARL with the same architecture, described in Table 4. Since the scale of the reward in different task distributions are different, we need to use different values for the reward prediction ensemble threshold .

We did not conduct extensive search to determine the hyper-parameters. Instead, we reuse some default hyper-parameter settings from the other multi-task learning literature on the MuJoCo benchmarks [41, 15]. As for the architecture of the distillation networks, we select reasonably deep networks.

When using BCQ to train the single-task policies in the first phase of the distillation procedure, we use the default hyper-parameters in the official implementation of BCQ, except for the learning rate, which decreases from to . We find lowering the learning rate leads to more stable learning for BCQ.

c.2 Hyper-parameters of Contextual BCQ

For Contextual BCQ, the value function, decoder, and perturbation model have the same architecture as in our model. The encoder also has the same architecture as the decoder. The task inference module has the same architecture as the task inference module in PEARL, described in Table 4.

The context set size used during training Contextual BCQ is , twice the size of the context set in our model. This is because during training of our model, we use the combination of context transitions and the same number of relabelled transitions from the other tasks to infer the posterior over task identity, as detailed in Sec. 5.2 and pseudo-codes provided in Alg. 4. Therefore, the effective number of transitions that are used as input into the task inference module during training are the same for our model and Contextual BCQ.

Unless stated otherwise, for the remaining hyper-parameters, such as the maximum value of the perturbation, we use the default value in BCQ.

c.3 Hyper-parameters of PEARL

Hyper-parameters Value
Task inference module architecture MLP with 3 hidden layers, 200 nodes each, ReLU activation
Table 4: Hyper-parameters of PEARL

We use the default hyper-parameters as provided in the official implementation of PEARL. For completeness when discussing the hyper-parameters of our model, we provide the architecture of the task inference module in Table 4.

c.4 Hyper-parameters of ablation studies of the full model

Hyper-parameters Value
Number of sampled context sets 10
Context set size 128
Table 5: Hyper-parameters of No transition relabelling
Hyper-parameters Value
Context set size 64
Table 6: Hyper-parameters of No triplet loss
Hyper-parameters Value
Context set size 128
Table 7: Hyper-parameters of Neither

Table 5, Table 6 and Table 7 provide the hyper-parameters for the ablated versions of our full model No transition relabelling, No triplet loss, and Neither, respectively. Without the transition relabelling techniques, No transition relabelling and Neither set the size of training context size to as Contextual BCQ to use the same effective number of transitions to infer the posterior over the task identity as our full model. Note that the remaining hyper-parameters of these methods are set to be the same as our full model, described in Table 2.

c.5 Hyper-parameters when we initialize SAC with our multi-task policy

Hyper-parameters Value
Q function architecture MLP with 9 hidden layers, 1024 nodes each, ReLU activation
Q function target smoothing rate 0.005
policy target smoothing rate 0.1
Table 8: Hyper-parameters of SAC when initialized by our multi-task policy

The architecture of the Q function network is the same as the distilled Q function in Table 2. The Q function target smoothing rate is the same as the standard SAC implementation [22]. The policy target smoothing rate is searched over . For the SAC trained from random initialization baseline (Appendix J.2), we also change the sizes of the value function to the same value in Table 8. For the remaining hyper-parameters, we use the default hyper-parameter settings of SAC.

Appendix D Reward prediction ensemble

Input: data batch ; with randomly initialized parameters.

1:for a fixed number of iterations do
2:     Sample a transition from
3:     Obtain the predicted reward
4:     Update parameters of to minimize through gradient descent.
5:end for

Output: trained reward function approximator

Algorithm 2 Training procedure of reward function approximator

Input:an ensemble of learned reward functions ; context set from task , a threshold .

1:
2:for  do
3:     if  then
4:         
5:         Add to
6:     end if
7:end for

Output: relabelled transitions

Algorithm 3 Relabel transition from task to task

In subsection 3.2, we propose to train a reward function approximator for each training task to relabel the transitions from the other tasks. To increase the accuracy of the estimated reward, for each task , we use an ensemble of learnt reward functions , where indexes the task and indexes the function in the ensemble. The training procedures for each reward function approximator in the ensemble are provided in Alg. 2.

The pseudo-code for generating relabelled context set from context set of task is given in Alg. 3. We use the output of the ensemble as an estimate of the epistemic uncertainty in the reward prediction [8]. Concretely, for each transition in , we only include it in the relabelled set if the standard deviations of the ensemble output is below a certain threshold (line 3). We also use the mean of the outputs as the estimated reward (line 4).

We conduct ablation study of the reward prediction ensemble in Appendix I.2, where we show that the use of reward prediction ensemble improves the performance when initializing SAC with our multi-task policy.

Appendix E Detailed pseudo-codes of the two-phases distillation procedures

In this section, we provide the detailed pseudo-code in Alg. 4 for the two-phases distillation procedures introduced in Sec. 3. The basic idea is that we first obtain single-task policy for each training task using BCQ. In the second phase, we distill the single-task policies into a multi-task policy by incorporating a task inference module. Note that the task inference module is trained by minimizing the Q value function distillation loss (Eq. 3) and the triplet loss (Eq. 9).

Line 1 describes the first phase of the two-phases distillation procedure. We use BCQ to learn a state-action value function , a candidate action generator and a perturbation generator for each training batch .

We next enter the second phase. We first sample context set of size from , in line 3. Line 5-10 provide the procedures to calculate the triplet loss. For each task , we relabel the reward of each transition in all the remaining context set using and obtain in line 5. From the union of the relabelled context set , we sample a subset of size in line 6. Denote transitions in originated from as . Further denote transitions in before relabelling as , we thus have . These sets of transitions have the following relationships:

(12)

To calculate the triplet loss for task , in line 9 we sample a subset with the same number of transitions as from , i.e. for each . Therefore, the triplet loss for task can be given by Eq. 13.

Line 11-13 provide the procedures to infer the task identity for each task . We use the union of the context set and the relabeled context set sampled from to infer the posterior over task identity. We next sample the task identity from .

To calculate the distillation loss of each distilled function, in line 14 we sample the training batch of transitions from . With and the training transition batch, we can calculate the value function distillation loss of task using Eq. 14. To calculate the distillation loss of the candidate action generator and perturbation generator of task , we first sample noises from the standard Gaussian distribution in line 16. In line 17, we then obtain the candidate actions for each state in the training batch. The calculations to derive and for task follow Eq. 15 and Eq. 16, respectively.

After repeating the procedures for all the training tasks, in line 21-24 we average the losses across tasks and obtain , , , and . At the end of each iteration, we update and by minimizing in line 25.

Input: Batches ; trained reward function ; randomly initialized , and jointly parameterized by ; task inference module with randomly initialized ; context set size ; training batch size ; triplet margin

1:Learn single task policy , , and from each data batch using BCQ,
2:repeat
3:     Sample context set from
4:     for  do
5:         Obtain the relabelled context set from with according to Alg. 3,
6:         Sample a subset of relabelled context set : ,
7:         Denote transitions in originated from as
8:         Denote transitions in before relabelling as ,
9:         Sample a subset from with ,
10:         Calculate the triplet loss
(13)
11:         Combine and to form the new context set
12:         Infer the posterior over task identity from
13:         Sample task identity
14:         Sample training batch:
15:         Calculate the value function distillation loss
(14)
16:         Sample noises: ,
17:         Obtain candidate action from : ,
18:         Calculate the candidate action generator distillation loss
(15)
19:         Calculate the perturbation generator distillation loss
(16)
20:     end for
21:     Calculate
22:     Calculate
23:     Calculate
24:     Calculate
25:     Update to minimize
26:until Done
Algorithm 4 Two-phases distillation procedure with novel triplet loss design

Appendix F Action selection and evaluation of the multi-task policy

Figure 8: Action selection. Given context set , infer the posterior over task identity, from which we sample the task identity . With the task identity , generates multiple candidate actions for state . generates small corrections for the candidate actions . The policy takes the corrected action with the highest value as estimated by .

In this section, we will describe the action selection procedures from the multi-task policy as shown in Fig. 8, and how we evaluate its performance.

Sampling action given a state from the multi-task policy is similar to the procedures of BCQ (Appendix B). The main difference is that the networks also take an inferred task identity as input. Concretely, given a state , the distilled candidate action generator generates multiple candidate actions with random noise . The distilled perturbation generator generates a small correction term for each state-candidate action pair. We take the corrected action with the highest value as estimated by the distilled value function . The action selection procedures can be summarized by:

(17)

We elaborate the evaluation procedures in Alg. 5. When testing on a new task, we do not have the ground truth task identity or any transition from the task to infer the task identity. We thus sample the initial task identity from the standard Gaussian prior in line 1. The task identity is kept fixed for the duration of the first episode. Afterwards, we use the collected transitions to infer the posterior and sample new task identities before each new episode as described in line 3. When calculating the average episode returns, we do not count the first two episodes’ returns as what is done in [41].

Input: unseen task ; learned multi-task policy

1:Initialize context set ; initialize
2:repeat
3:     Sample task identity .
4:     Collect one episode of transitions from task with multi-task policy conditioned on .
5:     Add to c.
6:until Done

Output: average episode returns, not counting the first two episodes

Algorithm 5 Evaluation procedures of our model

Appendix G On Modifying the original HumanoidDir task distribution

We are concerned the original HumanoidDir task distribution is not suitable as a benchmark for multi-task RL because a policy trained from a single task can already obtain the optimal performance on unseen tasks. In particular, we train BCQ with transitions from one task and it obtains a similar return, as measured on unseen tasks (), to SAC trained from scratch separately for each task ().

In the HumanoidDir task distribution, each task is defined by a target running direction. The intended task is for the agent to run with maximal velocity in the target direction. The reward of each task can be defined as below:

(18)

where denotes the inner product. Note that the two cost terms tend to be very small thus it will be reasonable to omit them in analysis. The is the same across different tasks and is a constant. The is different across tasks. weights their relative contribution to the reward. If is too small, the reward is dominated by the constant . In this case, to achieve good performance, the agent does not need to perform the intended task. In other word, the agent does not need to infer the task identity to obtain good performance and only needs to remain close to the initial state while avoiding terminal states to maximize the episode length.

Prior works that use HumanoidDir set and . With such a small value for the reward coefficient , the reward is dominated by the . We provide video to illustrate that in different tasks, the SAC-trained single-task policies display similar behaviors even though the different tasks have different running directions3. In most tasks, the SAC-trained policy controls the Humanoid to stay upright near the initial state, which is enough to obtain high performance. If a single policy that controls the agent to stay upright can achieve high performance in all tasks sampled from this task distribution, we argue that the learned multi-task policy in this task distribution can achieve near-optimal performance across tasks without the need to perform accurate task inference. In other word, this task distribution is not suitable to demonstrate the test-time task inference challenge identified in our work.

Therefore, we set , which is the value used in the OpenAI implementation of Humanoid4, and denote the modified task distribution as HumanoidDir-M. As is shown in the video, the SAC-trained agent in our case runs with significant velocity in the target direction. The optimal behaviors among the different tasks are thus sufficiently different such that the multi-task policy needs to infer the task identity to obtain high performance.

Appendix H Details of the environmental settings and baseline algorithms

In this section, we will first provide the details of environmental settings in Appendix H.1, and then describe the baseline algorithms we compare against in Sec 5. We explain PEARL in Appendix H.2 and Contextual BCQ in Appendix H.3.

h.1 Environment setups

We construct the task distribution UmazeGoal-M by modifying the maze-umaze-dense-v1 from D4RL. We always reset the agent from the medium of the U shape maze, while the goal locations is randomly initialized around the two corners of the maze.

The episode length is 1000 for HalfCheetahVel, which is the episode length commonly used when model-free algorithms are tested in the single-task variant of these task distributions. We use the same episode length 300 as D4RL for UmazeGoal-M. In the remaining task distributions, we set the episode length to be 200 due to constrained computational budget.

Table 9 provides details on each task distribution, including the number of training tasks and number of testing tasks. Note that the set of training tasks and the set of testing tasks do not overlap. The column “Interactions” specifies the number of transitions available for each task. With the selected number of interactions with the environment, we expect the final performance of training SAC in each task to be slightly below the optimal performance. In other word, we do not expect the batch data to contain a large amount of trajectories with high episode returns.

Num train tasks Num test tasks Interactions SAC returns BCQ returns
HalfCheetahVel 10 8 60K
AntDir 10 8 200K
AntGoal 10 8 300K
WalkerParam 30 8 300K
HumanoidDir-M 10 8 600K
UmazeGoal-M 10 8 30K
Table 9: Details of the experimental settings

h.2 PEARL under Batch RL setting

Our works are very much inspired by PEARL [41], which is the state-of-the-art algorithm designed for optimizing the multi-task objective in various MuJoCo benchmarks. By including the results for PEARL, we demonstrate that conventional algorithms that require interaction with the environment during training does not perform well in the Multi-task Batch RL setting, which motivates our work.

To help readers understand the changes we made to adapt PEARL to the Batch RL setting, we reuse the notations from the original PEARL paper in this section. Detailed training procedures are provided in Algorithm 6. Without the privilege to interact with the environment, PEARL proceeds to sample the context set from the task batch in line 5. The task inference module , parameterized by takes as input the context set to infer the posterior . In line 6, we sample the task identity from . In line 7-9, the task identity combined with the RL mini-batch is further input into the SAC module. For task , defines the actor loss, and defines the critic loss. constrains the inferred posterior over task identity from context set