Uncertainty-aware Self-ensembling Model for Semi-supervised 3D Left Atrium Segmentation
Training deep convolutional neural networks usually requires a large amount of labeled data. However, it is expensive and time-consuming to annotate data for medical image segmentation tasks. In this paper, we present a novel uncertainty-aware semi-supervised framework for left atrium segmentation from 3D MR images. Our framework can effectively leverage the unlabeled data by encouraging consistent predictions of the same input under different perturbations. Concretely, the framework consists of a student model and a teacher model, and the student model learns from the teacher model by minimizing a segmentation loss and a consistency loss with respect to the targets of the teacher model. We design a novel uncertainty-aware scheme to enable the student model to gradually learn from the meaningful and reliable targets by exploiting the uncertainty information. Experiments show that our method achieves high performance gains by incorporating the unlabeled data. Our method outperforms the state-of-the-art semi-supervised methods, demonstrating the potential of our framework for the challenging semi-supervised problems111Code is available in https://github.com/yulequan/UA-MT.
Keywords:Semi-supervised learning Uncertainty estimation Self-ensembling Segmentation
Automated segmentation of left atrium (LA) in magnetic resonance (MR) images is of great importance in promoting the treatment of atrial fibrillation. With a large amount of labeled data, deep learning has greatly advanced the segmentation of LA . In the medical imaging domain, however, it is expensive and tedious to delineate reliable annotations from 3D medical images in a slice-by-slice manner by experienced experts. Since unlabeled data is generally abundant, we focus on studying semi-supervised approach on LA segmentation by leveraging both limited labeled data and abundant unlabeled data.
Considerable effort has been devoted to utilizing unlabeled data to improve the segmentation performance in medical image community [1, 2, 3, 7, 19]. For example, Bai et al.  introduced a self-training-based method for cardiac MR image segmentation, where the network parameters and the segmentation for unlabeled data were alternatively updated. Besides, adversarial learning has been used in semi-supervised learning [6, 12, 18]. Zhang et al.  designed a deep adversarial network to use the unannotated images by encouraging the segmentation of unannotated images to be similar to those of the annotated ones. Another approach  utilized an adversarial network to select the trustworthy regions of unlabeled data to train the segmentation network. With the promising results achieved by self-ensembling methods [9, 14] on semi-supervised natural image classification, Li et al.  extended the -model  with transformation consistent for semi-supervised skin lesion segmentation. Other approaches [5, 13] utilized the weight-averaged consistency targets for semi-supervised MR segmentation. Although promising progress has been achieved, these methods do not consider the reliability of the targets, which may lead to meaningless guidance.
In this paper, we present a novel uncertainty-aware semi-supervised learning framework for left atrium segmentation from 3D MR images by additionally leveraging the unlabeled data. Our method encourages the segmentation predictions to be consistent under different perturbations for the same input, following the same spirit of mean teacher . Specifically, we build a teacher model and a student model, where the student model learns from the teacher model by minimizing the segmentation loss on the labeled data and the consistency loss with respect to the targets from the teacher model on all input data. Without ground truth provided in the unlabeled input, the predicted target from the teacher model may be unreliable and noisy. In this regard, we design the uncertainty-aware mean teacher (UA-MT) framework, where the student model gradually learns from the meaningful and reliable targets by exploiting the uncertainty information of the teacher model. Concretely, besides generating the target outputs, the teacher model also estimates the uncertainty of each target prediction with Monte Carlo sampling. With the guidance of the estimated uncertainty, we filter out the unreliable predictions and preserve only the reliable ones (low uncertainty) when calculating the consistency loss. Hence, the student model is optimized with more reliable supervision and in return, encourages the teacher model to generate higher-quality targets. Our method was extensively evaluated on the dataset of MICCAI 2018 Atrial Segmentation Challenge. The results demonstrate that our semi-supervised method achieves large improvements for the LA segmentation by utilizing the unlabeled data, and also outperforms other state-of-the-art semi-supervised segmentation methods.
Fig. 1 illustrates our uncertainty-aware self-ensembling mean teacher framework (UA-MT) for semi-supervised LA segmentation. The teacher model generates targets for the student model to learn from and also estimates the uncertainty of the target. The uncertainty-guided consistency loss improves the student model and the robustness of the framework.
2.1 Semi-supervised Segmentation
We study the task of semi-supervised segmentation for 3D data, where the training set consists of labeled data and unlabeled data. We denote the labeled set as and the unlabeled set as , where is the input volume and is the ground-truth annotations. The goal of our semi-supervised segmentation framework is to minimize the following combined objective function:
where denotes the supervised loss (e.g., cross-entropy loss) to evaluate the quality of the network output on labeled inputs, and represents the unsupervised consistency loss for measuring the consistency between the prediction of the teacher model and the student model for the same input under different perturbations. Here, denotes the segmentation neural network; and represents the weights and different perturbation operations (e.g., adding noise to input and network dropout) of the teacher and student models, respectively. is an ramp-up weighting coefficient that controls the trade-off between the supervised and unsupervised loss.
Recent study [9, 14] show that ensembling predictions of the network at different training process can improve the quality of the predictions, and using them as the teacher predictions can improve the results. Therefore, we update the teacher’s weights as an exponential moving average (EMA) of the student’s weights to ensemble the information in different training step ; see Fig. 1. Specifically, we update the teacher’s weights at training step as: where is the EMA decay that controls the updating rate.
2.2 Uncertainty-Aware Mean Teacher Framework
Without the annotations in the unlabeled inputs, the predicted targets from the teacher model may be unreliable and noisy. Therefore, we design an uncertainty-aware scheme to enable the student model to gradually learn from the more reliable targets. Given a batch of training images, the teacher model not only generates the target predictions but also estimates the uncertainty for each target. Then the student model is optimized by the consistency loss, which focuses on only the confident targets under the guidance of the estimated uncertainty.
2.2.1 Uncertainty Estimation.
Motivated by the uncertainty estimation in Bayesian networks, we estimate the uncertainty with the Monte Carlo Dropout . In detail, we perform stochastic forward passes on the teacher model under random dropout and input Gaussian noise for each input volume. Therefore, for each voxel in the input, we obtain a set of softmax probability vector: . We choose the predictive entropy as the metric to approximate the uncertainty, since it has a fixed range . Formally, the predictive entropy can be summarized as:
where is the probability of the -th class in the -th time prediction. Note that the uncertainty is estimated in voxel level and the uncertainty of the whole volume is .
2.2.2 Uncertainty-Aware Consistency Loss.
With the guidance of the estimated uncertainty , we filter out the relatively unreliable (high uncertainty) predictions and select only the certain predictions as targets for the student model to learn from. In particular, for our semi-supervised segmentation task, we design the uncertainty-aware consistency loss as the voxel-level mean squared error (MSE) loss of the teacher and student models only for the most certainty predictions:
where is the indicator function; and are the predictions of teacher model and student model at the -th voxel, respectively; is the estimated uncertainty at the -th voxel; and is a threshold to select the most certain targets. With our uncertainty-aware consistency loss in the training procedure, both the student and teacher can learn more reliable knowledge, which can then reduce the overall uncertainty of the model.
2.3 Technique Details
We employ V-Net  as our network backbone. We remove the short residual connection in each convolution block, and use a joint cross-entropy loss and dice loss . To adapt the V-Net as a Bayesian network to estimate the uncertainty, two dropout layers with dropout rate 0.5 are added after the L-Stage 5 layer and R-Stage 1 layer of the V-Net. We turn on the dropout in the network training and uncertainty estimation, while we turn off the dropout in the testing phase, as we do not need to estimate uncertainty. We empirically set the EMA decay as referring to the previous work . Following [9, 14], we use a time-dependent Gaussian warming up function to control the balance between the supervised loss and unsupervised consistency loss, where denotes the current training step and is the maximum training step. Such design can ensure that at the beginning, the objective loss is dominated by the supervised loss term and avoid the network get stuck in a degenerate solution where no meaningful target prediction of unlabeled data is obtained . For the uncertainty estimation, we set to balance the uncertainty estimation quality and training efficiency. We also use the same Gaussian ramp-up paradigm to ramp up the uncertainty threshold from to in Eq. (3), where is the maximum uncertainty value (i.e., in our experiments). As the training continues, our method would filter out less and less data and enable the student to gradually learn from the relatively certain to uncertain cases.
3 Experiments and Results
3.0.1 Dataset and Pre-processing.
We evaluated our method on the Atrial Segmentation Challenge dataset222http://atriaseg2018.cardiacatlas.org/. It provides 100 3D gadolinium-enhanced MR imaging scans (GE-MRIs) and LA segmentation mask for training and validation. These scans have an isotropic resolution of . We split the 100 scans into 80 scans for training and 20 scans for evaluation. All the scans were cropped centering at the heart region for better comparison of the segmentation performance of different methods, and normalized as zero mean and unit variance.
The framework was implemented in PyTorch, using a TITAN Xp GPU. We used the SGD optimizer to update the network parameters (weight decay=, momentum=0.9). The initial learning rate was set as 0.01 and divided by 10 every 2500 iterations. We totally trained 6000 iterations as the network has converged. The batch size was 4, consisting of 2 annotated images and 2 unannotated images. We randomly cropped sub-volumes as the network input and the final segmentation results were obtained using a sliding window strategy. We used the standard data augmentation techniques on-the-fly to avoid overfitting following , including randomly flipping, and rotating with 90, 180 and 270 degrees along the axial plane.
|Method||# scans used||Metrics|
3.0.3 Evaluation of Our Semi-supervised Segmentation.
We use four metrics to quantitatively evaluate our method, including Dice, Jaccard, the average surface distance (ASD), and the 95% Hausdorff Distance (95HD). Out of the 80 training scans, we use 20% (i.e., 16) scans as labeled data and the remaining 64 scans as unlabeled data. Table 1 presents the segmentation performance of V-Net trained with only the labeled data (the first two rows) and our semi-supervised method (UA-MT) on the testing dataset. Compared with the Vanilla V-Net, adding dropout (Bayesian V-Net) improves the segmentation performance, and achieves an average Dice of 86.03% and Jaccard of 76.06% with only the labeled training data. By utilizing the unlabeled data, our semi-supervised framework further improves the segmentation by 4.15% Jaccrad and 2.85% Dice.
To analyze the importance of consistency loss for labeled data and unlabeled data, we conducted another experiment (UA-MT-UN) with the consistency loss only on the unlabeled data. The performance of this method is very close to UA-MT, validating that the performance of our method improves mainly due to the unlabeled data. We trained the fully supervised V-Net with all 80 labeled scans, which can be regarded as the upper-line performance. As we can see, our semi-supervised method is approaching the fully supervised ones. To validate our network backbone design, we reference the state-of-the-art challenging method , which used multi-task U-Net for LA segmentation. They reported a 90.10% Dice on 20 testing scans with 80 training scans. Compared with this method, we can regard our V-Net as a standard baseline model.
3.0.4 Comparison with Other Semi-supervised Methods.
We implemented several state-of-the-art semi-supervised segmentation methods for comparison, including self-training based method , deep adversarial network (DAN) , adversarial learning based semi-supervised method (ASDNet) , and -Model based method (TCSE) . Note that we used the same network backbone (Bayesian V-Net) in these methods for fair comparison. As shown in Table 1, compared with the self-training method, the DAN and ASDNet improve by 0.60% and 0.98% Dice, respectively, showing the effect of adversarial learning in semi-supervised learning. The ASDNet is better than DAN, since it selects the trustworthy region of unlabeled data for training the segmentation network. The self-ensembling-based methods TCSE achieve slightly better performance than ASDNet, demonstrating that perturbation-based consistency loss is helpful for the semi-supervised segmentation problem. Notably, our method (UA-MT) achieves the best performance over the state-of-the-art semi-supervised methods, except that the ASD performance is comparable with ASDNet, corroborating that our uncertainty-aware mean teacher framework has the full capability to draw out the rich information from the unlabeled data.
|Method||# scans used||Metrics|
3.0.5 Analysis of Our Method.
To validate the effectiveness of our uncertainty-aware scheme, we evaluate the performance of the original mean teacher method (MT) and an adapted mean teacher method (MT-Dice) with dice-loss-like consistency loss . As shown in Table 2, our uncertainty-aware method outperforms both the MT model and MT-Dice model. We also investigate the impact of using different numbers of labeled scans in our semi-supervised method. As shown in Table 2, our semi-supervised method consistently improves the supervised-only V-Net (Bayesian V-Net) by utilizing the unlabeled data on both 10% (i.e., 8) and 30% (i.e., 24) labeled scans, demonstrating our method effectively utilizes the unlabeled data for the performance gains. In Fig. 2, we show some segmentation examples of supervised method and our semi-supervised method, and the estimated uncertainty. Compared with the supervised method, our results have higher overlap ratio with the ground truth (the second row) and produce less false positives (the first row). As shown in Fig. 2(d), the network estimates high uncertainty near the boundary and ambiguous regions of great vessels.
We present a novel uncertainty-aware semi-supervised learning method for left atrium segmentation from 3D MR images.
Our method encourages the segmentation to be consistent for the same input under different perturbations to use the unlabeled data.
More importantly, we explore the model uncertainty to improve the quality of the target.
The comparison with other semi-supervised methods confirm the effectiveness of our method.
The future works include investigating the effect of different uncertainty estimation manners and applying our framework to other semi-supervised medical image segmentation problems.
Acknowledgments. The work was partially supported by HK RGC TRS project T42-409/18-R and in part by the CUHK T Stone Robotics Institute, The Chinese University of Hong Kong.
-  Bai, W., Oktay, O., Sinclair, M.e.a.: Semi-supervised learning for network-based cardiac mr image segmentation. In: MICCAI. pp. 253–260 (2017)
-  Baur, C., Albarqouni, S., Navab, N.: Semi-supervised deep learning for fully convolutional networks. In: MICCAI. pp. 311–319 (2017)
-  Chartsias, A., Joyce, T., Papanastasiou, G., Semple, S., Williams, M., Newby, D., Dharmakumar, R., Tsaftaris, S.A.: Factorised spatial representation learning: application in semi-supervised myocardial segmentation. MICCAI pp. 490–498 (2018)
-  Chen, C., Bai, W., Rueckert, D.: Multi-task learning for left atrial segmentation on ge-mri. arXiv preprint arXiv:1810.13205 (2018)
-  Cui, W., Liu, Y., Li, Y., Guo, M., Li, Y., Li, X., Wang, T., Zeng, X., Ye, C.: Semi-supervised brain lesion segmentation with an adapted mean teacher model. In: IPMI. pp. 554–565 (2019)
-  Dong, N., Kampffmeyer, M., Liang, X., Wang, Z., Dai, W., Xing, E.: Unsupervised domain adaptation for automatic estimation of cardiothoracic ratio. In: MICCAI. pp. 544–552 (2018)
-  Ganaye, P.A., Sdika, M., Benoit-Cattin, H.: Semi-supervised learning for segmentation under semantic constraint. In: MICCAI. pp. 595–602 (2018)
-  Kendall, A., Gal, Y.: What uncertainties do we need in bayesian deep learning for computer vision? In: NIPS. pp. 5574–5584 (2017)
-  Laine, S., Aila, T.: Temporal ensembling for semi-supervised learning. arXiv preprint (2016)
-  Li, X., Yu, L., Chen, H., Fu, C.W., Heng, P.A.: Semi-supervised skin lesion segmentation via transformation consistent self-ensembling model. BMVC (2018)
-  Milletari, F., Navab, N., Ahmadi, S.A.: V-net: Fully convolutional neural networks for volumetric medical image segmentation. In: 3DV. pp. 565–571 (2016)
-  Nie, D., Gao, Y., Wang, L., Shen, D.: Asdnet: Attention based semi-supervised deep networks for medical image segmentation. In: MICCAI. pp. 370–378 (2018)
-  Perone, C.S., Cohen-Adad, J.: Deep semi-supervised segmentation with weight-averaged consistency targets. In: DLMIA workshop (2018)
-  Tarvainen, A., Valpola, H.: Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results. In: NIPS (2017)
-  Xiong, Z., Fedorov, V.V., Fu, X., Cheng, E., Macleod, R., Zhao, J.: Fully automatic left atrium segmentation from late gadolinium enhanced magnetic resonance imaging using a dual fully convolutional neural network. TMI 38(2), 515–524 (2019)
-  Yang, X., Bian, C., Yu, L., Ni, D., Heng, P.A.: Hybrid loss guided convolutional networks for whole heart parsing. In: International Workshop on STACOM (2017)
-  Yu, L., Cheng, J.Z., Dou, Q., Yang, X., Chen, H., Qin, J., Heng, P.A.: Automatic 3d cardiovascular mr segmentation with densely-connected volumetric convnets. In: MICCAI. pp. 287–295. Springer (2017)
-  Zhang, Y., Yang, L., Chen, J., Fredericksen, M., Hughes, D.P., Chen, D.Z.: Deep adversarial networks for biomedical image segmentation utilizing unannotated images. In: MICCAI. pp. 408–416 (2017)
-  Zhou, Y., Wang, Y., Tang, P., Bai, S., Shen, W., Fishman, E.K., Yuille, A.L.: Semi-supervised multi-organ segmentation via multi-planar co-training. arXiv preprint arXiv:1804.02586 (2018)