Selective Transfer with Reinforced Transfer Network for Partial Domain Adaptation
Partial domain adaptation (PDA) extends standard domain adaptation to a more realistic scenario where the target domain only has a subset of classes from the source domain. The key challenge of PDA is how to select the relevant samples in the shared classes for knowledge transfer. Previous PDA methods tackle this problem by re-weighting the source samples based on the prediction of classifier or discriminator, thus discarding the pixel-level information. In this paper, to utilize both high-level and pixel-level information, we propose a reinforced transfer network (RTNet), which is the first work to apply reinforcement learning to address the PDA problem. The RTNet simultaneously mitigates the negative transfer by adopting a reinforced data selector to filter out outlier source classes, and promotes the positive transfer by employing a domain adaptation model to minimize the distribution discrepancy in the shared label space. Extensive experiments indicate that RTNet can achieve state-of-the-art performance for partial domain adaptation tasks on several benchmark datasets. Codes and datasets will be available online.
Deep neural networks have achieved impressive performance in a variety of applications. However, when applied to related but different domains, the generalization ability of the learned model may be severely degraded due to the harmful effects of the domain shift . Re-collecting labeled data from the coming new domain is prohibitive because of the huge cost of data annotation. Domain adaptation techniques solve such a problem by transferring knowledge from a source domain with rich labeled data to a target domain where labels are scarce or unavailable. These domain adaptation methods learn domain-invariant feature representations by moment matching [17, 24, 7] or adversarial training [25, 10].
The previous domain adaptation methods generally assume that the source and target domains have shared label space. However, in real applications, it is usually formidable to find a relevant source domain with identical label space as the target domain of interest. Thus, a more realistic scenario is partial domain adaptation (PDA) , which relaxes the constraint that the source and target domains share the label space and assumes that the unknown target label space is a subset of the source label space. In such a scenario, as shown in Figure 1a, existing standard domain adaptation methods force the matching between the outlier source class (blue triangle) and the unrelated target class (red square) by aligning the whole source domain with the target domain. As a result, the negative transfer may be triggered due to the mismatch. Negative transfer is a dilemma that the transfer model performs even worse than the non-adaptation (NoA) model .
Several approaches have been proposed to solve the partial domain adaptation problem by re-weighting the source samples in the domain-adversarial network. These weights can be obtained from the distribution of the predicted target label probabilities  or the prediction of the domain discriminator [3, 29]. However, these methods ignore pixel-level features when determining weights, thereby losing global correlation information. Moreover, these PDA modules based on adversarial networks are difficult to integrate into domain adaptation methods based on moment matching because they lack discriminators to filter outlier classes. Therefore, most advanced standard domain adaptation methods based on moment matching are hard to extend to address PDA problem.
In this paper, to mitigate negative transfer, we present a reinforced transfer network (RTNet), as shown in Figure 1b, which exploits reinforcement learning to learn a reinforced data selector for filtering out outlier source samples automatically. The motivation of considering pixel-level information is that the source samples related to the target domain will have smaller reconstruction errors than outlier source classes on the generator trained with target samples. For example, an outlier triangle source sample will have a larger reconstruction error than a square source sample on the target generator because the target generator lacks training samples of the triangle category and outlier source samples extremely dissimilar to the target classes. Hence, the reconstruction error can measure the similarity between each source sample and the target domain. To utilize both the pixel-level and high-level information to select source samples related to the target domain, we design a reinforced data selector. Specifically, the reinforced data selector takes action (keep or drop a sample) based on the state of the sample. Then, the reconstruction error of the selected source samples on the target generator is used as a reward to guide the learning of the selector via the actor-critic algorithm . It’s worth noting that the state contains high-level information, and the reward contains pixel-level information. The contribution of this work is that we design a novel reinforced data selector based on reinforcement learning, which solves the PDA problem by taking into account high-level and pixel-level information to select related samples for positive transfer. In addition, most deep domain adaptation methods can be extended to solve PDA problem by integrating the module.
Partial Domain Adaptation
Deep domain adaptation methods have been widely studied in recent years. These methods extend deep neural networks by embedding adaptation layers for moment matching [26, 15, 24, 6] or adding domain discriminators for adversarial training [10, 25]. However, these methods may be restricted by the assumption that the source and target domains share the same label space, which is not held in the PDA scenario. Several methods have been proposed to solve the PDA problem. Selective adversarial network (SAN)  trains a separate domain discriminator for each class with a weight mechanism to suppress the harmful influence of the outlier classes. Partial adversarial domain adaptation (PADA)  improves SAN by adopting only one domain discriminator and obtains the weight of each class based on the predicted target probability distribution of the classifier. Example transfer network (ETN)  automatically quantifies the weights of source examples based on their similarities to the target domain. Unlike previous PDA methods, only high-level information was used to select source samples. Our proposed RTNet combines pixel-level and high-level information to achieve more accurate outlier source sample filtering.
Reinforcement learning (RL) can be roughly divided into two categories : value-based methods and policy-based methods. The value-based methods estimate future expected total rewards through a state, such as SARSA  and deep Q network . Policy-based methods try to directly find the next best action in the current state, such as REINFORCE algorithm . To reduce variance, some methods combine value-based and policy-based methods for more stable training, such as the actor-critic algorithm . So far, data selection based on RL has been applied in the fields of active learning , co-training , text matching , etc. However, there is a lack of reinforced data selection methods to solve PDA problem.
Problem Definition and Notations
In this work, based on PDA settings, we define the labeled source dataset as from the source domain associated with classes, and define the unlabeled target dataset as from the target domain associated with classes. Note that, the target label space is contained in the source label space, i.e., and is unknown. The two domains follow different marginal distributions, and , respectively, we further have . is the distribution of source samples in the target label space. The goal is to improve the performance of the classifier in with the help of the knowledge in associated with and make and robust enough to outlier samples.
Overiew of RTNet
As shown in Figure 2, RTNet consists of two components: a domain adaptation (DA) model ( and ) and a reinforced data selector (, and ). The DA model promotes positive transfer by reducing distribution shift between source and target domains in the shared label space. The reinforced data selector based on RL mitigates negative transfer by filtering out outlier source classes. Specially, to filter out outlier source samples, the policy network considers high-level information provided by feature extractor and classifier for decision making to get selected source samples . For the backbone of DA model, takes source transfer features as input to produce label predictions and achieves distribution alignment between and . Meanwhile, the selected source sample reconstruction errors based on are used as rewards to encourage to select samples with small reconstruction errors. For the stability of training, based on actor-critic algorithm, we use a value network combined with rewards to optimize . Besides, the domain-specific generators and trained with reconstruction errors of reconstructed source images and target images , respectively.
Domain Adaptation Model
Almost all partial domain adaptation frameworks are based on adversarial network [3, 4, 29, 5], which has led to many existing advanced domain adaptation algorithms based on moment matching cannot be extended to solve PDA problem. The proposed reinforced data selector is a general module that can be integrated into most UDA frameworks. Hence, we use deep CORAL as the base domain adaptation model to prove that reinforced data selector can be embedded into the matching-based UDA framework to make it robust to PDA scene. In the following, we will give a brief introduction to the main ideas of CORAL.
We define the last layer of as adaptation layer and reduce the distribution shift between source and target domains by aligning the covariance of source and target features. Hence, the CORAL objective function is as follows:
where denotes the squared matrix Frobenius norm, and represent source and target transferable features output by the adaptation layer, respectively, is the batch ID, and n is the batch size. and represent the covariance matrices, which can be computed as , and . is the centralized matrix, where is an all-one column vector.
To ensure the shared feature extractor and classifier can be trained with supervision on labeled samples, we define a standard cross-entropy classification loss with respect to labeled source samples. Moreover, to encourage the target domain to have a nice manifold structure and thus increase the contributions of target data for better transfer, we expand CORAL by adopting the entropy minimization principle . Let , the entropy objective function utilized to quantify the uncertainty of the predicted label of the target sample can be computed as . Note that, the entropy minimization only constrains . We believe that if the entropy minimization applied to both and as in , the target domain samples are easily stuck into the wrong class in early training due to the large domain gap and are difficult to correct afterward. Formally, the full objective function for the domain adaptation model is as follows:
where hyperparameters and control the impact of the corresponding objective functions. However, in the PDA scenario, most UDA methods (e.g. CORAL) may trigger negative transfer since these methods force alignment of the global distributions and , even though and are non-overlapping and cannot be aligned during transfer. Thus, the motivation of the reinforced data selector is to mitigate negative transfer by filtering out the source outlier classes before performing the distribution alignment.
Reinforced Data Selector
We consider the source sample selection process of RTNet as Markov decision process, which can be addressed by RL. The reinforced data selector is an agent that interacts with the environment created by the domain adaptation model. The agent takes action to keep or drop a source sample based on the policy function. The domain adaptation model evaluates the actions taken by the agent and provides a reward to guide the agent’s learning. The goal of the agent is to maximize the reward through the actions taken.
As shown in Figure 2, given a batch of source samples , where represents the batch ID and is the batch size. We can obtain the corresponding states through the domain adaptation model. The reinforced data selector then utilizes the policy to determine the actions taken on source samples, where . means to filter out outlier source sample from . Thus, we get a new source batch related to the target domain. Instead of , we feed into the domain adaptation model to solve PDA problem. Finally, the domain adaptation model moves to the next state after updated with and , and provides a reward according to the source reconstruction errors based on to update policy and value networks. In the following sections, we will give a detailed introduction to the state, action, and reward.
In RTNet, state is defined as a vector . In order to simultaneously consider the unique information of each source sample and the label distribution information of target domain when taking action, concatenates the following features: (1) The high-level semantic feature , which is the output of given . (2) The label of the source sample, represented by a one-hot vector. (3) The predicted probability distribution of the target batch , which can be calculated as , . Feature (1) represents the high-level information of source sample. Feature (3) based on the intuition that the probabilities of assigning the target data to source outlier classes should be small since the target sample is significantly dissimilar to the source outlier sample. Consequently, quantifies the contribution of each source class to the target domain. Feature (2) is combined with feature (3) to measure the relation between each source sample and the target domain.
The action , which indicates whether the source sample is kept or filtered from the source batch. The selector utilizes -greedy strategy  to sample based on . represents the probability that the sample is kept. The is decayed from 1 to 0. is defined as a policy network with two fully connected layers. Formally, is computed as follows:
where is the ReLU activation, and are the weight matrix and bias of the -th layer, and is the state of the source sample, which concatenates feature (1), (2) and (3).
The selector takes actions to select from . The RTNet uses to update the domain adaptation model and obtains a reward for evaluating the policy. In contrast to usual reinforcement learning, where one reward corresponds to one action, The RTNet assigns one reward to a batch of actions to improve the efficiency of model training.
To take into account pixel-level information when selecting source samples, the novel reward is designed according to the reconstruction error of the selected source sample based on the target generator. The intuition of using this reconstruction error as reward is that the reconstruction error of outlier source sample is large since they are extremely dissimilar to the target classes. Consequently, the selector aims to select source samples with small reconstruction errors for distribution alignment and classifier training. However, the purpose of reinforcement learning is to maximize the reward, so we design the following novel reward based on reconstruction error:
where is the sample selected by the reinforced data selector, and is the number of samples selected. Note that, to accurately evaluate the efficacy of , rewards are collected after the feature extractor and classifier are updated as in Eq. 5 and before the generators are updated as in Eq. 6. , and can be trained as follows:
In the process of selection, not only the last action contributes to the reward, but all previous actions contribute. Therefore, the future total reward for each batch can be formalized as:
where is the reward discount factor, and is the number of batches in this episode.
The selector is optimized based on actor-critic algorithm . In each episode, the selector aims to maximize the expected total reward. Formally, the objective function is defined as:
where represents the parameter of policy network . is updated by performing, typically approximate, gradient ascent on . Formally, the update step of is defined as:
where is the learning rate, is the batch size, and is an estimate of the advantage function based on future total reward, which guides the update of . Note that, is an unbiased estimate of . The actor-critic framework combines and for stable training. In this work, we utilize to estimate the expected feature total reward . Hence, the can be considered as an estimate of the advantage of action, which encourage to make strategies to maximize feature total rewards . The is defined as follows:
The architecture of the value network is similar to the policy network, except that the final output layer is a regression function. The value network is designed to estimate the expected feature total reward for each state, which can be optimized in the following form:
As the reinforced data selector and domain adaptation model interact with each other during training, we train them jointly. To ensure that the domain adaptation model provides accurate states and rewards in the early stages of training, we first pre-train , , and through the classification loss of source samples and Eq. 6. We follow the previous work  to train RTNet, the detailed training process is shown in Algorithm 1. Note that, RTNet also filters the outlier source classes on the classifier, thus focusing the classifier more on the source samples belonging to target label space, which make the classifier can provide more accurate target label distribution as shown in feature (3).
Office-31 is a widely-used visual domain adaptation dataset, which contains 4,110 images of 31 categories from three distinct domains: Amazon website (A), Webcam (W) and DSLR camera (D). Following the settings in , we select the same 10 categories in each domain to build new target domain and create 6 transfer scenarios A31W10, W31A10, W31D10, D31W10, A31D10, and D31A10 to evaluate the RTNet. Digital dataset includes five domain adaptation benchmarks: Street View House Numbers (SVHN) , MNIST , MNIST-M , USPS  and synthetic digits dataset (SYN) , which consist of ten categories. We select 5 categories (digit 0 to digit 4) as target domain in each dataset and construct four partial domain adaptation tasks: SVHN10MNIST5, MNIST10MNIST-M5, USPS10MNIST5 and SYN10MNIST5.
The RTNet is implemented via Tensorflow and trained with the Adam optimizer. For the experiments on Office-31, we employ the ResNet-50 pre-trained on ImageNet as the backbone of domain adaptation model and fine-tune the parameters of the fully connected layers and the final block. For the experiments on digital datasets, We adopt modified LeNet as the backbone of domain adaptation model and update all of the weights. All images are converted to grayscale and resized to 32 32.
In RTNet, we adopt separate two fully connected layers for the policy network and value network , a series of transposed convolutional layers for the generators . To guarantee fair comparison, the same frameworks are used for and in all comparison methods, and each method is trained five times and the average is taken as the final result. In our method, the discount factor that determines the impact of previous actions on the current reward is a critical hyperparameter. We first select it according to accuracy on SVHN10MNIST5 task and then set to 0.8 for all other transfer tasks, since our approach can work stably across different tasks. As for other hyperparameters, we set , and as in . To ease model selection, the hyperparameters of comparison methods are gradually change from 0 to 1 as in .
|Type||Method||SVHN10 MNIST5||MNIST10 MNIST-M5||USPS10 MNIST5||SYN10 MNIST5||Avg|
Result and Discussion
Table 1 and Table 2 show the classification results on two datasets. RTNet achieves the best accuracy on most transfer tasks. In particular, RTNet outperforms other methods by a large margin on tasks with small source and target domains, e.g. A31W10, and on tasks with large source and target domains, e.g. SVHN10MNIST5. These results confirm that our approach can learn more transferable features in PDA scenarios of various scales by selecting the related source samples for positive transfer.
By looking at the Table 1 and Table 2, several observations can be made. First, the previous UDA methods including those based on adversarial network (DANN), and those based on moment match (DAN, JDDA, and CORAL) perform even worse than non-adaptation model (ResNet-50 or modified LeNet), indicating that they were affected by the negative transfer. These methods reduce the shift of the marginal distribution between domains without considering the conditional distribution, thus matching the outlier source class with the target domain, resulting in weak classifier performance. Second, PDA methods (ETN and PADA) improve classification accuracy by a large margin since their weighting mechanisms can mitigate negative transfer caused by outlier categories. Finally, RTNet achieves the best performance. Different from the previous PDA methods which only rely on the predicted probability distribution to obtain the weight, RTNet combines the high-level semantic feature and predicted probability distribution to select the source sample, and employs the pixel-level reconstruction error as evaluation criteria to guide the learning of policy network. Thus, this selection mechanism can detect outlier source classes more effectively and transfer relevant samples.
|State Features Combination||Accuracy|
|(1) (2) (3)||95.3|
We further perform a state feature ablation test on SVHN10MNIST5. We have three state features as mentioned in Section 3Our Approach, two of which (Feature (2), (3)) can be considered as a feature group because they are combined to evaluate the contribution of the source sample to the target domain. As shown in Table 3, feature (1) can be used alone to get good performance, which indicates that the feature extracted by the feature extractor can describe the state of the model well, while the result of the second feature group suggests that the probability distributions have limited capacity for state representation. Besides, the combination of these two feature groups yields the best performance, confirming that all features contribute to the final result.
We visualize the features of the adaptation layer using t-SNE . As shown in Figure 3, several observations can be made. First, by comparing Figures 2(a), 2(e) and Figures 2(b), 2(f), we find that CORAL forces the target domain to be aligned with the whole source domain, including outlier classes that do not exist in the target label space, which triggers negative transfer leading to model degradation. Second, as can be seen in Figures 2(d), 2(h), RTNet correctly matches the target samples to related source samples by integrating the selector into CORAL architecture to filter out outlier classes, which confirms that the matching-based UDA framework can be extended to solve PDA problem by embedding the reinforced data selector. Finally, compared with Figure 2(c), 2(g), RTNet matches the related source domain and the target domain more accurately, indicating that it is more effective than ETN in suppressing the impact of outlier source classes by considering high-level and pixel-level information.
Statistics of Class-wise Retention Probabilities
We utilize to verify the ability of the data selector to filter the samples, averaging the retention probabilities of each class of source samples. represents a source sample set, which contains samples belonging to class c. As shown in Figure 3(a), RTNet assigns much larger retention probabilities to the shared classes than to the outlier classes. This result proves that RTNet has the ability to automatically select relevant source classes and filter out outlier classes.
To investigate the effects of the reward discount factor , we varied its value as . indicates that future rewards are not considered when updating the policy network. indicates that future rewards with no discounts are considered when updating the policy network. As shown in Figure 3(b), the trend implies that appropriately increasing the contribution of future rewards can facilitate correct filtering of reinforced data selector to mitigate negative transfer.
We analyze the convergence of RTNet. As shown in Figure 3(c), the test errors of DANN and CORAL are higher than ResNet due to negative transfer. Their testing errors are also unstable, probably because the target domain is matched to different outlier classes during training. RTNet fast and stably converges to the lowest test error, indicating that it can be efficiently trained to solve PDA problem. As shown in Figure 3(d), the reward gradually increases as the episode progresses, meaning that the reinforced data selector can learn the correct policy to maximize the reward and filter out source outlier classes.
We conduct experiments to evaluate the performance of RTNet when the number of target classes varies. As shown in Figure 3(e), as the number of target classes decreases, the performance of CORAL degrades rapidly, indicating that the negative transfer becomes more and more serious as the label distribution becomes larger. RTNet performs better than other comparison methods, indicating that our approach can mitigate negative transfer to solve PDA problem. Moreover, RTNet is superior to UDA method (CORAL) when the source and target label spaces are consistent (A31W31), which shows that our method does not filter erroneously when there are no outlier classes.
In this work, we propose an end-to-end reinforced transfer network, which utilizes both high-level and pixel-level information to address partial domain adaptation problem. RTNet applies reinforcement learning to train a reinforced data selector based on the actor-critic framework to filter out outlier source classes with the purpose of mitigating negative transfer. Unlike previous partial domain adaptation methods based on adversarial network, the reinforced data selector we proposed can be integrated into almost all standard domain adaptation frameworks including those based on adversarial network, and those based on moment match. Note that, the results on RTNet based on adversarial domain adaptation model are shown in the Appendix. The state-of-the-art experimental results confirm the efficacy of our method.
-  (2017) Deep reinforcement learning: a brief survey. IEEE Signal Processing Magazine 34 (6), pp. 26–38. Cited by: Reinforcement Learning.
-  (2009) Dataset shift in machine learning. The MIT Press. Cited by: Introduction.
-  (2018) Partial transfer learning with selective adversarial networks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2724–2732. Cited by: Introduction, Introduction, Partial Domain Adaptation, Domain Adaptation Model, Datasets.
-  (2018) Partial adversarial domain adaptation. In Proceedings of the European Conference on Computer Vision (ECCV), pp. 135–150. Cited by: Introduction, Partial Domain Adaptation, Domain Adaptation Model.
-  (2019) Learning to transfer examples for partial domain adaptation. arXiv preprint arXiv:1903.12230. Cited by: Partial Domain Adaptation, Domain Adaptation Model.
-  (2019) Joint domain alignment and discriminative feature learning for unsupervised deep domain adaptation. national conference on artificial intelligence. Cited by: Partial Domain Adaptation, Implementation Details.
-  (2019) Deep joint two-stream wasserstein auto-encoder and selective attention alignment for unsupervised domain adaptation. Neural Computing and Applications, pp. 1–14. Cited by: Introduction.
-  (2014) Decaf: a deep convolutional activation feature for generic visual recognition. In International conference on machine learning, pp. 647–655. Cited by: Feature Visualization.
-  (2017) Learning how to active learn: a deep reinforcement learning approach.. In Proceedings of the 2017 Conference on Empirical Methods in Natural Language Processing, pp. 595–605. Cited by: Reinforcement Learning.
-  (2016) Domain-adversarial training of neural networks. The Journal of Machine Learning Research 17 (1), pp. 2096–2030. Cited by: Introduction, Partial Domain Adaptation, Datasets.
-  (2005) Semi-supervised learning by entropy minimization. In Advances in neural information processing systems, pp. 529–536. Cited by: Domain Adaptation Model.
-  (2002) A database for handwritten text recognition research. IEEE Transactions on Pattern Analysis & Machine Intelligence 16 (5), pp. 550–554. Cited by: Datasets.
-  (2000) Actor-critic algorithms. In Advances in neural information processing systems, pp. 1008–1014. Cited by: Introduction, Reinforcement Learning, Optimization.
-  (1998) Gradient-based learning applied to document recognition. proc ieee. Proceedings of the IEEE 86 (11), pp. 2278–2324. Cited by: Datasets.
-  (2015) Learning transferable features with deep adaptation networks. In International Conference on Machine Learning, pp. 97–105. Cited by: Partial Domain Adaptation.
-  (2016) Unsupervised domain adaptation with residual transfer networks. In Advances in Neural Information Processing Systems, pp. 136–144. Cited by: Domain Adaptation Model.
-  (2017) Deep transfer learning with joint adaptation networks. In International Conference on Machine Learning, pp. 2208–2217. Cited by: Introduction, Implementation Details.
-  (2015) Human-level control through deep reinforcement learning. Nature 518 (7540), pp. 529. Cited by: Reinforcement Learning, Action.
-  (2011) Reading digits in natural images with unsupervised feature learning. Nips Workshop on Deep Learning & Unsupervised Feature Learning. Cited by: Datasets.
-  (2010) A survey on transfer learning. IEEE Transactions on knowledge and data engineering 22 (10), pp. 1345–1359. Cited by: Introduction.
-  (2019) Learning to selectively transfer: reinforced transfer learning for deep text matching. In Proceedings of the Twelfth ACM International Conference on Web Search and Data Mining, pp. 699–707. Cited by: Reinforcement Learning, Optimization.
-  (1994) On-line q-learning using connectionist systems. Vol. 37, University of Cambridge, Department of Engineering Cambridge, England. Cited by: Reinforcement Learning.
-  (2010) Adapting visual category models to new domains. In European Conference on Computer Vision (ECCV), Cited by: Datasets.
-  (2016) Deep coral: correlation alignment for deep domain adaptation. In European Conference on Computer Vision, pp. 443–450. Cited by: Introduction, Partial Domain Adaptation, Domain Adaptation Model.
-  (2017) Adversarial discriminative domain adaptation. In 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 2962–2971. Cited by: Introduction, Partial Domain Adaptation.
-  (2014) Deep domain confusion: maximizing for domain invariance. arXiv preprint arXiv:1412.3474. Cited by: Partial Domain Adaptation.
-  (1992) Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning 8 (3-4), pp. 229–256. Cited by: Reinforcement Learning, Optimization.
-  (2018) REINFORCED co-training. In NAACL HLT 2018: 16th Annual Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Vol. 1, pp. 1252–1262. Cited by: Reinforcement Learning.
-  (2018) Importance weighted adversarial nets for partial domain adaptation. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 8156–8164. Cited by: Introduction, Domain Adaptation Model.