Generalizing Across Domains via Cross-Gradient Training
We present CrossGrad, a method to use multi-domain training data to learn a classifier that generalizes to new domains. CrossGrad does not need an adaptation phase via labeled or unlabeled data, or domain features in the new domain. Most existing domain adaptation methods attempt to erase domain signals using techniques like domain adversarial training. In contrast, CrossGrad is free to use domain signals for predicting labels, if it can prevent overfitting on training domains. We conceptualize the task in a Bayesian setting, in which a sampling step is implemented as data augmentation, based on domain-guided perturbations of input instances. CrossGrad parallelly trains a label and a domain classifier on examples perturbed by loss gradients of each other’s objectives. This enables us to directly perturb inputs, without separating and re-mixing domain signals while making various distributional assumptions. Empirical evaluation on three different applications where this setting is natural establishes that (1) domain-guided perturbation provides consistently better generalization to unseen domains, compared to generic instance perturbation methods, and that (2) data augmentation is a more stable and accurate method than domain adversarial training.
We investigate how to train a classification model using multi-domain training data, so as to generalize to labeling instances from unseen domains. This problem arises in many applications, viz., handwriting recognition, speech recognition, sentiment analysis, and sensor data interpretation. In these applications, domains may be defined by fonts, speakers, writers, etc. Most existing work on handling a target domain not seen at training time requires either labeled or unlabeled data from the target domain at test time. Often, a separate “adaptation” step is then run over the source and target domain instances, only after which target domain instances are labeled. In contrast, we consider the situation where, during training, we have labeled instances from several domains which we can collectively exploit so that the trained system can handle new domains without the adaptation step.
1.1 Problem statement
Let be a space of domains. During training we get labeled data from a proper subset of these domains. Each labeled example during training is a triple where x is the input, is the true class label from a finite set of labels and is the domain from which this example is sampled. We must train a classifier to predict the label for examples sampled from all domains, including the subset not seen in the training set. Our goal is high accuracy for both in-domain (i.e., in ) and out-of-domain (i.e., in ) test instances.
One challenge in learning a classifier that generalizes to unseen domains is that is typically harder to learn than . While Yang2015 addressed a similar setting, they assumed a specific geometry characterizing the domain, and performed kernel regression in this space. In contrast, in our setting, we wish to avoid any such explicit domain representation, appealing instead to the power of deep networks to discover implicit features.
Lacking any feature-space characterization of the domain, conventional training objectives (given a choice of hypotheses having sufficient capacity) will tend to settle to solutions that overfit on the set of domains seen during training. A popular technique in the domain adaptation literature to generalize to new domains is domain adversarial training (ganin16; tzeng17). As the name suggests, here the goal is to learn a transformation of input x to a domain-independent representation, with the hope that amputating domain signals will make the system robust to new domains. We show in this paper that such training does not necessarily safeguard against over-fitting of the network as a whole. We also argue that even if such such overfitting could be avoided, we do not necessarily want to wipe out domain signals, if it helps in-domain test instances.
In a marked departure from domain adaptation via amputation of domain signals, we approach the problem using a form of data augmentation based on domain-guided perturbations of input instances. If we could model exactly how domain signals for manifest in x, we could simply replace these signals with those from suitably sampled other domains to perform data augmentation. We first conceptualize this in a Bayesian setting: discrete domain ‘causes’ continuous multivariate g, which, in combination with , ‘causes’ x. Given an instance x, if we can recover g, we can then perturb g to , thus generating an augmented instance . Because such perfect domain perturbation is not possible in reality, we first design an (imperfect) domain classifier network to be trained with a suitable loss function. Given an instance x, we use the loss gradient w.r.t. x to perturb x in directions that change the domain classifier loss the most. The training loss for the -predictor network on original instances is combined with the training loss on the augmented instances. We call this approach cross-gradient training, which is embodied in a system we describe here, called CrossGrad. We carefully study the performance of CrossGrad on a variety of domain adaptive tasks: character recognition, handwriting recognition and spoken word recognition. We demonstrate performance gains on new domains without any out-of-domain instances available at training time.
2 Related Work
Domain adaptation has been studied under many different settings: two domains (ganin16; tzeng17) or multiple domains (Mansour2009; Zhang2015), with target domain data that is labeled (Daume2007; Kumar2010; Saenko2010) or unlabeled (Gopalan2011; Gong2012; ganin16), paired examples from source and target domain (PengWE17), or domain features attached with each domain (Yang2016). Domain adaptation techniques have been applied to numerous tasks in speech, language processing and computer vision (Woodland2001; Saon2013; Jiang2007; Daume2007; Saenko2010; Gopalan2011; Li13; HuangB17; UpchurchSB16). However, unlike in our setting, these approaches typically assume the availability of some target domain data which is either labeled or unlabeled.
For neural networks a recent popular technique is domain adversarial networks (DANs) (tzeng15; tzeng17; ganin16). The main idea of DANs is to learn a representation in the last hidden layer (of a multilayer network) that cannot discriminate among different domains in the input to the first layer. A domain classifier is created with the last layer as input. If the last layer encapsulates no domain information apart from what can be inferred from the label, the accuracy of the domain classifier is low. The DAN approach makes sense when all domains are visible during training. In this paper, our goal is to generalize to unseen domains.
Domain generalization is traditionally addressed by learning representations that encompass information from all the training domains. MuandetBS13 learn a kernel-based representation that minimizes domain dissimilarity and retains the functional relationship with the label. GanYG16 extends MuandetBS13 by exploiting attribute annotations of examples to learn new feature representations for the task of attribute detection. In GhifaryBZB15, features that are shared across several domains are estimated by jointly learning multiple data-reconstruction tasks. Such representations are shown to be effective for domain generalization, but ignore any additional information that domain features can provide about labels.
Domain adversarial networks (DANs) (ganin16) can also be used for domain generalization in order to learn domain independent representations. A limitation of DANs is that they can be misled by a representation layer that over-fits to the set of training domains. In the extreme case, a representation that simply outputs label logits via a last linear layer (making the softmax layer irrelevant) can keep both the adversarial loss and label loss small, and yet not be able to generalize to new test domains. In other words, not being able to infer the domain from the last layer does not imply that the classification is domain-robust.
Since we do not assume any extra information about the test domains, conventional approaches for regularization and generalizability are also relevant. XuLNX14 use exemplar-based SVM classifiers regularized by a low-rank constraint over predictions. ECCV12_Khosla also deploy SVM based classifier and regularize the domain specific components of the learners. The method most related to us is the adversarial training of (szegedy14; goodfellow2014; MiyatoDG16) where examples perturbed along the gradient of classifier loss are used to augment the training data. perturbs examples. Instead, our method attempts to model domain variation in a continuous space and perturbs examples along domain loss.
Our Bayesian model to capture the dependence among label, domains, and input is similar to Zhang2015, but the crucial difference is the way the dependence is modeled and estimated. Our method attempts to model domain variation in a continuous space and project perturbation in that space to the instances.
3 Our approach
We assume that input objects are characterized by two uncorrelated or weakly correlated tags: their label and their domain. E.g. for a set of typeset characters, the label could be the corresponding character of the alphabet (‘A’, ‘B’ etc) and the domain could be the font used to draw the character. In general, it should be possible to change any one of these, while holding the other fixed.
We use a Bayesian network to model the dependence among the label , domain , and input x as shown in Figure 1. Variables and are discrete and lie in continuous multi-dimensional spaces.
|(Generative)||(Conditional after removing )|
The domain induces a set of latent domain features g. The input x is obtained by a complicated, un-observed mixing
Under this assumption can be modeled during training, so that during inference we can infer for a given x by estimating
where is the inferred continuous representation of the domain of x.
This assumption is key to our being able to claim generalization to new domains even though most real-life domains are discrete. For example, domains like fonts and speakers are discrete, but their variation can be captured via latent continuous features (e.g. slant, ligature size etc. of fonts; speaking rate, pitch, intensity, etc. for speech). The assumption states that as long as the training domains span the latent continuous features we can generalize to new fonts and speakers.
We next elaborate on how we estimate and using the domain labeled data . The main challenge in this task is to ensure that the model for is not over-fitted on the inferred g’s of the training domains. In many applications, the per-domain is significantly easier to train. So, an easy local minima is to choose a different g for each training and generate separate classifiers for each distinct training domain. We must encourage the network to stay away from such easy solutions. We strive for generalization by moving along the continuous space g of domains to sample new training examples from hallucinated domains. Ideally, for each training instance from a given domain , we wish to generate a new by transforming its (inferred) domain to a random domain sampled from , keeping its label unchanged. Under the domain continuity assumption (A1), a model trained with such an ideally augmented dataset is expected to generalize to domains in .
However, there are many challenges to achieving such ideal augmentation. To avoid changing , it is convenient to draw a sample g by perturbing . But may not be reliably inferred, leading to a distorted sample of g. For example, if the obtained from an imperfect extraction conceals label information, then big jumps in the approximate g space could change the label too. We propose a more cautious data augmentation strategy that perturbs the input to make only small moves along the estimated domain features, while changing the label as little as possible. We arrive at our method as follows.
Domain inference. We create a model to extract domain features g from an input x. We supervise the training of to predict the domain label as where is a softmax transformation. We use to denote the cross-entropy loss function of this classifier. Specifically, is the domain loss at the current instance.
Given an example , we seek to sample a new example (i.e., with the same label ), whose domain is as “far” from as possible. To this end, consider setting . Intuitively, this perturbs the input along the direction of greatest domain change
What is the consequent change of the continuous domain features ? This turns out to be , where is the Jacobian of w.r.t. x. Geometrically, the term is the (transpose of the) metric tensor matrix accounting for the distortion in mapping from the x-manifold to the -manifold. While this perturbation is not very intuitive in terms of the direct relation between and , we show in the Appendix that the input perturbation is also the first step of a gradient descent process to induce the “natural” domain perturbation .
The above development leads to the network sketched in Figure 2, and an accompanying training algorithm, CrossGrad, shown in Algorithm 1. Here correspond to a minibatch of instances. Our proposed method integrates data augmentation and batch training as an alternating sequence of steps. The domain classifier is simultaneously trained with the perturbations from the label classifier network so as to be robust to label changes. Thus, we construct cross-objectives and , and update their respective parameter spaces. We found this scheme of simultaneously training both networks to be empirically superior to independent training even though the two classifiers do not share parameters.
If and are completely correlated, CrossGrad reduces to traditional adversarial training. If, on the other extreme, they are perfectly uncorrelated, removing domain signal should work well. The interesting and realistic situation is where they are only partially correlated. CrossGrad is designed to handle the whole spectrum of correlations.
In this section, we demonstrate that CrossGrad provides effective domain generalization on four different classification tasks under three different model architectures. We provide evidence that our Bayesian characterization of domains as continuous features is responsible for such generalization. We establish that CrossGrad’s domain guided perturbations provide a more consistent generalization to new domains than label adversarial perturbation (goodfellow2014) which we denote by LabelGrad. Also, we show that DANs, a popular domain adaptation method that suppresses domain signals, provides little improvement over the baseline (ganin16; tzeng17).
We describe the four different datasets and present a summary in Table 1.
Character recognition across fonts. We created this dataset from Google Fonts
Handwriting recognition across authors. We used the LipiTk dataset that comprises of handwritten characters from the Devanagari script
MNIST across synthetic domains. This dataset derived from MNIST was introduced by GhifaryBZB15. Here, labels comprise the 10 digits and domains are created by rotating the images in multiples of 15 degrees: 0, 15, 30, 45, 60 and 75. The domains are labeled with the angle by which they are rotated, e.g., M15, M30, M45. We tested on domain M15 while training on the rest. The network is the 2-layer convolutional one used by motiian2017CCSA.
Spoken word recognition across users. We used the Google Speech Command Dataset
For all experiments, the set of domains in the training, test, and validation sets were disjoint. We selected hyper-parameters based on accuracy on the validation set as follows. For LabelGrad the parameter was chosen from and for CrossGrad we chose from the same set of values. We chose ranges so that norm of the perturbations are of similar sizes in LabelGrad and CrossGrad. The multiples in the range came from . The optimizer for the first three datasets is RMS prop with a learning rate () of 0.02 whereas for the last Speech dataset it is SGD with initially and 0.0001 after 15 iterations. In CrossGrad networks, g is incorporated in the label classifier network by concatenating with the output from the last but two hidden layer.
4.1 Overall comparison
In Table 2 we compare CrossGrad with domain adversarial networks (DAN), label adversarial perturbation (LabelGrad), and a baseline that performs no special training. For the MNIST dataset the baseline is CCSA (motiian2017CCSA) and D-MTAE (GhifaryBZB15). We observe that, for all four datasets, CrossGrad provides an accuracy improvement. DAN, which is designed specifically for domain adaptation, is worse than LabelGrad, which does not exploit domain signal in any way. While the gap between LabelGrad and CrossGrad is not dramatic, it is consistent as supported by this table and other experiments that we later describe.
|Font||36 characters||109 fonts|
|Handwriting||111 characters||74 authors|
|MNIST||10 digits||6 rotations|
|Speech||12 commands||1888 speakers|
Changing model architecture. In order to make sure that these observed trends hold across model architectures, we compare different methods with the model changed to a 2-block ResNet (he16deepresidual) instead of LeNet (lenet97) for the Fonts and Handwriting dataset in Table 3. For both datasets the ResNet model is significantly better than the LeNet model. But even for the higher capacity ResNet model, CrossGrad surpasses the baseline accuracy as well as other methods like LabelGrad .
4.2 Why does CrossGrad work?
We present insights on the working of CrossGrad via experiments on the MNIST dataset where the domains corresponding to image rotations are easy to interpret.
In Figure 2(a) we show PCA projections of the g embeddings for images from three different domains, corresponding to rotations by 30, 45, 60 degrees in green, blue, and yellow respectively. The g embeddings of domain 45 (blue) lies in between the g of domains 30 (green) and 60 (yellow) showing that the domain classifier has successfully extracted continuous representation of the domain even when the input domain labels are categorical. Figure 2(b) shows the same pattern for domains 0, 15, 30. Here again we see that the embedding of domain 15 (blue) lies in-between that of domain 0 (yellow) and 30 (green).
Next, we show that the g perturbed along gradients of domain loss, does manage to generate images that substitute for the missing domains during training. For example, the embeddings of the domain 45, when perturbed, scatters towards the domain 30 and 60 as can be seen in Figure 2(c): note the scatter of perturbed 45 (red) points inside the 30 (green) zone, without any 45 (blue) points. Figure 2(d) depicts a similar pattern with perturbed domain embedding points (red) scattering towards domains 30 and 0 more than unperturbed domain 15 (blue). For example, between x-axis -1 and 1 dominated by the green domain (domain 30) we see many more red points (perturbed domain 15) than blue points (domain 15). Similarly in the lower right corner of domain 0 shown in yellow. This highlights the mechanism of CrossGrad working; that it is able to augment training with samples closer to unobserved domains.
Finally, we observe in Figure 4 that the embeddings are not correlated with labels. For both domains 30 and 45 the colors corresponding to different labels are not clustered. This is a consequence of CrossGrad’s symmetric training of the domain classifier via label-loss perturbed images.
4.3 When is domain generalization effective?
We next present a couple of experiments that provide insight into the settings in which CrossGrad is most effective.
First, we show the effect of increasing the number of training domains. Intuitively, we expect CrossGrad to be most useful when training domains are scarce and do not directly cover the test domains. We verified this on the speech dataset where the number of available domains is large. We varied the number of training domains while keeping the test and validation data fixed. Table 4 summarizes our results. Note that CrossGrad outperforms the baseline and LabelGrad most significantly when the number of training domains is small (40). As the training data starts to cover more and more of the possible domain variations, the marginal improvement provided by CrossGrad decreases. In fact, when the models are trained on the full training data (consisting of more than 1000 domains), the baseline achieves an accuracy of 88.3%, and both CrossGrad and LabelGrad provide no gains
|Method Name||40 domains||100 domains||200 domains||1000 domains|
In general, how CrossGrad handles multidimensional, non-linear involvement of g in determining x is difficult to diagnose. To initiate a basic understanding of how data augmentation supplies CrossGrad with hallucinated domains, we considered a restricted situation where the discrete domain is secretly a continuous 1-d space, namely, the angle of rotation in MNIST. In this setting, a natural question is, given a set of training domains (angles), which test domains (angles) perform well under CrossGrad?
We conducted leave-one-domain-out experiments by picking one domain as the test domain, and providing the others as training domains. In Table 5 we compare the accuracy of different methods. We also compare against the numbers reported by the CCSA method of domain generalization (motiian2017CCSA) as reported by the authors.
It becomes immediately obvious from Table 5 that CrossGrad is beaten in only two cases: M0 and M75, which are the two extreme rotation angles. For angles in the middle, CrossGrad is able to interpolate the necessary domain representation g via ‘hallucination’ from other training domains. Recall from Figures 2(c) and 2(d) that the perturbed g during training covers for the missing test domains. In contrast, when M0 or M75 are in the test set, CrossGrad’s domain loss gradient does not point in the direction of domains outside the training domains. If and how this insight might generalize to more dimensions or truly categorical domains is left as an open question.
Domain and label interact in complicated ways to influence the observable input x. Most domain adaption strategies implicitly consider the domain signal to be extraneous and seek to remove its effect to train more robust label predictors. We presented CrossGrad, which considers them in a more symmetric manner. CrossGrad provides a new data augmentation scheme based on the (respectively, ) predictor using the gradient of the (respectively, ) predictor over the input space, to generate perturbations. Experiments comparing CrossGrad with various recent adversarial paradigms show that CrossGrad can make better use of partially correlated and , without requiring explicit distributional assumptions about how they affect x. CrossGrad is at its best when training domains are scarce and do not directly cover test domains well. Future work includes extending CrossGrad to exploit labeled or unlabeled data in the test domain, and integrating the best of LabelGrad and CrossGrad into a single algorithm.
We gratefully acknowledge the support of NVIDIA Corporation with the donation of Titan X GPUs used for this research. We thank Google for supporting travel to the conference venue.
Relating the “natural” perturbations of x and . In Section 3, we claimed that the intuitive perturbation of attempts to induce the intuitive perturbation of , even though the exact relation between a perturbation of and that of requires the metric tensor transpose . We will now prove this assertion. Of course, an isometric map , with an orthogonal Jacobian (), trivially has this property, but we present an alternative derivation which may give further insight into the interaction between the perturbations in the general case.
Consider perturbing to produce . This yields a new augmented input instance as
We show next that the perturbed can be approximated by .
In this proof we drop the subscript for ease of notation. In the forward direction, the relationship between and can be expressed using the Jacobian of w.r.t. x:
To invert the relationship for a non-square and possibly low-rank Jacobian, we use the Jacobian transpose method devised for inverse kinematics (BalestrinoMS84; WolovichE84). Specifically, we write , and recast the problem as trying to minimize the squared L2 error
with gradient descent. The gradient of the above expression w.r.t. x is
Hence, the initial gradient descent step to affect a change of in the domain features would increment x by . The Jacobian, which is a matrix of first partial derivatives, can be computed by back-propagation. Thus we get
which, by the chain rule, gives
- The dependence of on x could also be via continuous hidden variables but our model for domain generalization is agnostic of such structure.
- We use as shorthand for the gradient evaluated at .
- The gap in accuracy between the baseline and CrossGrad for the case of 1000 domains is not statistically significant according to the MAPSSWE test (MAPSSWE).