DADA: Deep Adversarial Data Augmentation for Extremely Low Data Regime Classification
Deep learning has revolutionized the performance of classification, but meanwhile demands sufficient labeled data for training. Given insufficient data, while many techniques have been developed to help combat overfitting, the challenge remains if one tries to train deep networks, especially in the ill-posed extremely low data regimes: only a small set of labeled data are available, and nothing – including unlabeled data – else. Such regimes arise from practical situations where not only data labeling but also data collection itself is expensive. We propose a deep adversarial data augmentation (DADA) technique to address the problem, in which we elaborately formulate data augmentation as a problem of training a class-conditional and supervised generative adversarial network (GAN). Specifically, a new discriminator loss is proposed to fit the goal of data augmentation, through which both real and augmented samples are enforced to contribute to and be consistent in finding the decision boundaries. Tailored training techniques are developed accordingly. To quantitatively validate its effectiveness, we first perform extensive simulations to show that DADA substantially outperforms both traditional data augmentation and a few GAN-based options. We then extend experiments to three real-world small labeled datasets where existing data augmentation and/or transfer learning strategies are either less effective or infeasible. All results endorse the superior capability of DADA in enhancing the generalization ability of deep networks trained in practical extremely low data regimes. Source code is available at https://github.com/SchafferZhang/DADA.
DADA: Deep Adversarial Data Augmentation for Extremely Low Data Regime Classification
Xiaofeng Zhang University of Science and Technology of China Hefei, Anhui, China firstname.lastname@example.org Zhangyang Wang Texas A&M University College Station, TA, USA email@example.com Dong Liu University of Science and Technology of China Hefei, Anhui, China firstname.lastname@example.org Qing Ling Sun Yat-Sen University Guangzhou, Guangdong, China email@example.com
noticebox[b]Preprint. Work in progress.\end@float
The performance of classification and recognition has been tremendously revolutionized by the prosperity of deep learning krizhevsky2012imagenet (). Deep learning-based classifiers can reach unprecedented accuracy given that there are sufficient labeled data for training. Meanwhile, such a blessing can turn into a curse: in many realistic settings where either massively annotating labels is a labor-intensive task, or only limited datasets are available, a deep learning model will easily overfit and generalizes poorly. Many techniques have been developed to help combat overfitting with insufficient data, ranging from classical data augmentation perez2017effectiveness (), to dropout krizhevsky2012imagenet () and other structural regularizations han2015learning (), to pre-training erhan2010does (), transfer learning raina2007self () and semi-supervised learning kingma2014semi (). However in low data regimes, even these techniques will fall short, and the resulting models usually cannot capture all possible input data variances and distinguish them from nuisance variances. The high-variance gradients also cause popular training algorithms, e.g., stochastic gradient descent, to be extremely unstable.
In this paper, we place ourself in front of an even more ill-posed and challenging problem: how to learn a deep network classifier, where the labeled training set is high-dimensional but small in sample size? Most existing methods in the low data regimes deal with the scarcity of labeled data; however, they often assume the help from abundant unlabeled samples in the same set, or (labeled or unlabeled) samples from other similar datasets, enabling various semi-supervised or transfer learning solutions. Different from them, we investigate a less-explored and much more daunting task setting of extremely low data regimes: besides the given small amount of labeled samples, neither unlabeled data from the same distribution nor data from similar distributions are assumed to be available throughout training. In other words, we aim to train a deep network classifier from scratch, using only the given small number of labeled data and nothing else. Our only hope lies in maximizing the usage of the given small training set, by finding nontrivial and semantically meaningful re-composition of sample information that helps us characterize the underlying distribution. The extremely low data regimes for classification are ubiquitous in and have blocked many practical or scientific fields, where not only data labeling, but data collection itself is also expensive to scale up. For example, image subjects from military and medical imagery are usually expensive to collect, and often admit quite different distributions from easily accessible natural images. While we mostly focus on image classification/visual recognition in this paper, our methodology can be readily extended to classifying non-image data in extremely low data regimes; we will intentionally present one such example of electroencephalographic (EEG) signal classification in Section 5. To resolve the challenges, we have made multi-fold technical contributions in this paper:
For learning deep classifiers in extremely low data regimes, we focus on boosting the effectiveness of data augmentation, and introduce learning-based data augmentation, that can be optimized for classifying general data without relying on any domain-specific prior or unlabeled data. The data augmentation module and the classifier are formulated and learned together as a fully-supervised generative adversarial network (GAN). We call the proposed framework Deep Adversarial Data Augmentation (DADA).
We propose a new loss function for the GAN discriminator, that not only learns to classify real images, but also enforces fine-grained classification over multiple “fake classes”. That is referred to as the 2 loss, in contrast to the +1 loss used by several existing GANs (to be compared in the context later). The novel loss function is motivated by our need of data augmentation: the generated augmented (“fake”) samples need to be discriminative among classes too, and the decision boundaries learned on augmented samples shall align consistently with those learned on real samples. We show in experiments that the 2 loss is critical to boost the overall classification performance.
We conduct extensive simulations on CIFAR-10, CIFAR-100, and SVNH, to train deep classifiers in the extremely low data regimes, demonstrating significant performance improvements through DADA compared to using traditional data augmentation. To further validate the practical effectiveness of DADA, we train deep classifiers on three real-world small datasets: the Karolinska Directed Emotional Faces (KDEF) dataset for the facial expression recognition task, a Brain-Computer Interface (BCI) Competition dataset for the EEG brain signal classification task, and the Curated Breast Imaging Subset of the Digital Database for Screening Mammography (CBIS-DDSM) dataset for the tumor classification task. For all of them, DADA leads to highly competitive generalization performance.
2 Related Work
2.1 Generative Adversarial Networks
Generative Adversarial Networks (GANs) goodfellow2014generative () have gathered a significant amount of attention due to their ability to learn generative models of multiple natural image datasets. The original GAN model and its many successors are unsupervised: their discriminators have a single probabilistic realness output attempting to decipher whether an input image is real or generated (a.k.a. fake). Conditional GAN mirza2014conditional () generates data conditioned on class labels via label embeddings in both discriminator and generator. Conditioning generated samples on labels sheds light the option of semi-supervised classification using GANs. In odena2016semi (), the semi-supervised GAN has the discriminator network to output class labels, leading to a class loss function consisting of class labels if the sample is decided to be real, and a single extra class if the sample is decided to be fake. Such a structured loss has been re-emphasized in salimans2016improved () to provide more informed training that leads to generated samples capturing class-specific variances better. Even with the proven success of GANs for producing realistic-looking images, tailoring GANs for classification is not as straightforward as it looks like dai2017good (). The first question would naturally be: is GAN really creating semantically novel compositions, or simply memorizing its training samples (or adding trivial nuisances)? Fortunately, there seems to be empirical evidence that GANs perform at least some non-trivial modeling of the unknown distribution and are able to interpolate in the latent space radford2015unsupervised (); arora2017gans (). However, previous examinations also reported that the diversity of generated samples is far poorer than the true training dataset. In santurkar2017classification (), the authors tried several unconditional GANs to synthesize samples, on which they trained image classifiers. They reported that the accuracies achieved by such classifiers were comparable only to the accuracy of a classifier trained on a 100 (or more) subsampled version of the true dataset, and the gap cannot be reduced by drawing more samples from the GANs. Despite many insights revealed, the authors did not consider low data regimes. More importantly, they focused on a different goal on using classification performance to measure the diversity of generated data. As a result, they neither considered class-conditional GANs, nor customized any GAN structure for the goal of classification-driven data augmentation. Besides, GANs also have hardly been focused towards non-image subjects.
2.2 Deep Learning on Small Samples
Training a deep learning classifier on small datasets is a topic of wide interests in the fields of fine-grained visual recognition lin2015bilinear (), few-shot learning mehrotra2017generative (), and life-long learning in new environments shu2017lifelong (). Hereby we review and categorize several mainstream approaches.
Dimensionality Reduction and Feature Selection. A traditional solution to overfitting caused by high dimension, low sample size data is to perform dimensionality reduction or feature selection as pre-processing, and to train (deep) models on the new feature space. Such pre-processing has become less popular in deep learning because the latter often emphasizes end-to-end trainable pipelines. A recent work liu2017deep () performed the joint training of greedy feature selection and a deep classifier; but their model was designed for bioinformatics data (attributed vectors) and it was unclear how a similar model can be applied to raw images.
Pre-training and Semi-Supervised Learning. Both pre-training and semi-supervised learning focus on improving classification with smalled labeled samples, by utilizing extra data from the same training distribution but is unlabeled. Greedy pre-training with larger unlabeled data, e.g., via auto-encoders, could help learn (unsupervised) feature extractors and converge to a better generalizing minimum erhan2010does (). In practice, pre-training is often accompanied with data augmentation vincent2010stacked (). Semi-supervised learning also utilizes extra unlabeled data, while unlabeled data contribute to depicting data density and thus locating decision boundaries within low-density regions; see lee2013pseudo (); kingma2014semi (); rasmus2015semi (); salimans2016improved (). However, note that both pre-training and semi-supervised learning rely heavily on the abundance of unlabeled data: they are motivated by the same hypothesis that while labeling data is difficult, collecting unlabeled data remains to be a cheap task. While the hypothesis is valid in many computer vision tasks, it may not always stand true and differs from our target – extremely low data regimes.
Transfer Learning. Compared to the above two, transfer learning admits a more relaxed setting: using unlabeled data from a similar or overlapped distribution (a.k.a. source domain), rather than from the same target distribution as labeled samples (a.k.a. target domain). For standard visual recognition, common visual patterns like edges are often shared between different natural image datasets. This makes a knowledge transfer between such datasets promising raina2007self (), even though their semantics are not strictly tied. Empirical study wagner2013learning () showed that, the weight transfer from deep networks trained on a source domain with abundant (labeled) data can boost visual recognition on a target domain where labeled samples are scarce. It is, however, unclear whether transfer or how much learning will help, if the source and target domains possess notable discrepancy.
Data Augmentation. Data augmentation is an alternative strategy to bypass the unavailability of labeled training data, by artificially synthesizing new labeled samples from existing ones. Traditional data augmentation techniques rely on a very limited set of known invariances that are easy to invoke, and adopt ad-hoc, minor perturbations that will not change labels. For instance, in the case of image classification, typical augmentations include image rotation, lighting/color tone modifications, rescaling, cropping, or as simple as adding random noise krizhevsky2012imagenet (). However, such empirical label-preserving transformations are often unavailable in non-image domains. A latest work (ratner2017learning, ) presented a novel direction to select and compose pre-specified base data transformations (such as rotations, shears, central swirls for images) into a more sophisticated “tool chain” for data augmentation, using generative adversarial training. They achieve highly promising results on both image and text datasets, but need the aid of unlabeled data in training (the same setting as in salimans2016improved ()). We experimentally compare the method (ratner2017learning, ) and DADA and analyze their more differences in Section 5.3.
Few efforts went beyond encoding priori known invariances to explore more sophisticated, learning-based augmentation strategies. Several semi-supervised GANs, e.g., salimans2016improved (), could also be viewed as augmented unlabeled samples from labeled ones. A Bayesian Monte Carlo algorithm for data augmentation was proposed in tran2017bayesian (), and was evaluated on standard label-rich image classification datasets. The authors of hauberg2016dreaming () learned class-conditional distributions by a diffeomorphism assumption. A concurrent preprint antoniou2017data () explored a Data Augmentation Generative Adversarial Network (DAGAN): the authors developed a completely different GAN model from the proposed DADA, whose generator does not depend on the classes and the discriminator is a vanilla real/fake one. Hence, it enables DAGAN to be applicable to unseen new classes for few-shot learning scenarios, different from our goal of improving fully-supervised classification. As we will see in experiments, deriving a stronger discriminator is critical in our target task.
We also noticed an interesting benchmark study conducted in perez2017effectiveness () to compare among various data augmentation techniques, including very sophisticated generative models such as CycleGAN zhu2017unpaired (). Somewhat surprisingly, they found traditional ad-hoc augmentation techniques to be still able to outperform existing learning-based choices. Overall, enhancing small sample classification via learning-based data augmentation remains as an open and under-investigated problem.
Domain-specific Data Synthesis. A number of works wang2015deepfont (); le2017using (); sixt2016rendergan (); shrivastava2017learning (); wangadversarial (); jaderberg2014synthetic () explored the “free” generation of labeled synthetics examples to assist training. However, they either relied on extra information, e.g., 3D models of the subject, or were tailored for one special object class such as face or license plates. The synthesis could also be viewed as a special type of data augmentation that hinges on stronger forms of priori invariance knowledge.
Training Regularization. A final option to fight against small datasets is to exploit variance reduction techniques for network design and training. Examples include dropout krizhevsky2012imagenet (), dropconnect wan2013regularization (), and enforcing compact structures on weights or connection patterns (e.g., sparsity) han2015learning (). Those techniques are for the general purposes of alleviating overfitting, and they alone are unlikely to resolve the challenge of extremely low data regimes.
3 Technical Approach
3.1 Problem Formulation and Solution Overview
Consider a general -class classification problem. Suppose that we have a training set , where denotes a sample and a corresponding label, . Our task is to learn a good classifier to predict the label , by minimizing the empirical risk objective over , being some loss function such as K-L divergence. As our goal, a good should generalize well on an unseen test set . In classical deep learning-based classification settings, is large enough to ensure that goal. However in our extremely low data regimes, can be too small to support robust learning of any complicated decision boundary, causing severe overfitting.
Data augmentation approaches seek an augmenter , to synthesize a new set of augmented labeled data from , constituting the new augmented training set of size . Traditional choices of , being mostly ad-hoc minor perturbations, are usually class-independent, i.e., constructing a sample-wise mapping from to without taking into account the class distribution. Such mappings are usually limited to a small number of priori known, hand-crafted perturbations. They are not learned from data, and are not optimized towards finding classification boundaries. To further improve , one may consider the inter-sample relationships hauberg2016dreaming (), as well as inter-class relationships in , where training a generative model over becomes a viable option.
The conceptual framework of DADA is depicted in Figure 1. If taking a GAN point of view towards this, naturally resembles a generator: its inputs can be latent variables conditioned on , and outputs belonging to the same class but being sufficiently diverse from . can act as the discriminator, if it will incorporate typical GAN’s real-fake classification in addition to the target -class classification. Ideally, the classifier should: (1) be able to correctly classify both real samples and augmented samples into the correct class ; (2) be unable to distinguish and . The entire DADA framework of and can be jointly trained on , whose procedure will bear similarities to training a class-conditional GAN. However, existing GANs may not fit the task well, due to the often low diversity of generated samples. We are hence motivated to introduce a novel loss function towards generating more diverse and class-specific samples.
3.2 Going More Discriminative: From Loss to Loss
The discriminator of a vanilla, unsupervised GAN goodfellow2014generative () has only one output to indicate the probability of its input being a real sample. In salimans2016improved (); odena2016semi (), the discriminator is extended with a semi-supervised fashion loss, whose output is a ()-dimensional probabilistic vector: the first elements denote the probabilities of the input coming from the class 1, 2, …, of real data; the ()-th denotes its probability of belonging to the generated fake data. In that way, the generator simply has the semi-supervised classifier learned on additional unlabeled examples and supplied as a new “generated” class. In contrast, when in extremely low data regimes, we tend to be more “economical” on consuming data. We recognize that the unlabeled data provides weaker guidance than labeled data to learn the classification decision boundary. Therefore, if there is no real unlabeled data available and we can only generate from given limited labeled data, generating labeled data (if with quality) should benefit classifier learning more, compared to generating the same amount of unlabeled data. Further, the generated labeled samples should join force with the real labeled samples, and their decisions on the classification boundary should be well aligned. Motivated by the above design philosophy, we build a new loss function, whose first group of outputs represent the probabilities of the input data from the class 1, 2, …, of real data; its second group of outputs represent the probabilities of the input data from the class 1, 2, …, of fake data.
Since we use a class-conditional augmenter (generator), the label used to synthesize the augmented (fake) sample could be viewed to supply the “ground truth” for the second group of outputs. For example, for = 2, if the input datum is real and belongs to class 1, then its ground truth label is ; otherwise if the input data is augmented conditionally on label of class 1, then its ground truth label is . During training, the K-L divergence is computed between the -length output and its ground truth label. For testing, we add the -th and -th elements of the output to denote the probability of the input belonging to class , . A comparison among loss functions for GANs including DADA is listed in Table 1.
The detailed training algorithm for DADA is outlined in supplementary.
To evaluate our approach, we first conduct a series of simulations on three widely adopted image classification benchmarks: CIFAR-10, CIFAR-100, and SVHN. We intentionally sample the given training data to simulate the extremely low data regimes, and compare the following training options. 1) C: directly train a classifier using the limited training data; 2) C_augmented: perform traditional data augmentation (including rotation, translation and flipping), and then train a classifier; 3) DADA: the proposed data augmentation; 4) DADA_augmented: first apply the same traditional augmentation as C_augmented on the real samples, then perform DADA. We use absolutely no unlabeled data or any pre-trained initialization in training, different from the setting of most previous works. We use the original full test sets for evaluation. The network architectures that we used have been exhaustively tuned to ensure the best possible performance of all baselines in those unusually small training sets. Detailed configurations and hyperparameters, as well as visualized examples of augmented samples, are given in the supplementary.
CIFAR-10 and CIFAR-100
The CIFAR-10 dataset consists of 60,000 color images at the resolution of 3232 in = 10 classes, with 5,000 images per class for training and 1,000 for testing. We sample the training data so that the amount of training images varies from 50 to 1,000 per class.
To illustrate the advantage of our proposed loss, we also use the vanilla GAN goodfellow2014generative () (which adopt the -class loss), as well as the Improved GAN salimans2016improved () (which adopt the -class loss), as two additional baselines to augment samples. For the vanilla GAN, we train a separate generator for each class. For Improved GAN, we provide only the labeled training data without using any unlabeled data: a different and more challenging setting than evaluated in salimans2016improved (). They work with traditional data augmentation too, similarly to the DADA_augmented pipeline. For all compared methods, we generate samples so that the augmented dataset has 10 times the size of the given real labeled dataset.
Figure 2 summarizes the performance of the compared methods. The vanilla GAN augmentation performs slightly better than the no-augmentation baseline, but the worst in all other data augmentation settings. It concurs with santurkar2017classification () that, though GAN can generate visually pleasing images, it does not naturally come with increased data diversity from a classification viewpoint. While improved GAN achieves superior performance, DADA (without using traditional augmentation) is able to outperform it at the smaller end of sample numbers (less than 400 per class). Comparing with vanilla GAN, Improved GAN and DADA_augmented reveal that as the discriminator loss goes “more discriminative”, the data augmentation becomes more effective along the way.
Furthermore, DADA_augmented is the best performer among all, and consistently surpass all other methods for the full range of [50, 1000] samples per class. It leads to around 8 percent top-1 accuracy improvement in the 500 labeled sample, 10 class subset, without relying on any unlabeled data. It also raises the top-1 performance to nearly 80%, using only 10% of the original training set (i.e. 1000 samples per class), again with neither pre-training nor unlabeled data.
It is worth pointing out that the traditional data augmentation C_augmented presents a very competitive baseline here: it is next to DADA_augmented, and becomes slightly inferior to DADA when the labeled samples are less than 300 per class, but is constantly better than all others. Further, integrating traditional data augmentation contributes to the consistent performance boost from DADA to DADA_augmented. That testifies the value of empirical domain knowledge of invariance: they help considerably even learning-based augmentation is in place.
Finally, the comparison experiment is repeated on CIFAR-100. The results (see supplementary) are consistent with CIFAR-10, where DADA_augmented achieves the best results and outperforms traditional data augmentation for at least 6%, for all sample sizes. We also study the effects of DADA and traditional augmentation for deeper classifiers, such as ResNet-56 (see supplementary).
|# Samples per class|
SVHN is a digit recognition dataset, whose major challenge lies in that many images contain “outlier” digits but only the central digit is regarded as the target of recognition. As such, traditional data augmentation approaches such as translation or flipping may degrade training, and thus are excluded in this experiment. Table 2 summarizes the results of using the proposed DADA (without combining traditional augmentation) in comparison with Improved GAN and the naive baseline of no data augmentation. It can be observed that, at extremely low data regimes, DADA again performs the best among the three. However, when a relatively large number of labeled samples are available (500 per class), DADA witnesses a slight negative impact on the accuracy compared to the naive baseline, but is still better than Improved GAN. We conjecture that this failure case is attributed to the “outlier” digits occurring frequently in SVNH that might hamper class-conditional generative modeling. We plan to explore more robust generators as future work to alleviate this problem.
We notice the larger margin of DADA (without augmentation) over Improved GAN on SVNH, compared to CIFAR-10. We conjecture the reason to be that SVNH has complicated perturbations (e.g., distracting digits), while CIFAR-10 is much “cleaner” in that sense (objects always lie in central foregrounds without other distractions). Thus on SVNH, the class information used by DADA could become more important in supervising the generation of high quality augmented samples, without being affected by perturbations.
5 Experiments with Real-World Small Data
In this section, we discuss three real-data experiments which fall into extremely low data regimes. The data, not just labels, are difficult to collect and subject to high variability. We show that in these cases, the effects of transfer learning are limited, and/or even no ad-hoc data augmentation approach might be available to alleviate the difficulty to train deep networks. In comparison, DADA can be easily plugged in and boost the classification performance in all experiments.
5.1 Emotion Recognition from Facial Expressions: Comparison with Transfer Learning
Background and Challenge. Recognizing facial expressions is a topic of growing interests in the field of human-computer interaction. Among several public datasets in this field, the Karolinska Directed Emotional Faces (KDEF) dataset calvo2008facial () is a challenging benchmark consisting of rich facial variations (e.g., orientations, ethnicity, age, and gender), as well as relatively uniform distribution of the emotion classes. It has a total of 4,900 facial images collected from 70 individuals, displaying seven different facial expressions (happiness, fear, anger, disgust, surprise, sadness, neutrality). For each individual, the same expression is displayed twice and captured from 5 different angles. We choose images from the straight upfront angle in the first-time display only, forming a subset of 490 images for a 7-class classification problem. That certainly places us in an extremely low data regime.
Results and Analyses. We use a random 5:2 split for the training and testing sets and pre-process the images by cropping and resizing the face regions to the resolution at 224224. We choose a VGG-16 model simonyan2014very () pre-trained on ImageNet as a baseline, which is re-trained and then tested on KDEF. We do not perform any traditional data augmentation, since each image is taken in a strictly-controlled setting. The baseline could be viewed as a transfer learning solution with ImageNet as the source domain. We then treat the pre-trained VGG-16 model as our classifier in DADA, and append it with an augmenter network (whose configuration is detailed in the supplementary). While the pre-trained VGG baseline gives rise to an accuracy of 82.86%, DADA obtains a higher accuracy of 85.71%.
We also train vanilla GAN and Improved-GAN on this dataset, and have them compare with DADA in the similar fair setting as in CIFAR-10. The vanilla GAN augmentation ends up with 83.27% and Improved-GAN gets 84.03%: both outperform transfer learning but stay inferior to DADA.
Transfer learning is often an effective choice for problems short of training data. But their effectiveness is limited when there are domain mismatches, even it is widely believed that ImageNet pre-trained models are highly transferable for most tasks. In this case, the gap between the source domain (ImageNet, general natural images) and the target domain (KDEF, facial images taken in lab environments) cannot be neglected. We advocate that learning-based data augmentation could boost the performance further on top of transfer learning, and their combination is more compelling.
5.2 Brain Signal Classification: No Domain Knowledge Can be Specified for Augmentation
Background. The classification of brain signals has found extensive applications in brain-computer interface, entertainment and rehabilitation engineering leeb2007brain (). Among various tasks, the electroencephalographic (EEG) signal classification problem has been widely explored. Existing approaches include band power method brodu2011comparative (), multivariate adaptive autoregressive (MVAAR) method anderson1998multivariate (), and independent component analysis (ICA) hung2005recognition (). Recent works ren2014convolutional (); tabar2016novel () have explored CNNs in classifying EEG signals. However, the performance boost has been largely limited by the availability of labeled data. For example, the commonly used benchmark dataset 2b, a subset from the BCI Competition IV training set schlogl2003outcome (), includes only 400 trials. After several domain-specific pre-processing steps, each sample could be re-arranged into a image, where comes from the three EEG channels recorded (C3, Cz, and C4). They are collected from three sessions of motor imagery task experiments, and are to be classified into two classes of motions: right and left hand movements. We thus have a practical binary classification problem in extremely low data regimes.
Challenge. Unlike image classification problems discussed above, no straightforward knowledge-based, label-preserving augmentation has been proposed for EEG signals, nor has any data augmentation been applied in previous EEG classification works ren2014convolutional (); tabar2016novel () to our best knowledge. Also, the noisy nature of brain signals discourages to manually add more perturbations. The major bottleneck for collecting EEG classification datasets lies in the expensive controlled data collection process itself, rather than the labeling (since subjects are required to perform designated motions in a monitored lab environment, the collected EEG signals are naturally labeled). Besides, the high variability of human subjects also limit the scope of transfer learning in EEG classification. The multi-fold challenges make EEG classification an appropriate user case and testbed for our proposed DADA approach.
Results and Analyses. We follow tabar2016novel () to adopt the benchmark dataset 2b from BCI Competition IV training set schlogl2003outcome (). We train and test classification models, as well as DADA models separately for each of the nine subjects. We randomly select 90% of 400 trials for training and the remaining 10% for testing, and report the average accuracy of 10 runs. We treat each EEG input as a “color image” and adopt a mostly similar DADA model architecture as used for CIFAR-10 (except for changing class number)111Note that in the same channel of an EEG input, differently from a natural image, the signal coherence between vertical neighborhood (i.e., among different frequencies) is less than that between horizontal neighborhood (i.e., among different time stamps). The standard 2-D CNN is an oversimplified model here and could be improved by considering such anisotropy, which is the theme of our future work.. We include three baselines reported in tabar2016novel () for comparison: directly classifying the inputs by SVM; a shallow CNN with one convolutional and one fully-connected layers (CNN); and a deeper CNN with one convolutional layer, concatenated with seven fully-connected layers pre-trained using stacked auto-encoder (CNN + SAE). Table 3 shows the performance advantage of DADA over the competitive CNN-SAE method in all nine subjects, with an average accuracy margin of 1.7 percent.
|Method||Sub. 1||Sub. 2||Sub. 3||Sub. 4||Sub. 5||Sub. 6||Sub. 7||Sub. 8||Sub. 9||Average|
5.3 Tumor Classification: Comparison with Other Learning-based Augmentation
In the existing learning-based data augmentation work Tanda ratner2017learning (), most training comes with the help of unlabeled data. One exception we noticed is their experiment on the Curated Breast Imaging Subset of the Digital Database for Screening Mammography (CBIS-DDSM) clark2013cancer (); heath2000digital (); lee2016curated (), a medical image classification task whose data is expensive to collect besides labeling. Since both Tanda and DADA use the only available labeled dataset to learn data augmentation, we are able to perform a fair comparison on CBIS-DDSM between the two.
We follow the same configuration of the classifier used for CBIS-DDSM by Tanda: a four-layer all-convolution CNN with leaky ReLUs and batch normalization. We resize all medical images to 224 224. Note that Tanda heavily relies on hand-crafted augmentations: on DDMS, it uses many basic heuristics (crop, rotate, zoom, etc.) and several domain-specific transplantations. For DADA_augmented, we apply only rotation, zooming, and contrast as the traditional augmentation pre-processing, to be consistent with the user-specified traditional augmentation modules in Tanda. We compare DADA and DADA_augmented with two versions of Tanda using mean field (MF) and LSTM generators ratner2017learning (), with Table 4 showing the clear advantage of our approaches.
What differentiates DADA and Tanda? Tanda trains a generative sequence model over user-specified, knowledge-based transformation functions, while DADA is purely trained in a data-driven discriminative way. Unlike Tanda whose augmented samples always look like the naturalistic samples of each class, DADA may sometimes lead to augmented samples which are not visually close, but are optimized towards depicting the boundary between different classes. We display some “un-naturalistic” augmented samples found in the SVHN experiments in supplementary. Tanda also seems to benefit from the unlabeled data used in training, which ensures the transformed data points to be within the data distribution, while DADA can work robustly without unlabeled data (such as CBIS-DDSMF).
We present DADA, a learning-based data augmentation solution for training deep classifiers in extremely low data regimes. We leverage the power of GAN to generate new training data that both bear class labels and enhance diversity. A new loss is elaborated for DADA and verified to boost the performance. We perform extensive simulations as well as three real-data experiments, where results all endorse the practical advantage of DADA. We anticipate that DADA can be applied into many real-world tasks, including satellite, military, and biomedical image/data classification.
- (1) Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton. Imagenet classification with deep convolutional neural networks. In NIPS, pages 1097–1105, 2012.
- (2) Luis Perez and Jason Wang. The effectiveness of data augmentation in image classification using deep learning. arXiv preprint arXiv:1712.04621, 2017.
- (3) Song Han, Jeff Pool, John Tran, and William Dally. Learning both weights and connections for efficient neural network. In NIPS, pages 1135–1143, 2015.
- (4) Dumitru Erhan, Yoshua Bengio, Aaron Courville, Pierre-Antoine Manzagol, Pascal Vincent, and Samy Bengio. Why does unsupervised pre-training help deep learning? Journal of Machine Learning Research, 11(Feb):625–660, 2010.
- (5) Rajat Raina, Alexis Battle, Honglak Lee, Benjamin Packer, and Andrew Y Ng. Self-taught learning: Transfer learning from unlabeled data. In ICML, pages 759–766, 2007.
- (6) Diederik P Kingma, Shakir Mohamed, Danilo Jimenez Rezende, and Max Welling. Semi-supervised learning with deep generative models. In NIPS, pages 3581–3589, 2014.
- (7) Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In NIPS, pages 2672–2680, 2014.
- (8) Mehdi Mirza and Simon Osindero. Conditional generative adversarial nets. arXiv preprint arXiv:1411.1784, 2014.
- (9) Augustus Odena. Semi-supervised learning with generative adversarial networks. arXiv preprint arXiv:1606.01583, 2016.
- (10) Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, and Xi Chen. Improved techniques for training GANs. In NIPS, pages 2234–2242, 2016.
- (11) Zihang Dai, Zhilin Yang, Fan Yang, William W Cohen, and Ruslan R Salakhutdinov. Good semi-supervised learning that requires a bad GAN. In NIPS, pages 6513–6523, 2017.
- (12) Alec Radford, Luke Metz, and Soumith Chintala. Unsupervised representation learning with deep convolutional generative adversarial networks. arXiv preprint arXiv:1511.06434, 2015.
- (13) Sanjeev Arora and Yi Zhang. Do GANs actually learn the distribution? an empirical study. arXiv preprint arXiv:1706.08224, 2017.
- (14) Shibani Santurkar, Ludwig Schmidt, and Aleksander Madry. A classification-based perspective on GAN distributions. arXiv preprint arXiv:1711.00970, 2017.
- (15) Tsung-Yu Lin, Aruni RoyChowdhury, and Subhransu Maji. Bilinear CNN models for fine-grained visual recognition. In ICCV, pages 1449–1457, 2015.
- (16) Akshay Mehrotra and Ambedkar Dukkipati. Generative adversarial residual pairwise networks for one shot learning. arXiv preprint arXiv:1703.08033, 2017.
- (17) Lei Shu, Hu Xu, and Bing Liu. Lifelong learning crf for supervised aspect extraction. arXiv preprint arXiv:1705.00251, 2017.
- (18) Bo Liu, Ying Wei, Yu Zhang, and Qiang Yang. Deep neural networks for high dimension, low sample size data. In IJCAI, pages 2287–2293, 2017.
- (19) Pascal Vincent, Hugo Larochelle, Isabelle Lajoie, Yoshua Bengio, and Pierre-Antoine Manzagol. Stacked denoising autoencoders: Learning useful representations in a deep network with a local denoising criterion. Journal of Machine Learning Research, 11(Dec):3371–3408, 2010.
- (20) Dong-Hyun Lee. Pseudo-label: The simple and efficient semi-supervised learning method for deep neural networks. In Workshop on Challenges in Representation Learning, ICML, volume 3, page 2, 2013.
- (21) Antti Rasmus, Mathias Berglund, Mikko Honkala, Harri Valpola, and Tapani Raiko. Semi-supervised learning with ladder networks. In NIPS, pages 3546–3554, 2015.
- (22) Raimar Wagner, Markus Thom, Roland Schweiger, Gunther Palm, and Albrecht Rothermel. Learning convolutional neural networks from few samples. In IJCNN, pages 1–7, 2013.
- (23) Alexander J Ratner, Henry Ehrenberg, Zeshan Hussain, Jared Dunnmon, and Christopher Ré. Learning to compose domain-specific transformations for data augmentation. In Advances in Neural Information Processing Systems, pages 3239–3249, 2017.
- (24) Toan Tran, Trung Pham, Gustavo Carneiro, Lyle Palmer, and Ian Reid. A bayesian data augmentation approach for learning deep models. In Advances in Neural Information Processing Systems, pages 2797–2806, 2017.
- (25) Søren Hauberg, Oren Freifeld, Anders Boesen Lindbo Larsen, John Fisher, and Lars Hansen. Dreaming more data: Class-dependent distributions over diffeomorphisms for learned data augmentation. In Artificial Intelligence and Statistics, pages 342–350, 2016.
- (26) Antreas Antoniou, Amos Storkey, and Harrison Edwards. Data augmentation generative adversarial networks. arXiv preprint arXiv:1711.04340, 2017.
- (27) Jun-Yan Zhu, Taesung Park, Phillip Isola, and Alexei A Efros. Unpaired image-to-image translation using cycle-consistent adversarial networks. arXiv preprint arXiv:1703.10593, 2017.
- (28) Zhangyang Wang, Jianchao Yang, Hailin Jin, Eli Shechtman, Aseem Agarwala, Jonathan Brandt, and Thomas S Huang. Deepfont: Identify your font from an image. In Proceedings of the 23rd ACM international conference on Multimedia, pages 451–459. ACM, 2015.
- (29) Tuan Anh Le, Atilim Giineş Baydin, Robert Zinkov, and Frank Wood. Using synthetic data to train neural networks is model-based reasoning. In IJCNN, pages 3514–3521, 2017.
- (30) Leon Sixt, Benjamin Wild, and Tim Landgraf. Rendergan: Generating realistic labeled data. arXiv preprint arXiv:1611.01331, 2016.
- (31) Ashish Shrivastava, Tomas Pfister, Oncel Tuzel, Josh Susskind, Wenda Wang, and Russ Webb. Learning from simulated and unsupervised images through adversarial training. In CVPR, volume 3, page 6, 2017.
- (32) Xinlong Wang, Zhipeng Man, Mingyu You, and Chunhua Shen. Adversarial generation of training examples: Applications to moving vehicle license plate recognition. arXiv preprint arXiv:1707.03124, 2017.
- (33) Max Jaderberg, Karen Simonyan, Andrea Vedaldi, and Andrew Zisserman. Synthetic data and artificial neural networks for natural scene text recognition. arXiv preprint arXiv:1406.2227, 2014.
- (34) Li Wan, Matthew Zeiler, Sixin Zhang, Yann Le Cun, and Rob Fergus. Regularization of neural networks using dropconnect. In ICML, pages 1058–1066, 2013.
- (35) Manuel G Calvo and Daniel Lundqvist. Facial expressions of emotion (KDEF): Identification under different display-duration conditions. Behavior Research Methods, 40(1):109–115, 2008.
- (36) Karen Simonyan and Andrew Zisserman. Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556, 2014.
- (37) Robert Leeb, Felix Lee, Claudia Keinrath, Reinhold Scherer, Horst Bischof, and Gert Pfurtscheller. Brain–computer communication: Motivation, aim, and impact of exploring a virtual apartment. IEEE Transactions on Neural Systems and Rehabilitation Engineering, 15(4):473–482, 2007.
- (38) Nicolas Brodu, Fabien Lotte, and Anatole Lécuyer. Comparative study of band-power extraction techniques for motor imagery classification. In IEEE Symposium on Computational Intelligence, Cognitive Algorithms, Mind, and Brain (CCMB), 2011.
- (39) Charles W Anderson, Erik A Stolz, and Sanyogita Shamsunder. Multivariate autoregressive models for classification of spontaneous electroencephalographic signals during mental tasks. IEEE Transactions on Biomedical Engineering, 45(3):277–286, 1998.
- (40) Chih-I Hung, Po-Lei Lee, Yu-Te Wu, Li-Fen Chen, Tzu-Chen Yeh, and Jen-Chuen Hsieh. Recognition of motor imagery electroencephalography using independent component analysis and machine classifiers. Annals of Biomedical Engineering, 33(8):1053–1070, 2005.
- (41) Yuanfang Ren and Yan Wu. Convolutional deep belief networks for feature extraction of EEG signal. In IJCNN, 2014.
- (42) Yousef Rezaei Tabar and Ugur Halici. A novel deep learning approach for classification of EEG motor imagery signals. Journal of Neural Engineering, 14(1):016003, 2016.
- (43) Alois Schlögl. Outcome of the BCI-competition 2003 on the Graz data set. Technical report, Graz University of Technology, 2003.
- (44) Kenneth Clark, Bruce Vendt, Kirk Smith, John Freymann, Justin Kirby, Paul Koppel, Stephen Moore, Stanley Phillips, David Maffitt, Michael Pringle, et al. The cancer imaging archive (tcia): maintaining and operating a public information repository. Journal of digital imaging, 26(6):1045–1057, 2013.
- (45) M Heath, K Bowyer, D Kopans, R Moore, and P Kegelmeyer. The digital database for screening mammography. Digital mammography, pages 431–434, 2000.
- (46) R Sawyer Lee, F Gimenez, A Hoogi, and D Rubin. Curated breast imaging subset of ddsm. The Cancer Imaging Archive, 2016.
Appendix A Training Algorithm Details
Due to the different roles that the classifier have to play, the training procedure of DADA is divided into two different phases. In training phase I, which we call Generation training, the classifier and the augmenter compete with each other like in the vanilla GAN. The difference is that their competition is conditioned on the specific class, rather than the whole data set. The augmenter attempts to generate realistic data to cheat the classifier within a specific class, while the classifier endeavors to distinguish the fake data from the real within a specific class. The game between the two players will have its optimum only if . Thus, the optimal classifier has , indicating that the augmenter is trained well enough so that the classifier can not discriminate them.
Similar to the vanilla GAN formulation, the loss functions of the augmenter and the classifier in training phase I are:
Based on the observation of the Improved-GAN that the feature matching technique can help improve the classification performance of the generated samples, we make some modifications on this training strategy. Because we desire not only the features of the training data and the generated data to be as similar as possible, but also the training data and the generated data to be within a specific class, the modified version of feature matching is formulated as:
Here denotes activations on an intermediate layer of the classifier. We keep the same to the true data label. With the regularization of feature matching, the objective function of generator in training phase I is hence:
Once the training phase I is finished, assuming that the generator can capture the class-wise data distribution, then it comes to the training phase II called Classification training. In this phase, the generator is fixed just as a data provider. We only train the classifier on the generated data and the real training data. The loss function of the classifier in training phase II can be written as:
The entire training procedure is summarized in Algorithm 1, where the two training stages are described as two for-loops.
Appendix B Examples of Generated Samples
For visual inspection, Figure 3 depicts a few samples of real images as well as the generated samples by DADA (by interpolating latent codes). In these figures, the CIFAR-10 and SVHN examples correspond to the extremely low data regimes (200 per class). Therefore, we do not anticipate the generated samples to contain large visual variety. Recall that our objective is to augment data for training a better classifier, the generated images can be observed to faithfully belong to the classes that they are conditioned on, and present certain variances, which fulfills our goal.
Figure 4 displays the diversity of generated samples as well as some “un-naturalistic” examples found in the SVHN experiments. For example, in the fourth row (class “3”), several examples look like “8”; also in the sixth row (class “5”), one might find examples in part resembling “6” or “8”. Those “un-naturalistic” examples are obviously close to boundaries between confusing digit classes.
Appendix C CIFAR-10 Experiments with Deeper Classifiers
How far we can go with DADA to train even deeper classifiers in extremely low-data regimes? To test that, we replace the default classifier in our CIFAR-10 experiment with a ResNet-56 model, and train it with DADA. Table 5 compares DADA_augmented with traditional augmentation at = 500 and . DADA_augmented remains superior, and all results show improvements over the original numbers in Figure 2 (main text). However, for less than 200, both methods show severe overfitting and their performance degrade fast. That reminds us that deeper classifers may not be the right choice when the data and label are both too limited.
Appendix D CIFAR-100 Experiments
CIFAR-100 is an extended and more difficult version of CIFAR-10, containing 600 images (500 for training and 100 for testing) per each of the 100 classes. Similar to the results of CIFAR-10, DADA boosts the classification accuracy, and at the small sample choices DADA (without augmentation) has a clear advantage over traditional augmentation. Their combination DADA_augmented achieves the best results and outperforms the others significantly (more than 6 percent compared to only using traditional data augmentation); see in Table 6.
|# Samples per class|
|# Samples per class|
Appendix E Network architecture
Here we give the detailed configurations and hyperparameters of our models. Table 7 shows the details of CIFAR-10 and SVHN. The configurations of CIFAR-100 and EEG are almost the same as those of CIFAR-10 and SVHN except that the numbers of classes are different. The classifier in the KDEF experiment follows the VGG-16 architecture except that we append a weight normalization layer after each convolution layer. Table 8 shows the details of KDEF generator (augmenter).
|T-Conv||5 x 5||2 x 2||256||✓||0.0||ReLU|
|T-Conv||5 x 5||2 x 2||128||✓||0.0||ReLU|
|T-Conv||5 x 5||2 x 2||3||✓||0.0||Tanh|
|Conv||3 x 3||1 x 1||96||✓||0.0||LReLU|
|Conv||3 x 3||1 x 1||96||✓||0.0||LReLU|
|Conv||3 x 3||2 x 2||96||✓||0.5||LReLU|
|Conv||3 x 3||1 x 1||192||✓||0.0||LReLU|
|Conv||3 x 3||1 x 1||192||✓||0.0||LReLU|
|Conv||3 x 3||2 x 2||192||✓||0.5||LReLU|
|Conv||3 x 3||1 x 1||192||✓||0.0||LReLU|
|T-Conv||5 x 5||2 x 2||512||✓||0.0||ReLU|
|T-Conv||5 x 5||2 x 2||256||✓||0.0||ReLU|
|T-Conv||5 x 5||2 x 2||256||✓||0.0||ReLU|
|T-Conv||5 x 5||2 x 2||128||✓||0.0||ReLU|
|T-Conv||5 x 5||2 x 2||3||✓||0.0||Tanh|