Multi-Task Attention-Based Semi-Supervised Learning for Medical Image Segmentation
We propose a novel semi-supervised image segmentation method that simultaneously optimizes a supervised segmentation and an unsupervised reconstruction objectives. The reconstruction objective uses an attention mechanism that separates the reconstruction of image areas corresponding to different classes. The proposed approach was evaluated on two applications: brain tumor and white matter hyperintensities segmentation. Our method, trained on unlabeled and a small number of labeled images, outperformed supervised CNNs trained with the same number of images and CNNs pre-trained on unlabeled data. In ablation experiments, we observed that the proposed attention mechanism substantially improves segmentation performance. We explore two multi-task training strategies: joint training and alternating training. Alternating training requires fewer hyperparameters and achieves a better, more stable performance than joint training. Finally, we analyze the features learned by different methods and find that the attention mechanism helps to learn more discriminative features in the deeper layers of encoders.
Keywords:semi-supervised learning multi-task learning attention deep learning segmentation brain tumor white matter hyperintensities.
Semi-supervised learning (SSL) uses unlabeled data to improve the generalization performance of a supervised model. This can be useful in medical image segmentation, where manual annotations can be expensive and tedious to produce and are often only available for a small subset of the training data.
One approach to semi-supervised learning is multi-task learning in which the network is trained with an auxiliary objective requiring no manually labeled data, in addition to the target objective using labeled data. This can be done by including an additional autoencoder objective and has been used for image classification (e.g., [3, 8]). Sedai et al.  introduced variational autoencoder into semi-supervised segmentation task for the first time, where they train a segmentation autoencoder by learning the encoded embeddings from another pre-trained reconstruction autoencoder and reconstructing the segmentation mask. However, multi-task learning with image reconstruction is not trivial to combine with popular image segmentation architectures like U-Net  and its variants [2, 6], which use skip-connections to preserve high-resolution information from their early encoder layers. These skip-connections are not suitable in combination with an autoencoder as the auxiliary task, because they allow the network to copy information from early layers and skip the dimensionality reduction in the autoencoder.
Another semi-supervised approach is creating new pseudo labels for the unlabeled training data, such as self-training  and co-training [12, 15], to enlist more available training resources. However, the created pseudo labels usually do not have the same quality as the ground truth for the target segmentation objective, which limits their potential for improvements from unlabeled data.
We propose a novel semi-supervised method called Multi-task Attention-based Semi-Supervised Learning (MASSL), in which we combine an autoencoder with a U-Net-like network. Instead of training it to reconstruct the original input , we train the autoencoder to reconstruct synthetic segmentation labels created by the attention mechanism. This encourages our model to learn discriminative features for segmentation from unlabeled images. Although attention is very often applied to supervised learning (e.g., ), to our best knowledge, it has never been combined with semi-supervised learning. Our method has some similarities with self-training  and co-training [12, 15], which also create new labels for the unlabeled training data on-the-fly. In contrast to these methods, our method creates labels for the reconstruction task. This guides the unsupervised auxiliary task to learn a more discriminative latent representation from unlabeled data than that learned by the traditional reconstruction network, which does not consider class differences.
Our contributions are summarized as follows. Firstly, we propose a novel multi-task semi-supervised learning method and study its performance in combination with two training strategies. Secondly, we evaluate our method on two segmentation problems (brain tumors and white matter hyperintensities), demonstrating that it outperforms a fully-supervised CNN baseline, two pre-training approaches, and multi-task learning without the proposed attention mechanism. Thirdly, we investigate how the attention mechanism affects the features learned by the encoder and show that it helps the deeper layers to learn more discriminative features.
Our semi-supervised learning method is shown in Fig. 1. It consists of a segmentation and a reconstruction networks sharing the same encoder, and an attention mechanism connecting the two tasks.
2.1 Architecture and Loss Functions
The segmentation CNN architecture, similarly to U-Net , has skip-connections, allowing the transfer of fine details from shallower layers of the encoder to its decoder, and is trained using Dice objective on labeled images only (see Fig. 1). The reconstruction network has a decoder without skip-connections, resulting in an autoencoder, and is trained using mean squared error (MSE) on both labeled and unlabeled images.
In the baseline version of our method, the output of the reconstruction network is optimized to predict the input image. We call this method Multi-task SSL (MSSL) in the remainder of the paper.
In the attention-based version of our method, which we call Multi-task Attention-based SSL (MASSL), we reconstruct separately background and foreground parts of the image, as defined by the soft predictions obtained from the segmentation network. The foreground and background objectives are weighted by the size of the respective segmentation masks:
where and are the predictions of the reconstruction and segmentation paths, respectively, for the background () and foreground (); is the number of voxels in input image ; is element-wise product. Note that the gradient does not propagate through to the segmentation decoder. We hypothesize that infusing reconstruction labels with segmentation predictions will lead to learning better features in the deeper layers of the encoder and hence better segmentation. The objective terms are weighed to prevent over-emphasizing the importance of foreground reconstruction.
2.2 Training Strategy
The two tasks of the MSSL and MASSL networks can be optimized jointly or alternatingly:
Given a minibatch containing an equal number of labeled samples and unlabeled samples , the unlabeled samples are first segmented using the most recent segmentation network parameters, to create the foreground and background images for the reconstruction task. Then, the weights of the entire network are updated by optimizing the objective function of both segmentation and reconstruction tasks. The loss is a linear combination of segmentation and reconstruction losses controlled by the hyperparameter :
For each epoch, labeled and unlabeled images are randomly sampled by the same amount (the smaller amount of either labeled and unlabeled images) from their corresponding training sets. A minibatch contains either labeled samples or the same amount of unlabeled samples . The two types of batch are alternated during training. The weights of the segmentation path and reconstruction path are updated individually according to the given batch type and the corresponding loss:
BraTS18: 220 MRI scans from patients with high grade glioma are randomly split into 120, 50, 50 scans for training, validation and testing respectively, with 5-fold Monte Carlo cross-validation. To simplify comparison between the different segmentation tasks we perform binary classification and segment only the whole tumor, including all four tumor structures, and use only the FLAIR sequence.
WMH17: There are 60 FLAIR MRI scans provided with corresponding manual segmentations of white matter hyperintensities (WMH). The scans are acquired at three sites, 20 at each site. In our experiments, we use 30 scans for training, 10 for validation and 20 for testing, ensuring approximately equal numbers for each site in each of the three sets. We use 5-fold Monte Carlo cross-validation.
3.0.2 Network and hyperparameters
The network layout is shown in Fig. 1. Our network is inspired by the UNet  architecture but has several differences. The input size of the network is . There are 5 resolution levels in the encoder and in each of the decoders. Each level consists of two convolution layers using zero-padding, instance normalization  and LeakyReLU activation functions, except for the last layer of both decoders which use sigmoid to make the final prediction. There is an average pooling/upsampling layer between each level. The number of feature channels is 16 in the first level, which is doubled/halved after each pooling/upsampling to a maximum of 256 features at the deepest level. The feature maps in the segmentation upsampling path are concatenated with earlier ones through skip-connections. The reconstruction network has the same architecture as the segmentation network but does not have skip-connections. For joint training, we use one Adam optimizer to optimise the loss in Eq. 2. For alternating training, we use two individual Adam optimizers to optimize the two types of loss in Eq. 3 separately. Based on the performance on the validation sets, we set the initial learning rate to 0.01 and 0.001 for the segmentation and reconstruction tasks respectively. Random rotation, scaling, and horizontal flipping are applied as data augmentation.
3.0.3 Feature analysis
We use linear regression analysis to evaluate how well the features can discriminate between foreground and background regions in the last layer of every encoder level. We consider each voxel as an individual sample, using its values in each feature map as the regression variables. The label for each voxel is obtained by the taking binary segmentation ground truth and then down-sampling this with average pooling to the required resolution.
The segmentation results are shown in Table 1 and Table 2. For the semi-supervised setting (first two colomns), there is no overlap between labeled and unlabeled data. For the fully-supervised setting (last column), all the images are used as labeled and unlabeled data. For Pretrain(Dec) we pretrain the reconstruction network with unlabeled data first and then train the decoder path of the segmentation network with labeled data, while keeping the encoder part fixed to ensure that the segmentation task can only use the features learned from unlabeled images. For Pretrain(CNN) we pretrain the reconstruction network with unlabeled data first and then train the whole segmentation network using labeled data, which allows the network to fine-tune the encoder parameters if necessary. MASSL and MSSL are the proposed multi-task SSL methods with and without the attention mechanism, where and alter indicate joint training and alternating training respectively. For joint training, we tried and the network did not converge when . The results show that MASSL(alter) achieves the best segmentation performance of all methods. The joint training strategy achieved a slightly lower performance than alternating training, which also varied a lot between different labeled/unlabeled data splits, reflecting the instability of the joint training strategy and the difficulty of tuning .
|#Labeled (unlabeled)||20 (100)||50 (70)||120 (120)|
The results of the feature analysis are shown in Table 3. The higher scores indicate that the features learned with MASSL are more discriminative in the deeper levels than those of CNN and MSSL. This supports our hypothesis that the attention mechanism can make the deeper layers of the encoder learn more discriminative features while still also optimizing the reconstruction objective.
|#Labeled (unlabeled)||10 (20)||20 (10)||30 (30)|
5 Discussion and Conclusion
In this paper, we propose a new semi-supervised learning method called MASSL that combines a segmentation task and a reconstruction task through an attention mechanism in a multi-task learning network. The proposed method is evaluated on two applications. For both applications, MASSL using part of the labeled images outperforms the fully-supervised CNN baseline using the same number of labeled images, pretraining+finetuning methods, and the proposed approach without attention (MSSL). When using the segmentation and reconstruction loss for all images, MASSL also improves over baseline CNN, although this difference was only statistically significant for the BRATS data. This is mainly due to the sparse distribution of foreground in WMH data, which makes our attention maps less effective.
The improvement of our method mainly comes from the attention mechanism, which introduces the segmentation task into the reconstruction task and links them better than before. The mechanism can be easily integrated into any CNN architecture and generalized to multi-class segmentation. Compared with joint training, alternating training is a practical strategy that allows task-dependent variations in the learning rate and does not require fine-tuning , although one still needs to choose proper initial learning rates. Alternating training is not guaranteed to be stable because the encoder parameters change discontinuously between the two tasks. During experiments, we found that training was sufficiently stable when choosing a smaller initial learning rate for reconstruction than segmentation, and in most cases, the performance of the alternating optimization was much better than that of joint optimization.
When comparing different multi-task learning strategies, we made some simplifications. For the pretraining method, unlike Sedai et al. , we use a regular autoencoder rather than a variational autoencoder (VAE) in this paper. We think our SSL method could also work well with VAE and perhaps fuse the two tasks even better. In the regression analysis we use a simple regression model that could only show the linear discriminative power of the features. It would be interesting to use a more complicated non-linear model to show the non-linear discriminative power, too. Since we use only one MRI sequence and a subset of scans, our performance on BraTS18 and WMH17 are lower than the state of the art. The best Dice performances of BraTS18 (whole tumor) and WMH17 on testing sets are 0.8839  and 0.80  respectively, and first work also uses variational autoencoder to provide more regularization effect similar to the Ladder network  and our MSSL method.
In conclusion, MASSL is a promising segmentation framework for simple and efficient multi-task learning that can achieve strong improvements in semi-supervised as well as in fully supervised settings.
This research is supported by the China Scholarship Council (File No.201706170040). We gratefully acknowledge the support of the computational resources provided by SURFsara services and Cartesius.
-  Bakas, S., Akbari, H., Sotiras, A., Bilello, M., Rozycki, M., Kirby, J.S., Freymann, J.B., Farahani, K., Davatzikos, C.: Advancing the cancer genome atlas glioma MRI collections with expert segmentation labels and radiomic features. Scientific data 4, 170117 (2017)
-  Çiçek, Ö., Abdulkadir, A., Lienkamp, S.S., Brox, T., Ronneberger, O.: 3D U-Net: learning dense volumetric segmentation from sparse annotation. In: MICCAI. pp. 424–432. Springer (2016)
-  Kingma, D.P., Mohamed, S., Rezende, D.J., Welling, M.: Semi-supervised learning with deep generative models. In: NeurIPS. pp. 3581–3589 (2014)
-  Li, H., Jiang, G., Zhang, J., Wang, R., Wang, Z., Zheng, W.S., Menze, B.: Fully convolutional network ensembles for white matter hyperintensities segmentation in MR images. NeuroImage 183, 650–665 (2018)
-  Menze, B.H., Jakab, A., Bauer, S., Kalpathy-Cramer, J., Farahani, K., Kirby, J., Burren, Y., Porz, N., Slotboom, J., Wiest, R., et al.: The multimodal brain tumor image segmentation benchmark (BRATS). TMI 34(10), 1993–2024 (2015)
-  Milletari, F., Navab, N., Ahmadi, S.A.: V-net: Fully convolutional neural networks for volumetric medical image segmentation. In: 3DV. pp. 565–571. IEEE (2016)
-  Myronenko, A.: 3D MRI brain tumor segmentation using autoencoder regularization. In: International MICCAI Brainlesion Workshop. pp. 311–320. Springer (2018)
-  Rasmus, A., Berglund, M., Honkala, M., Valpola, H., Raiko, T.: Semi-supervised learning with ladder networks. In: NeurIPS. pp. 3546–3554 (2015)
-  Ronneberger, O., Fischer, P., Brox, T.: U-Net: Convolutional networks for biomedical image segmentation. In: MICCAI. pp. 234–241. Springer (2015)
-  Schlemper, J., Oktay, O., Schaap, M., Heinrich, M., Kainz, B., Glocker, B., Rueckert, D.: Attention gated networks: Learning to leverage salient regions in medical images. Medical Image Analysis (2019)
-  Sedai, S., Mahapatra, D., Hewavitharanage, S., Maetschke, S., Garnavi, R.: Semi-supervised segmentation of optic cup in retinal fundus images using variational autoencoder. In: MICCAI. pp. 75–82. Springer (2017)
-  Xia, Y., Liu, F., Yang, D., Cai, J., Yu, L., Zhu, Z., Xu, D., Yuille, A., Roth, H.: 3D semi-supervised learning with uncertainty-aware multi-view co-training. arXiv preprint arXiv:1811.12506 (2018)
-  You, X., Peng, Q., Yuan, Y., Cheung, Y.m., Lei, J.: Segmentation of retinal blood vessels using the radial projection and semi-supervised approach. Pattern Recognition 44(10-11), 2314–2324 (2011)
-  Zhou, X.Y., Yang, G.Z.: Normalization in training U-Net for 2D biomedical semantic segmentation. IEEE Robotics and Automation Letters (2019)
-  Zhou, Y., Wang, Y., Tang, P., Shen, W., Fishman, E.K., Yuille, A.L.: Semi-supervised multi-organ segmentation via multi-planar co-training. arXiv preprint arXiv:1804.02586 (2018)