Domain Adversarial Fine-Tuning as an Effective Regularizer
In Natural Language Processing (NLP), pretrained language models (LMs) that are transferred to downstream tasks have been recently shown to achieve state-of-the-art results. In this work, we extend the standard fine-tuning process of pretrained LMs by introducing a new regularization technique, after; domain Adversarial Fine-Tuning as an Effective Regularizer. Specifically, we complement the task-specific loss used during fine-tuning with an adversarial objective. This additional loss term is related to an adversarial classifier, that aims to discriminate between in-domain and out-of-domain text representations. In-domain refers to the labeled dataset of the task at hand while out-of-domain refers to unlabeled data from a different domain. Intuitively, the adversarial classifier acts as a regularizer which prevents the model from overfitting to the task-specific domain. Empirical results on sentiment analysis, linguistic acceptability, and paraphrase detection show that after leads to improved performance compared to standard fine-tuning.
Current research in NLP focuses on transferring knowledge from a language model (LM), pretrained on large general-domain data, to a target task. The LM representations are transferred to the target task either as additional features of a task-specific model Peters et al. (2018), or by fine-tuning Howard and Ruder (2018); Devlin et al. (2019); Yang et al. (2019). Standard fine-tuning involves initializing the target model with the pretrained LM and training it with the target data.
Fine-tuning, however, can lead to catastrophic forgetting Goodfellow et al. (2013), if the pretrained LM representations are adjusted to such an extent to the target task, that most generic knowledge, captured during pretraining, is in effect forgotten. In addition, pretrained LMs tend to overfit when fine-tuned to a target task dataset that is several orders of magnitude smaller than the pretraining data Dai and Le (2015).
Adversarial training is a method to increase robustness and regularize deep neural networks Goodfellow et al. (2015); Miyato et al. (2017). It has been used for domain adaptation Ganin et al. (2016) to train a model from scratch to produce representations that are invariant to different domains. Inspired by this approach, we propose a regularization technique for the fine-tuning process of a pretrained LM, that aims to optimize knowledge transfer to the target task and avoid overfitting.
Our method, domain Adversarial Fine-Tuning as an Effective Regularizer (after) extends standard fine-tuning by adding an adversarial objective to the task-specific loss. We leverage out-of-domain unlabeled data (i.e. from a different domain than the target task domain). The transferred LM is fine-tuned so that an adversarial classifier cannot discriminate between text representations from in-domain and out-of-domain data. This loss aims to regularize the extent to which the model representations are allowed to adapt to the target task domain. Thus, after is able to preserve the general-domain knowledge acquired during the pretraining of the LM, while adapting to the task.
Our contributions are: (1) We propose after, an LM fine-tuning method that aims to avoid catastrophing forgetting of general-domain knowledge, acting as a new kind of regularizer. (2) We show that our approach improves the performance of standard BERT Devlin et al. (2019) fine-tuning in three out of four considered natural language understanding tasks from the GLUE benchmark Wang et al. (2019a). We also validate the efficiency of our appproach on a different pretrained LM, XLNet Yang et al. (2019). (3) We further conduct an ablation study to provide useful insights regarding the key factors of the proposed approach.
2 Related Work
Several approaches have been proposed for the adaptation of a model trained on a domain to a different domain , where no labeled data is available Grauman (2012); Tzeng et al. (2014); Sun et al. (2016). Ganin et al. (2016) were the first to propose adversarial training for domain adaptation. They introduced a gradient reversal layer to adversarially train a classifier that should not be able to discriminate between and , in image classification and sentiment analysis tasks.
Various adversarial losses have been used for domain adaptation in several NLP tasks, such as question answering Lee et al. (2019), machine reading comprehension Wang et al. (2019b) and cross-lingual named entity recognition Keung et al. (2019). Adversarial approaches have been also used to learn latent representations that are agnostic to different attributes of the input text, such as language Lample et al. (2018a, b) and style Yang et al. (2018). Contrary to previous domain adaptation work, we explore the addition of an adversarial loss term to serve as a regularizer for fine-tuning.
Other variants of LM fine-tuning include a supplementary supervised training stage in data-rich tasks Phang et al. (2018) or multi-task learning with additional supervised tasks Liu et al. (2019). However, such methods require additional labeled data. A common way to leverage unlabeled data during fine-tuning is through an additional stage of language modeling. As unlabeled data can be used the available task-specific data Howard and Ruder (2018), or additional unlabeled in-domain corpora Sun et al. (2019); Gururangan et al. (2020). This approach adds a computationally expensive step that requires unlabeled data from a specific source. By contrast, our method leverages out-of-domain data with only a small computational overhead and minimal changes to the fine-tuning process.
3 Proposed Approach
Figure 1 provides a high-level overview of after.
Problem Definition. We tackle a Main task, with a labeled dataset from domain . We further exploit an existing unlabeled corpus, Auxiliary, that comes from a different domain . We label each sample with the corresponding domain label , for samples from Main, and for samples from Auxiliary. We note that we do not use any real labels from Auxiliary (if there are any). The domain labels are used to train a classifier that discriminates between and .
Model. We initialize our model with pretrained weights from a state-of-the-art language model, BERT Devlin et al. (2019). The representation of BERT for the input sequence is encoded in the [CLS] token output embedding. We add a linear layer on top of the [CLS] output embedding for the Main task, resulting in a task-specific loss . We also add another linear layer for the binary domain discriminator, with a corresponding loss , which has the same input (Figure 1).
Adversarial Fine-tuning. The domain discriminator outputs a domain label for each sample of the training set. We seek representations that are both discriminative for the Main task and indiscriminative for the domain classifier. Hence, we minimize and at the same time maximize , by fine-tuning BERT with the joint loss:
where () controls the importance of the domain loss. The parameters of the domain classifier are trained to predict the (true) domain label, while the rest of the network is trained to mislead the domain classifier, thereby developing domain-independent internal representations.
Gradient Reversal Layer. We use a Gradient Reversal Layer (GRL) Ganin et al. (2016) between the [CLS] output embedding and the domain discriminator layer, as shown in Figure 1, to maximize . During the forward pass, GRL acts as an identity transform, but during backpropagation, GRL reverses the gradients. In effect, BERT parameters are updated towards the opposite direction of the gradient of and, adversarially, towards the direction of the gradient of .
|Matthews corr.||Accuracy||Accuracy / F1||Accuracy|
|after w/ News||/|
|after w/ Reviews||/|
|after w/ Legal||/|
|after w/ Medical||/|
|after w/ Math|
|ag news||Agricultural News (news)||120K|
|europarl||Legal Documents (legal)||120K|
|amazon||Electronics Reviews (reviews)||120K|
|pubmed||Medical Papers (medical)||120K|
|mathematics||Mathematics Questions (math)||120K|
4 Experiments & Results
Datasets. We experiment with four Main datasets from the GLUE benchmark Wang et al. (2019a). For Auxiliary data we select corpora from various domains. For the news domain we use the ag news dataset Zhang et al. (2015) and for the reviews domain we use a part of the Electronics reviews of He and McAuley (2016). For the legal domain we use the English part of Europarl Koehn (2004) and for the medical domain we use papers from PubMed, provided by Cohan et al. (2018). We also use math questions from the dataset of Saxton et al. (2019) for the math domain. Table 2 summarizes all datasets. More details regarding the selection and processing of the datasets can be found in the Appendix.
Implementation Details. As our baseline, we fine-tune the pretrained BERT-base model, using the suggested hyperparameters from Devlin et al. (2019). We tune the hyperparameter of Eq. 1 on the validation set for each experiment, finding that most values of improve over the baseline. For more implementations details see Appendix.
Results. We observe in Table 1 that the proposed approach (after) outperforms the baseline (BERT) in three out of four tasks. For these tasks, after results in improved performance with every Auxiliary dataset, demonstrating the robustness of our approach across domains.
Specifically, in CoLA, we observe that fine-tuning with the adversarial loss substantially outperforms standard BERT fine-tuning. We observe that after improves the baseline by 4-5 points, with reduced variance, using most Auxiliary datasets. In SST-2, we notice that although BERT achieves high accuracy, the use of after still results in slight performance gains (). Similar to CoLA, these improvements are consistent across Auxiliary datasets and come with a reduced variance, compared to BERT. For instance, using an Auxiliary dataset from the legal domain we improve the accuracy of BERT from to . In MRPC, we observe gains of points on average in accuracy and in F1 over BERT. Using medical data as Auxiliary, after outperforms the baseline by points in accuracy and in F1. In RTE, the proposed approach does not improve upon the baseline. We attribute this result to the similarity between the domain of RTE (Wikipedia) and the domain of the pretraining corpus of BERT (Wikipedia and Books). We test this hypothesis in Section 5.
BERT pretraining and Task Domains. To explore why after fails to improve upon the baseline on RTE, we examine if the pretrained representations are already well suited for the task (i.e. no regularization is needed). We calculate the average masked LM (MLM) loss of the pretrained model for each Main dataset (Table 3). SST-2 produces the largest loss which can be partially attributed to the dataset format (it contains short sentences that make the MLM task very challenging). RTE produces the lowest loss confirming our intuition regarding the similarity of the pretraining corpus of BERT and RTE. In this case, general-domain and domain-specific representations are close, rendering domain-adversarial regularization undesirable. This is also confirmed by the the vocabulary overlap between RTE and a Wikipedia corpus (Table 3). The more distant the pretraining domain of BERT is to the specific task (measured by vocabulary overlap and MLM loss), the more benefits after demonstrates, confirming our intuition regarding domain-adversarial regularization.
|Overlap with Wiki (%)|
|after Improvement (%)|
Domain Distance. We measure the domain distance for all Main-Auxiliary pairs to evaluate how the choice of the latter affects the performance of after. We represent the word distribution of each dataset using term distributions where is the probability of the -th word in the joint vocabulary (see Appendix) and calculate Jensen-Shannon (JS) divergence Plank and van Noord (2011). Combining the results of Table 1 and Fig. 2, no clear pattern emerges demonstrating, perhaps, our method’s robustness to domain distance. We leave a further investigation of selection criteria for the Auxiliary data for future work.
Domain-invariant vs. Domain-specific Features. To investigate if the benefits of after can be attributed only to data augmentation we compare adversarial ( in Eq. 1) and multi-task () fine-tuning. We experiment with MRPC and CoLA for both settings (tuning each separately). We observe that during multi-task fine-tuning (Fig. 3), is close to zero (even in the first epoch). This implies that domain classification is an easy auxiliary task, confirming our intuition that a non-adversarial fine-tuning setting favors domain-specific features. Although the multi-task approach leverages the same unlabeled data, its performance is worse than after (Table 4), which highlights the need for an adversarial domain discriminator.
|after w/ Medical|
|multi-task w/ Medical|
Applicability of after. To test the generality of the proposed fine-tuning method, we apply after to XLNet Yang et al. (2019), another powerful pretrained LM. We conduct experiments on two datasets, due to resources constraints. XLNet has a similar architecture as BERT and uses a different pretraining objective that results in improved performance. We observe in Table 5 that after results in performance boost for an even higher-performing LM baseline. We can therefore attribute the effectiveness of after to its way of regularization itself and not to the model architecture.
|after w/ Medical|
|after w/ Math|
6 Conclusions and Future Work
We propose after, a domain adversarial method to regularize the fine-tuning process of a pretrained LM. Empirical results demonstrate that our method can lead to improved performance over standard fine-tuning. after can be widely applied to any transfer learning setting and model architecture, with minimal changes to the fine-tuning process, without requiring any additional labeled data. We aim to further explore the effect of Auxiliary data on the final performance and the use of multiple Auxiliary datasets. We also aim to extend the proposed approach as a way to fine-tune a pretrained LM to a different language, in order to produce language-invariant representations.
Appendix A Appendices
In this supplementary material, we provide additional information for producing the results in the paper, and results that could not fit into the main body of the paper.
a.1 Dataset Details
Main datasets. We use only four datasets of the GLUE benchmark as Main for our experiments, due to resources constraints. The datasets were chosen in such a way to represent the broad variety of natural language understanding tasks available in the GLUE benchmark, such as linguistic acceptability Warstadt et al. (2019), sentiment analysis Socher et al. (2013), paraphrase detection Dolan and Brockett (2005) and textual entailment Dagan et al. (2005); Bar-Haim et al. (2006); Giampiccolo et al. (2007); Bentivogli et al. (2009). The datasets used represent both high (SST-2) and low-resource (RTE, CoLA, MRPC) tasks, as well as single-sentence (CoLA, SST-2, MRPC) and sentence-pair (MRPC, RTE) tasks. All Main datasets are open source and can be found in https://gluebenchmark.com/tasks.
Auxiliary datasets. We choose Auxiliary datasets that are larger than Main, which we consider as the most realistic scenario, given the availability of unlabeled compared to labeled data. We under-sample the Auxiliary dataset to ensure that the two domains are equally represented, motivated by the observation of Bingel and Søgaard (2017) that balanced datasets tend to be better in auxiliary tasks. For each mini-batch, we sample equally from the Main and Auxiliary datasets.
The Auxiliary datasets are a mixed of labeled and unlabeled datasets from different domains. The labeled Auxiliary datasets (e.g. ag news) are handled as unabeled corpora, by dropping the task-specific labels and using only the domain labels. Although some domains might seem similar to those of the Main datasets, e.g Electronics Reviews vs. Movies revies and Agricultural News vs. News this is not the case as can be seen in Figure 6.
The maximum sequence length for all datasets was 128, so all samples were truncated to 128 tokens and lower-cased. For europarl, which contains parallel corpora in multiple languages, only the English part is used. We therefore sample 120K sentences from the English corpus. For pubmed we use 120K abstracts from medical papers, from the dataset of Cohan et al. (2018). For math we use 120K questions of medium difficulty from the dataset of Saxton et al. (2019). We note that all corpora used are in English.
a.2 Hyperparameters and Model details
We base our implementation on Hugging Face’s transformers library Wolf et al. (2019) in PyTorch Paszke et al. (2019). For BERT we use the bert-base-uncased pretrained model and we fine-tune it with the following hyperparameters: dropout , batch size and a maximum length of tokens. For the optimization we use the Adam optimizer Kingma and Ba (2015) with a learning rate of 2e-5, adam epsilon 1e-6 and weight decay . We use a linear warmup schedule with warmup proportion. We fine-tune each model for epochs and evaluate the model times per epoch, as suggested by Dodge et al. (2020). We select the best model based on the validation loss.
For XLNet we use the xlnet-base-cased. We use the last hidden state output embedding, as the input sequence representation. We fine-tune XLNet with the following hyperparameters: batch size and the same learning rate (2e-5) and sequence length () as BERT. We do not use weight decay or warmup.
When we combine after with either BERT or XLNet we use the same hyperparameters as above. We note that both models have approximately parameters and this is (almost) the same using after, as well. Our approach only introduces a binary domain discriminator in the form of a linear layer.
For all experiments we used a 6G GeForce GTX 1080. The duration of the experiments depended on the datasets. For SST-2, which is the largest dataset, the experiments for the baseline (BERT, XLNet) had a runtime of approximately 100mins (for all 4 epochs) and 200mins for after, due to the implicit dataset augmentation. Smaller datasets such as MRPC and CoLA had an approximate runtime of 30mins with standard fine-tuning and 60mins with after.
a.3 Tuning the hyperparameter
We tune on each development set, choosing from . In Figure 4 we compare the performance of bert and after for different Main-Auxiliary combinations, as we vary the value of .
We observe that the various values of can have different effect on the performance and variance of after. We observe that most values of significantly improve the performance of the baseline, BERT and an exhaustive search is not required. Table 6 presents the values of that were used for the results reported in Table 1. Best values of were chosen based on the task-specific metric (e.g. Accuracy, Matthews correlation).
a.4 More Domain Distance Results
In order to create a common vocabulary for all data for Figure 2 we find the most frequent words in each dataset and we then take the union of these sub-vocabularies which results in words. We also calculate the vocabulary overlap, by creating each domain (or task) vocabulary with the most frequent words in each dataset (in case a dataset contains less words we use all the words in the dataset).
We then calculate the vocabulary overlap between domains (Figure 5) and between each task and all domains (Figure 6). For the latter, we also include the Wiki domain to account for the pretraining domain of BERT. For the vocabulary of Wiki we use the WikiText-2 corpus from Merity et al. (2017). We observe in Figure 5 that most domains are dissimilar, with the exception of news and legal domains, that have vocabulary overlap. In Figure 6, we observe that rte has the most overlap in vocabulary with Wiki which is a possible cause for the deteriorated performance of after, since the model has already been pretrained in this domain and does not require further regularization, as described in Section 5.
- BERT figure taken from the-illustrated-bert blog.
- The second pascal recognising textual entailment challenge. Cited by: §A.1.
- The fifth pascal recognizing textual entailment challenge. In Proceedings of the Text Analysis Conference (TAC), Cited by: §A.1.
- Identifying beneficial task relations for multi-task learning in deep neural networks. In Proceedings of the Conference of the European Chapter of the Association for Computational Linguistics, pp. 164–169. External Links: Cited by: §A.1.
- A discourse-aware attention model for abstractive summarization of long documents. In Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 2 (Short Papers), pp. 615–621. External Links: Cited by: §A.1, §4.
- The pascal recognising textual entailment challenge. In Proceedings of the First International Conference on Machine Learning Challenges: Evaluating Predictive Uncertainty Visual Object Classification, and Recognizing Textual Entailment, MLCW’05, pp. 177–190. External Links: Cited by: §A.1.
- Semi-supervised sequence learning. In Advances in Neural Information Processing Systems 28, C. Cortes, N. D. Lawrence, D. D. Lee, M. Sugiyama and R. Garnett (Eds.), pp. 3079–3087. External Links: Cited by: §1.
- BERT: pre-training of deep bidirectional transformers for language understanding. In In Proceedings of the Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, pp. 4171–4186. External Links: Cited by: §1, §1, §3, §4.
- Fine-tuning pretrained language models: weight initializations, data orders, and early stopping. ArXiv abs/2002.06305. Cited by: §A.2.
- Automatically constructing a corpus of sentential paraphrases. In Proceedings of the Third International Workshop on Paraphrasing (IWP2005), External Links: Cited by: §A.1.
- Domain-adversarial training of neural networks. Journal of Machine Learning Research 17 (1), pp. 2096–2030. External Links: Cited by: §1, §2, §3.
- The third PASCAL recognizing textual entailment challenge. In Proceedings of the ACL-PASCAL Workshop on Textual Entailment and Paraphrasing, pp. 1–9. External Links: Cited by: §A.1.
- An empirical investigation of catastrophic forgeting in gradient-based neural networks. CoRR abs/1312.6211. External Links: Cited by: §1.
- Explaining and harnessing adversarial examples. In Proceedings of the International Conference on Learning Representations, External Links: Cited by: §1.
- Geodesic flow kernel for unsupervised domain adaptation. In Proceedings of the Conference on Computer Vision and Pattern Recognition, pp. 2066–2073. External Links: Cited by: §2.
- Don’t stop pretraining: adapt language models to domains and tasks. ArXiv abs/2004.10964. Cited by: §2.
- Ups and downs: modeling the visual evolution of fashion trends with one-class collaborative filtering. In Proceedings of the International Conference on World Wide Web, WWW ’16, pp. 507–517. External Links: Cited by: §4.
- Universal language model fine-tuning for text classification. In Proceedings of the Annual Meeting of the Association for Computational Linguistics, pp. 328–339. External Links: Cited by: §1, §2.
- Adversarial learning with contextual embeddings for zero-resource cross-lingual classification and NER. In Proceedings of the Conference on Empirical Methods in Natural Language Processing and the International Joint Conference on Natural Language Processing, pp. 1355–1360. External Links: Cited by: §2.
- Adam: A method for stochastic optimization. In 3rd International Conference on Learning Representations, ICLR 2015, San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings, Y. Bengio and Y. LeCun (Eds.), External Links: Cited by: §A.2.
- EuroParl: a parallel corpus for statistical machine translation. 5, pp. . Cited by: §4.
- Unsupervised machine translation using monolingual corpora only. In Proceedings of the International Conference on Learning Representations, External Links: Cited by: §2.
- Word translation without parallel data. In Proceedings of the International Conference on Learning Representations, External Links: Cited by: §2.
- Domain-agnostic question-answering with adversarial training. In Proceedings of the 2nd Workshop on Machine Reading for Question Answering, pp. 196–202. External Links: Cited by: §2.
- Multi-task deep neural networks for natural language understanding. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pp. 4487–4496. External Links: Cited by: §2.
- Pointer sentinel mixture models. In 5th International Conference on Learning Representations, ICLR 2017, Conference Track Proceedings, External Links: Cited by: §A.4.
- Virtual adversarial training: a regularization method for supervised and Semi-Supervised learning. External Links: Cited by: §1.
- PyTorch: an imperative style, high-performance deep learning library. In Advances in Neural Information Processing Systems 32, H. Wallach, H. Larochelle, A. Beygelzimer, F. d’ Alché-Buc, E. Fox and R. Garnett (Eds.), pp. 8024–8035. External Links: Cited by: §A.2.
- Deep contextualized word representations. In Proceedings of the Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, pp. 2227–2237. External Links: Cited by: §1.
- Sentence encoders on stilts: supplementary training on intermediate labeled-data tasks. ArXiv abs/1811.01088. Cited by: §2.
- Effective measures of domain similarity for parsing. In Proceedings of the 49th Annual Meeting of the Association for Computational Linguistics: Human Language Technologies, pp. 1566–1576. External Links: Cited by: §5.
- Analysing mathematical reasoning abilities of neural models. ArXiv abs/1904.01557. Cited by: §A.1, §4.
- Recursive deep models for semantic compositionality over a sentiment treebank. In Proceedings of the Conference on Empirical Methods in Natural Language Processing, pp. 1631–1642. External Links: Cited by: §A.1.
- Return of frustratingly easy domain adaptation. In Proceedings of the Association for the Advancement of Artificial Intelligence Conference on Artificial Intelligence, pp. 2058–2065. External Links: Cited by: §2.
- How to fine-tune bert for text classification?. ArXiv abs/1905.05583. Cited by: §2.
- Deep domain confusion: maximizing for domain invariance. CoRR abs/1412.3474. External Links: Cited by: §2.
- GLUE: a multi-task benchmark and analysis platform for natural language understanding. In International Conference on Learning Representations, External Links: Cited by: §1, §4.
- Adversarial domain adaptation for machine reading comprehension. In Proceedings of the Conference on Empirical Methods in Natural Language Processing and the International Joint Conference on Natural Language Processing, pp. 2510–2520. External Links: Cited by: §2.
- Neural network acceptability judgments. Transactions of the Association for Computational Linguistics 7 (0), pp. 625–641. External Links: Cited by: §A.1.
- HuggingFace’s transformers: state-of-the-art natural language processing. External Links: Cited by: §A.2.
- XLNet: generalized autoregressive pretraining for language understanding. In Advances in Neural Information Processing Systems 32, H. Wallach, H. Larochelle, A. Beygelzimer, F. d’ Alché-Buc, E. Fox and R. Garnett (Eds.), pp. 5753–5763. External Links: Cited by: §1, §1, §5.
- Unsupervised text style transfer using language models as discriminators. In Advances in Neural Information Processing Systems 31, S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi and R. Garnett (Eds.), pp. 7287–7298. External Links: Cited by: §2.
- Character-level convolutional networks for text classification. In Proceedings of the Conference on Neural Information Processing Systems, pp. 649–657. External Links: Cited by: §4.