Learning to Optimize Domain Specific Normalization
with Domain Augmentation for Domain Generalization
We propose a simple but effective multi-source domain generalization technique based on deep neural networks by incorporating optimized normalization layers specific to individual domains. Our approach employs multiple normalization methods while learning a separate affine parameter per domain. For each domain, the activations are normalized by a weighted average of multiple normalization statistics. The normalization statistics are kept track of separately for each normalization type if necessary. Specifically, we employ batch and instance normalizations in our implementation and attempt to identify the best combination of two normalization methods in each domain and normalization layer. In addition, we augment new domains through the combinations of multiple existing domains to increase the diversity of source domains available during training. The optimized normalization layers and the domain augmentation are effective to enhance the generalizability of the learned model. We demonstrate the state-of-the-art accuracy of our algorithm in the standard benchmark datasets.
Domain generalization aims to learn generic feature representations agnostic to domains and make trained models perform well in completely new domains. To achieve this challenging goal, one needs to train models that can capture useful information observed commonly in multiple domains and recognize semantically related but visually inconsistent examples effectively. Many real-world problems have similar objectives so this task can be widely used in various practical applications. Domain generalization is closely related to unsupervised domain adaptation but there is a critical difference regarding the availability of target domain data; contrary to unsupervised domain adaptation, domain generalization cannot access examples in target domain during training but is still required to capture transferable information across domains. Due to the constraint, the domain generalization problem is typically considered to be more difficult and involve multiple source domains to make the problem more feasible.
Domain generalization techniques are classified into several groups depending on their approaches. Some algorithms define novel loss functions to learn domain-agnostic representations [muandet2013domain, motiian2017unified, li2018domain] while others are more interested in designing deep neural network architectures to realize similar goals [pacs, dinnocente2018domain, mancini2018best]. The algorithms based on meta-learning have been proposed under the assumption that there exists a held-out validation set [balaji2018metareg, li2019episodic, li2018learning].
Our algorithm belongs to the second category, i.e.network architecture design methods. In particular, we are interested in exploiting normalization layers in deep neural networks to solve the domain generalization problem. The most naïve approach would be to train a single deep neural network with batch normalization using all training examples regardless of their domain memberships. This method works fairly well partly because batch normalization regularizes feature representations from heterogeneous domains and a trained model is often capable of adapting to unseen domains. However, the benefit of batch normalization is limited when a domain shift is significant, and we are often required to remove domain-specific styles for better generalization. Instance normalization [in] turns out to be an effective scheme for the goal and incorporating both batch and instance normalization techniques further improves accuracy by a data-driven balancing of two normalization methods [bin]. Our approach also employs the two normalizations but proposes a more sophisticated algorithm designed for domain generalization.
We explore domain-specific normalizations to learn domain-agnostic discriminative representations by discarding domain-specific ones. The goal of our algorithm is to optimize the combination of normalization techniques in each domain while different domains learn separate parameters for the mixture of normalizations. The intuition behind this approach is that we can learn domain-invariant representations by controlling types of normalization and parameters in normalization layers. Note that all other parameters including the ones in convolutional layers are shared across domains. Although our approach is similar to [bin, sn] in the sense that they attempt to learn the optimal mixing weights between normalization types, our domain-specific optimized normalization technique is unique since it learns normalization parameters and maintains normalization statistics in each domain separately. Figure 1 illustrates the main idea of our approach.
In addition, we propose a simple domain augmentation technique to increase the number of source domains used for training. Specifically, we generate a new source domain by unifying multiple existing ones, by which the resulting domains have different statistics compared to the original ones. Albeit simple, an ensemble of the models trained using multiple combinations of the original and generated domains consistently improves the accuracy compared to the comparable ensemble classifiers based only on the original domains.
Our contributions are as follows:
We propose a domain generalization technique using multiple heterogeneous normalization methods specific to individual domains, which facilitates to extract domain-agnostic feature representations of input examples by removing domain-specific information effectively.
Our approach optimizes the weights of individual normalization types jointly with the rest of networks, and the resulting model is capable of balancing between normalization methods to capture information useful for classification in unseen domains.
We diversify source domains by domain augmentation to improve accuracy, which constructs new domains by the simple combinations of multiple existing ones.
The proposed algorithm achieves the state-of-the-art accuracy in multiple standard benchmark datasets and presents consistent benefits of individual components.
2 Related Work
This section discusses existing domain generalization approaches and reviews two related problems, multi-source domain adaptation and normalization techniques in deep neural networks.
2.1 Domain Generalization
Domain generalization algorithms learn domain-invariant representations given input examples regardless of their domain memberships. Since target domain information is not available at training time, they typically rely on multiple source domains to extract knowledge applicable to any unseen domain. The existing domain generalization approaches can be roughly categorized into three classes. The first group of methods proposes novel loss functions that encourage learned representations to generalize well to new domains. Muandet et al.[muandet2013domain] propose domain-invariant component analysis, which performs dimensionality reduction to make feature representation invariant to domains. A few recent works [motiian2017unified, li2018domain] also attempt to learn a shared embedding space appropriate for semantic matching across domains. Another kind of approaches tackles domain generalization problem by manipulating deep neural network architectures. Domain-specific information is handled by designated modules within deep neural networks [pacs, dinnocente2018domain] while [mancini2018best] proposes a soft model selection technique to obtain generalized representations. Recently, meta-learning based techniques start to be used to solve domain generalization problems. MLDG [li2018learning] extends MAML [finn2017model] to domain generalization task. Balaji et al.[balaji2018metareg] points out the limitation of [li2018learning] and proposes a regularizer to address domain generalization in a meta-learning framework directly. Also, [li2019episodic] presents an episodic training technique appropriate for domain generalization. Note that, to our knowledge, none of the existing methods exploit normalization types and their parameters for domain generalization.
2.2 Multi-Source Domain Adaptation
Multi-source domain adaptation can be considered as the middle-ground between domain adaptation and generalization, where data from multiple source domains are used for training in addition to examples in an unlabeled target domain. Although unsupervised domain adaptation is a very popular problem, its multi-source version is relatively less investigated. Zhao et al.[zhao2018adversarial] propose to learn features that are invariant to multiple domain shifts through adversarial training, and Guo et al.[guo2018multi] use a mixture-of-experts approach by modeling the inter-domain relationships between source and target domains. A recent work using domain-specific batch normalization [dsbn] has shown competitive performance in multi-source domain adaptation by aligning the representations in heterogeneous domains to a single common feature space.
2.3 Normalization in Neural Networks
Normalization techniques in deep neural networks are originally designed for regularizing trained models and improving their generalization performance. Various normalization techniques [bn, in, ln, spectralnorm, wn, gn, bin, sn, ssn, dsbn] have been studied actively in recent years. The most popular technique is batch normalization (BN) [bn], which normalizes activations over individual channels using data in a mini-batch while instance normalization (IN) [in] performs the same operation per instance instead of mini-batch. In general, IN is effective to remove instance-specific characteristics (e.g.style in an image) and adding IN makes a trained model focus on instance-invariant information and increases generalization capability of the model to an unseen domain. Other normalizations such as layer normalization (LN) [ln] and group normalization (GN) [gn] have the same concept while weight normalization [wn] and spectral normalization [spectralnorm] normalize weights over parameter space.
Recently, batch-instance normalization (BIN) [bin], switchable normalization (SN) [sn], and sparse switchable normalization (SSN) [ssn] employ the combinations of multiple normalization types to maximize the benefit. Note that BIN considers batch and instance normalizations while SN uses LN additionally. On the other hand, DSBN [dsbn] adopts separate batch normalization layers for each domain to deal with domain shift and generate domain-invariant representations.
3 Domain-Specific Optimized Normalization for Domain Generalization
This section describes our main algorithm called domain-specific optimized normalization (DSON) in details and also presents how the proposed method is employed to solve domain generalization problems.
Domain generalization aims to learn a domain-agnostic model typically using multiple domains to be applied to an unseen domain. Consider a set of training examples with its corresponding label set in a source domain . Our goal is to train a classifier using the data in multiple () source domains to correctly classify an image in a target domain , which are unavailable during training.
The main objective of the domain generalization problem is to learn a joint embedding space across all source domains, which is expected to be valid in target domains as well. To this end, we train domain-invariant classifiers from each of the source domains and ensemble their predictions. To embed each example onto a domain-invariant feature space, we employ domain-specific normalization, which is to be described in following sections.
Our classification network consists of a set of feature extractors and a single fully connected layer . Specifically, the feature extractors share all parameters across domains except for the ones in the normalization layers.
For each source domain , loss function is defined as
where is the cross-entropy loss. All network parameters are jointly optimized to minimize the sum of classification losses for source domains:
Our domain-specific deep neural network model is obtained by backpropagating the total loss . To facilitate generalization, in the validation phase, we follow the leave-one-domain-out validation strategy proposed in [dinnocente2018domain]; the label of a validation example from domain is predicted by averaging predictions from all domain-specific classifiers, except for the one with domain .
3.2 Batch and Instance Normalization
Normalization techniques [bn, in, ln, gn] are widely applied in recent network architectures for better optimization and regularization. Among them, we combine batch normalization (BN) and instance normalization (IN) to construct a domain-agnostic feature extractor.
Our intuition about the two normalization methods is as follows. For each channel, BN whitens activations over all the spatial locations of instances within a mini-batch, whereas IN does over locations in a single instance. Compared to BN, IN is effective to reduce the cross-category variance by normalizing each instance independently. Figure 2 visualizes the change of feature distribution after applying batch and instance normalization. When IN is integrated into a feature extractor, the features become less discriminative with respect to object categories compared to the ones with BN layers only. At the same time, the features become less discriminative with respect to domains, thereby making the learned representations less overfit to a particular domain. A mixture of IN with BN serves as regularization and the resulting classifier tends to focus on high-level semantic information and is more likely to become invariant to domain shifts.
Based on this intuition, we now employ IN in addition to BN in all normalization layers of our network, where means and variances of two normalization statistics are linearly interpolated to make the feature extractor domain-agnostic.
3.3 Optimization for Domain-Specific Normalization
Based on the intuitions above, we propose a domain-specific optimized normalization (DSON) for domain generalization. Given an example from domain , the proposed domain-specific normalization layer transforms channel-wise whitened activations using affine parameters and . Note that whitening is also performed for each domain. At each channel, the activations are transformed as
where the whitening is performed using the domain-specific mean and variance, and ,
We combine batch normalization (BN) and instance normalization (IN) in a similar manner to [sn] as
where both are calculated separately in each domain as
The optimal mixture weights, , between BN and IN are trained to minimize the loss in Eq (2).
A test example in a target domain is unknown during training. Hence, for inference, we feed the example to the feature extractors of all domains. The final label prediction is given by computing the logits using the fully connected layer , averaging the logits, i.e., and finally applying a softmax function.
One potential issue in the inference step is whether target domains can rely on the model trained only on source domains. This is the main challenge in domain generalization, which assumes that reasonably good representations of target domains can be obtained from the information in source domains only. In our algorithm, we believe instance normalization in each domain has the capability to remove domain-specific styles and standardize the representation. Since each domain has different characteristics, we learn the relative weights of instance normalization in each domain separately. Thus, predictions in each domain should be accurate enough even for the data in target domains. Additionally, the accuracy given by aggregating the predictions of multiple networks trained on different source domains should further improve accuracy.
4 Domain Diversification via Augmentation
As domain generalization aims at working robustly over arbitrary domains, having diverse source domains in the training set is crucial for generalization ability. Based on this intuition, we propose to diversify source domains by domain augmentation. Our key idea is that a mixed set of samples from multiple domains can be interpreted as a new domain with deviated statistics from the original ones.
We employ a simple generation strategy; a new domain is constructed by making a union of multiple existing domains. For example, suppose that there exist three original domains denoted by , , and and each domain contains the instances in the corresponding domain. We construct a new domain by a union of all the elements in two or more original domains, i.e., . Although these are trivial augmentations, we believe that this strategy is useful since an interpolated domain can be considered indirectly. There are various ways to use these new domains. We first construct three datasets, , , and , where is same with the original dataset. Then, we run DSON on each of the datasets and make an ensemble of the learned models to obtain the final results.
To depict the effectiveness of domain-specific optimized normalization (DSON), we implement it on domain generalization benchmarks and provide an extensive ablation study of the algorithm.
5.1 Experimental Settings
We evaluate the proposed method on two domain generalization benchmarks. The PACS dataset [pacs] is commonly used in domain generalization and is favored due to its large inter-domain shift across four domains: Photo, Art Painting, Cartoon, and Sketch. It contains a total of 9,991 images in 7 categories, with an image resolution of 227 227. We follow the experimental protocol in [pacs], where the model is trained on any three of the four domains (source domains), and then tested on the remaining domain (target domain). Office-Home [office-home] is a popular domain adaptation dataset, which consists of four distinct domains: Artistic Images, Clip Art, Product, and Real-world Images. Each domain contains 65 categories, with around 15,500 images in total. While the dataset is mostly used in the domain adaptation context, it can easily be repurposed for domain generalization by following the same protocol used in the PACS dataset. We employ five datasets—MNIST, MNIST-M, USPS, SVHN and Synthetic Digits— for digit recognition and split training and testing subsets following [xu2018deep].
For a fair comparison with prior arts [balaji2018metareg, dinnocente2018domain, li2017domain, li2019episodic], we employ a ResNet-18 model as the backbone network in all experiments. The convolutional and BN layers are initialized with ImageNet pretrained weights while the fully connected layer is replaced and randomly initialized to match the number of classes for the object recognition task. We use a batch size of 32 images per source domain, and optimize the network parameters over 10K iterations using SGD-M with a momentum 0.9 and an initial learning rate . As suggested in [zhao2018adversarial], the learning rate is annealed by , where , , and increases linearly from 0 to 1 as training progresses. We follow the domain generalization convention by training with the “train” split from each of the source domains, then testing on the combined “train” and “validation” splits of the target domain. We make the mixture weights shared among all channels and layers in our network to facilitate optimization, especially in the lower layers according to our observation. This strategy improves accuracy substantially and consistently in all settings.
5.2 Ablation Study
PACS and Office-Home Dataset
Before comparing with the state-of-the-art domain generalization methods, we conduct an ablation study to assess the contribution of individual components within our full algorithm on the PACS and Office-Home datasets. Table 1 presents the results, where our complete method is denoted by DSON. It also presents accuracies of other methods, which include external methods as well as internal variations of our algorithm. We first present results from the baseline method, where the model is trained naïvely with BN layers that are not specific to any single domain. Then, to examine the effects of domain-specific normalization layers, the BN layers are made specific to each of the source domains, which is denoted by DSBN [dsbn]. We also examine the suitability of SN [sn] by replacing BN layers with adaptive mixtures of BN, IN and LN. We show three variations of the proposed network, DSON-, DSON-, and , DSON-, which are trained using three different training datasets, , , and , , respectively. Here, we follow the same notation in the previous section for . Finally, ensembling three models (DSON-Ensemble-) further improves the accuracy, and shows much stronger performance as compared to an ensemble of 3 independently trained BN models, which is denoted by BN-Ensemble. We do not include batch-instance normalization (BIN) [bin] in our experiment because it is hard to optimize and the results are unstable. The ablation study clearly illustrates the benefits of individual components in our algorithm: integration of multiple normalization methods, domain-specific normalization, and mixing weight sharing across layers.
The results on five digits datasets are shown in Table 2. We show four variations of the proposed network with four different training datasets: , , , and . Our model achieves 87.43% of average accuracy, outperforming all other baselines by large margins. Ensemble further improves the accuracy comparing to an ensemble of 4 independently trained BN models. Note that DSON is more effective in hard domain (MNIST-M).
|Noise level||Method||Art painting||Cartoon||Sketch||Photo||Avg.||Avg.|
5.3 Comparison with Other Methods
In this section, we compare our method with other domain generalization methods. For fair comparison, we use DSON- as our method.
Table 3 portrays the cross-domain object recognition accuracy on the PACS dataset. The proposed algorithm is compared with several existing methods, which include JiGen [li2017domain], D-SAM [dinnocente2018domain], Epi-FCR [li2019episodic] and MetaReg [balaji2018metareg], and it outperforms both the baseline and the state-of-the-art technique by significant margins (4.41% and 3.18%, respectively). DSON achieves the highest accuracy on two of the four target domains and reaches a close second in the remaining two domains. The result shows that DSON is particularly useful for hard domains, achieving 7.31% and 5.23% accuracy gain from the baseline in the Sketch and Cartoon, respectively.
We also evaluate DSON on the Office-Home dataset, and the results are presented in Table 4. As in PACS, DSON outperforms the recently proposed D-SAM [dinnocente2018domain] as well as our baseline. We find that DSON achieves the best score on all target domains except the “Product” domain, where we notice that DSON is slightly worse than the best model. Again, DSON is more advantageous in hard domains; it shows 1.31% improvement from the baseline in the Clipart domain.
5.4 Additional Experiments
Domain Generalization with Label Noise
The performance of the proposed algorithm on the PACS dataset is tested in the presence of label noise, and the results are investigated against other approaches. Two different noise levels are tested (0.2 and 0.5), and the results are presented in Table 5. Although all algorithms undergo performance degradation, the amount of accuracy drops is marginal in general and DSON turns out to be more reliable with respect to label noise compared to other models.
Multi-Source Domain Adaptation
DSON can be extended to the multi-source domain adaptation task, where we gain access to unlabeled data from the target domain. To compare the effect of different normalization, we adopt Shu et al. [shu2019transferable] as the baseline method and vary the normalization method only. The results are shown in Table 6, where we compare DSON with SN, and DSBN [dsbn]. All compared methods illustrate a large improvement over the baseline. In direct contrast to the results from the ablation analysis in Table 1, DSBN is clearly superior to SN. This is unsurprising, given that DSBN is focused specifically on the domain adaptation task. We find that DSON outperforms not only the baseline but also DSBN, which demonstrates how effectively DSON can be extended to the domain adaptation task. It shows that domain-specific models consistently outperform their domain-agnostic counterparts.
We presented a simple but effective domain generalization algorithm based on domain-specific optimized normalization layers. The proposed algorithm uses multiple normalization methods while learning a separate affine parameter per domain. The mixing weights are employed to compute the weighted average of multiple normalization statistics. This strategy turns out to be helpful for learning domain-invariant representations since instance normalization removes domain-specific style effectively and makes the trained model focus on semantic category information. In addition, we propose a domain augmentation method to diversify the source domains. The proposed algorithm achieves the state-of-the-art accuracy consistently on multiple standard benchmark datasets even with substantial label noise. We also showed that the domain-specific optimization of normalization types is well-suited for unsupervised domain adaptation.