Adversarial Policy Gradient for Deep Learning Image Augmentation††thanks: Funded by National Institute of Arthritis and Musculoskeletal and Skin Diseases.
The use of semantic segmentation for masking and cropping input images has proven to be a significant aid in medical imaging classification tasks by decreasing the noise and variance of the training dataset. However, implementing this approach with classical methods is challenging: the cost of obtaining a dense segmentation is high, and the precise input area that is most crucial to the classification task is difficult to determine a-priori. We propose a novel joint-training deep reinforcement learning framework for image augmentation. A segmentation network, weakly supervised with policy gradient optimization, acts as an agent, and outputs masks as actions given samples as states, with the goal of maximizing reward signals from the classification network. In this way, the segmentation network learns to mask unimportant imaging features. Our method, Adversarial Policy Gradient Augmentation (APGA), shows promising results on Stanford’s MURA dataset and on a hip fracture classification task with an increase in global accuracy of up to 7.33% and improved performance over baseline methods in 9/10 tasks evaluated. We discuss the broad applicability of our joint training strategy to a variety of medical imaging tasks.
Keywords:Deep Reinforcement Learning Adversarial Training Semantic Segmentation Image Augmentation
Convolutional neural networks (CNNs) have become an essential part of medical image acquisition, reconstruction, and post-processing pipelines, as this technology significantly improves our ability to detect, study, and predict diseases at scale. In computer-vision, CNNs have achieved above-human performance in natural-object classification tasks . However, in medical imaging, where datasets are limited in size and labels are often uncertain, there is still a significant need for methods that maximize information gain while preventing overfitting. Our work presents a novel reinforcement-learning (RL) based image augmentation framework for medical image classification.
Training data image augmentation imposes image-level regularization on CNNs in order to combat overfitting. Augmentation can include the addition of noise, image transformations such as zooming or cropping, image occlusion, and attention masking . The application of the first three methods is limited as they often rely on domain-knowledge to define the appropriate characteristics and severity of the augmentation. The last method requires dense segmentation masks for the region of interest (ROI). However, ROIs relevant to a classification task may not be known a-priori. For instance, when inspecting hip radiographs for bone fracture (Fx), the determination of Fx or no-Fx heavily depends on the location of abnormal signal within the bone and abnormalities from nearby tissues. Our RL image augmentation framework leverages an adversarial reward to weakly supervise a segmentation network and create these ROIs.
Through trial-and-error, reinforcement learning algorithms discover how to maximize their objective or reward R. The careful design of reward functions has enabled the application of RL to multiple medical tasks including landmark detection, automatic view planning, treatment planning, and MR/CT reconstruction. In this work, we present APGA, a joint, reinforcement learning strategy for training of a segmentation network and a classification network, that improves accuracy in the classification task.
2.1 Improving Classification with Segmentation for Image Masking
The framework has two parts: a classification model with parameters and a segmentation model with parameters (policy) .
To mask out the image-level features that are less useful for the classification task, we use the segmentation model to produce the pixel-wise probability of the pixel being useful. We zero out the pixels of the original image with , and use it for training the classification model, updating . With the end-goal of improving the classification performance, our method optimizes the segmentation model to evaluate the importance of each image pixel.
2.2 Policy Gradient Training
Following the policy gradient context , our segmentation model is seen as a policy in which the image batch is treated as a state , whereas the pixel-wise classification of useful image features is framed as an action . In practice, at each -th training step, the classification model receives the masked image as the input and outputs a reward signal . To accomplish this, our objective is to maximize the expected reward and find the optimal segmentation policy.
In eq. 1, is the expected reward with respect to the probability of taking an action , when the model has been parameterized with . The policy is learned through back-propagation, which requires the definition of the gradient of the expected reward with respect to the model parameters. Following the REINFORCE rule presented in , the gradient can be defined as
The expected reward cannot be estimated and requires approximation. As is common practice in the tractation of policy gradient, we can achieve such approximation using the negative log-likelihood loss, which is differentiable with respect to the model parameters, and can be properly weighted by the reward signal to obtain the segmentation policy loss presented in eq. 3
where is the binary cross-entropy loss
which becomes eq. 5.
By using pixel-wise binary cross entropy, we can achieve preservation of spatial information of the deviation of from . We then update by computing . Consequently, the classification model parameters are updated with gradient descent by using the cross entropy loss between the classification of the masked image samples and the original target labels. In our experiments, we perform stochastic gradient update for both and at each batch step.
2.3 Adversarial Reward
The design of the reward is crucial to the convergence of the segmentation model. Using the change in training loss as a reward, as is done in Neural Architecture Search , results in a weak reward signal hardly discernible from the expected changes in loss during training. Similarly, approximating rewards with a critic network introduces unnecessary overhead and slows down convergence. We propose a stable adversarial reward . Given pixel-wise feature importance probability , we zero out the pixels to mask-out the features predicted to be of high importance. The original and masked image batches are then fed as inputs to the classification model producing the losses and . The reward function is defined as:
To reduce the variance of training, a baseline , the exponential moving average of the reward, is included, similarly done in . Intuitively, by erasing the important features we revert the problem and tend to maximize the gain in loss. However, we do not want the segmentation policy to erase all pixels in favor of a gain in , so we penalize the masking of all pixels. Given pixel-wise all zero feature importance and a weight for regularization, the final loss is defined as:
The resulting reward signal is strongly related to mask quality, rather than reflecting the stochasticity in training of the classification network.
3 Experiments and Results
We evaluate our methodology on MURA  and an internal hip fracture dataset, using the same experimental setup, including network architectures, RL framework, and training hyperparameters. A DenseNet-169  pretrained on ImageNet  serves as the base classification model. A TernausNet , pretrained on Carvana Image Masking Challenge , serves as the segmentation model. Masked images are used as augmentation in a ratio with original images to train the classification network. Images resized to , batch size 25. Adam optimizers  with a initial learning rate of 0.0001 are used for the classification and the segmentation model. The exponential average baseline has a decay rate of 0.5. is set to 0.1. Training of APGA converges within 30 minutes on a single Nvidia TitanX GPU. Source code available at https://github.com/victorychain/Adversarial-Policy-Gradient-Augmentation.
We benchmark APGA using a DenseNet-169  classifier trained (1) without data augmentation, (2) with cutout  augmentation in a 1:1 ratio, and with (3) GradCam  derived masks augmentation also in a 1:1 ratio. Cutout augmentation masks out randomly sized patches of the input image while GradCam masks are produced by discretizing the probability saliency map from the DenseNet trained without data augmentation. For further comparison, a segmentation and classification network are trained end-to-end, by propagating the gradient from the classification loss function through the segmentation network and applying the discretized masks from the segmentation network in the same update step. Additionally, regularization terms and , and , are added to the loss function to prevent all or none masking behavior. However, end-to-end training was unstable, and the segmentation network produced all-one or all-zero masks, despite tuning of and . Therefore, these results were omitted. At its best, the end-to-end network produced all-one masks and performed the same as the DenseNet trained without augmentation.
3.1 Binary Classification: MURA
The MURA  dataset contains 14,863 musculoskeletal studies of elbows, finger, forearm, hand, humerus, shoulder, and wrist, which contains 9,045 normal and 5,818 abnormal labeled cases. We train the methods on the public training set and evaluate on the validation set, with global accuracy as the metric. We train and evaluate separate models on each body part, and train a single model on a random sample of 100 training images per class to test the performance of our method under extreme data constraints. The performance on the validation set is presented in Table 2, Table 2 as average and standard deviation of 5 random seeds.
3.2 Multi-class Classification: Hip Fracture
The Hip Fracture dataset contains 1118 studies with an average patient age of 74.6 years (standard deviation 17.3), and a female:male ratio. Each study includes a pelvic radiograph, labeled as 1 of 6 classes: No fracture, Intertrochanteric fracture, Displaced femoral neck fracture, Non-displaced femoral neck fracture, Arthroplasty, or ORIF (previous internal fixation). Bounding boxes are manually drawn on each study, resulting in 3034 bounded hips. The images are split by accession number into train:valid:test using a split, ensuring no overlap in patients between any of the sets. We train and evaluate separate models on the whole pelvic radiographs and the bounded hip radiographs. Per-image accuracy is used as the metric. The performance on the validation and test set is shown in Table 3.
|Whole Pelvis (val)|
|Whole Pelvis (test)|
|Bounded Hip (val)|
|Bounded Hip (test)|
Compared to the baseline, our method achieved higher global accuracy in 9 out of 10 tasks including binary (MURA Table 2) and multi-class (hip Fx Table 3) classification tasks. On average, our method improved MURA validation accuracy by 1.56% and hip validation and testing accuracy by 0.78% and 1.72% respectively. The most significant improvement in accuracy over the baseline was 7.33% and it was achieved in a data-constrained condition, reported in Table 2. In this particular experiment, the elbow training data was limited to 100 samples per class. Overall, APGA outperformed baseline methods in 9 out of 10 tasks, and consistently provided higher testing results. Example segmentation masks from the weakly supervised network are shown in Fig 2. APGA learns to ignore unimportant features in the radiographs, such as anatomy irrelevant to the classification task. APGA masking appears more exploratory in nature compared to saliency based attention masking (DenseNet + GradCam), which contains biases from the converged model.
4 Discussions and Conclusions
We propose a framework, APGA, for producing segmentations to aid medical image classification in a reinforcement learning setting. This framework requires no manual segmentation, which has the benefit of scalability and generalizability. The system is trained online with the goal of improving the performance of the main task, classification. If no improvement is seen, this can be a check for the assumption that masking based augmentation would aid classification, before pursuing more manual work. Marginal improvements should be evidence that APGA has the potential to add valuable information to the training process. The computational overhead in training is justified by those added benefits, and could be eliminated during inference, as the segmentation network can also be used as an inference augmentation technique. This general reinforcement learning with adversarial reward framework could easily be adopted for other medical imaging tasks, involving regression, and segmentation, with different aiding methods, such as bounding box detection, image distortion, and image generation. The reinforcement guided data augmentation has more generalizability compared to traditional data augmentation based on domain knowledge.
-  Carvana Image Masking Challenge. External Links: Cited by: §3.
-  ImageNet: A Large-Scale Hierarchical Image Database. pp. 8 (en). Cited by: §3.
-  (2017-08) Improved Regularization of Convolutional Neural Networks with Cutout. arXiv:1708.04552 [cs]. Note: arXiv: 1708.04552 External Links: Cited by: §3.0.1.
-  (2016-08) Densely Connected Convolutional Networks. arXiv:1608.06993 [cs]. Note: arXiv: 1608.06993Comment: CVPR 2017 External Links: Cited by: §3.0.1, Table 3, §3.
-  (2018-01) TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation. arXiv:1801.05746 [cs]. Note: arXiv: 1801.05746Comment: 5 pages, 4 figures External Links: Cited by: §3.
-  (2014-12) Adam: A Method for Stochastic Optimization. arXiv:1412.6980 [cs]. Note: arXiv: 1412.6980Comment: Published as a conference paper at the 3rd International Conference for Learning Representations, San Diego, 2015 External Links: Cited by: §3.
-  (2017-12) MURA: Large Dataset for Abnormality Detection in Musculoskeletal Radiographs. arXiv:1712.06957 [physics]. Note: arXiv: 1712.06957Comment: 1st Conference on Medical Imaging with Deep Learning (MIDL 2018) External Links: Cited by: §3.1, §3.
-  (2014-09) ImageNet Large Scale Visual Recognition Challenge. arXiv:1409.0575 [cs]. Note: arXiv: 1409.0575Comment: 43 pages, 16 figures. v3 includes additional comparisons with PASCAL VOC (per-category comparisons in Table 3, distribution of localization difficulty in Fig 16), a list of queries used for obtaining object detection images (Appendix C), and some additional references External Links: Cited by: §1.
-  (2016-10) Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization. arXiv:1610.02391 [cs]. Note: arXiv: 1610.02391Comment: 24 pages, 22 figures. Adds bias experiments, and robustness to adversarial noise External Links: Cited by: §3.0.1.
-  (2017-09) Attentional masking for pre-trained deep networks. In 2017 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS), pp. 6149–6154. External Links: Cited by: §1.
-  (1992-05) Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine Learning 8 (3), pp. 229–256 (en). External Links: Cited by: §2.2.
-  (2016-11) Neural Architecture Search with Reinforcement Learning. arXiv:1611.01578 [cs]. Note: arXiv: 1611.01578 External Links: Cited by: §2.3.