How well do U-Net-based segmentation trained on adult cardiac magnetic resonance imaging data generalise to rare congenital heart diseases for surgical planning?
Planning the optimal time of intervention for pulmonary valve replacement surgery in patients with the congenital heart disease Tetralogy of Fallot (TOF) is mainly based on ventricular volume and function according to current guidelines. Both of these two biomarkers are most reliably assessed by segmentation of 3D cardiac magnetic resonance (CMR) images. In several grand challenges in the last years, U-Net architectures have shown impressive results on the provided data. However, in clinical practice, data sets are more diverse considering individual pathologies and image properties derived from different scanner properties. Additionally, specific training data for complex rare diseases like TOF is scarce.
For this work, 1) we assessed the accuracy gap when using a publicly available labelled data set (the Automatic Cardiac Diagnosis Challenge (ACDC) data set) for training and subsequent applying it to CMR data of TOF patients and vice versa and 2) whether we can achieve similar results when applying the model to a more heterogeneous data base.
Multiple deep learning models were trained with four-fold cross validation. Afterwards they were evaluated on the respective unseen CMR images from the other collection. Our results confirm that current deep learning models can achieve excellent results (left ventricle dice of 0.951/0.941 train/validation) within a single data collection. But once they are applied to other pathologies, it becomes apparent how much they overfit to the training pathologies (dice score drops between 0.072 for the left and 0.165 for the right ventricle).
Tetralogy of Fallot (TOF) is the most common cyanotic congenital heart disease, affecting approx. 1:2500 babies (cf. European TOF guidelines[tof_guidelines_2010] and American cardiology congress guidelines[accaha_2008]). Currently, the initial repair in infancy is very successful, but often leaves the patient with progressive heart valve insufficiency. Over time, this can lead to dysfunction of the left and right heart chambers (ventricles) in about 25% of adults, ultimately causing heart failure. Current medical guidelines suggest replacing the affected heart valve by a prosthetic valve to counteract ventricular dysfunction. However, it remains extremely difficult to decide on the optimal point in time for replacement, which is often done in patientâs adolescence, since the durability of artificial prostheses is limited. Therefore, TOF patients have to undergo several surgical replacements in their lifetime, and each surgery is associated with high risks.
Current guidelines for the timing of a heart valve replacement, more specifically of pulmonary valve replacement, are mainly based on enlarged ventricular volumes and depressed ventricular function (cf. [tof_guidelines_2010], [accaha_2008]). Technically, these information can be obtained by segmentation of cardiac magnetic resonance (CMR) images.
Current deep learning models have been shown to be capable of fully-automatic segmentation of CMR data sets [Bernard2018a]. For example, in a huge international challenge for CMR, held at the Medical Image Computing and Computer Assisted Intervention (MICCAI) conference in 2017, called ACDC, deep learning methods outperformed previous methods. The best approaches use deep convolutional architectures. Most of them build on the U-Net architecture (introduced by Ronneberger et al[Ronneberger2015]).
Training one deep learning model per possible pathology requires manual annotation of hundreds of images per pathology which is a time consuming task. Especially the field of congenital heart diseases suffer from a lack of clinical cases which stands in contrasts with the huge clinical heterogeneity of the given data. In this paper, we study how well models trained on the publicly available ACDC data generalise to the case of TOF patients.
2 Material and Methods
|Name of Cohort||Years collected||# Patients||Heart centers||Pathologies|
|GCN dataset||2005 - 2008||203 labeled of total 406||14 German heart centers||4 TOF sub-pathologies|
|ACDC dataset||2017||100 labeled of total 150||1 center in Dijon||4 pathologies + 1 healthy group|
The ACDC dataset covers adults with normal cardiac anatomy and function and the following four cardiac pathologies: systolic heart failure with infarction, dilated cardiomyopathy, hypertrophic cardiomyopathy and abnormal right ventricular volume. Each pathology is represented by 20 labeled patients and has labels for the right ventricle (RV), left ventricle (LV) and the left ventricular myocardium (MYO). Each label is given at the end systolic (ES) and end diastolic (ED) time step.
The German Competence Network for Congenital Heart Defects (GCN) data set is derived in the follow up study of post-repair Tetralogy of Fallot[follow_up_study_2019] from multi-center collaborations of 14 German heart centers (cf. figure 1) and contains clinical, ECG and CMR data from patients with repaired TOF. This work relies only on the post surgery CMR images of the GCN study. The ACDC dataset represents adult hearts, while the GCN subset considered in this study consists of CMR images of adolescents with the mean age of 17 years. The RV, LV and MYO of 203 patients in the GCN dataset were labelled manually by clinical experts. Each patient has five labeled time steps, reflecting the five heart phases: end diastolic, mid systolic, end systolic, peak flow and mid diastolic phases, which makes this data set quite unique compared to other segmentation datasets. The initial analysis of the data set suggests that the provided CMR data is very heterogeneous considering size, resolution, scanner-model and number of available cardiac phases (cf. figure 1 and 7).
The segmentation network has a U-Net architecture with four down-sampling and four respective up-sampling blocks (based on the original U-Net from Ronneberger et al [Ronneberger2015]) with some small adjustments: batchnormalization layer after each non-linear activation, higher dropout rates, ELU activation instead of RELU, Dice loss function, data augmentation and no resampling to force the network to learn different physical representations to increase the generalisation of the network.
Figure 3 and 3 exemplary show one down- and up-sampling block. The ELU activation which is applied to each convolutional layer is not listed to keep it readable. The encoding part has four down-sampling blocks, each of them consists of a 2x2 max pooling layer, followed by a 3x3 convolution layer, a batchnormalization layer, a dropout layer and again a 3x3 convolution layer followed by a batchnormalization layer. All convolution layers use zero padding. The dropout rate at the shallow encoding and decoding layers were set to 0.3, this rate increased up to 0.5 in the deeper layers. The decoding part consists of the same amount of four up-sampling blocks, each of them has a concatenate layer, a 3x3 convolution layer, a batchnormalization layer, a dropout layer, 3x3 convolution layers, a batchnormalization layer and a transpose layer with stride = (2,2) and zero padding. Contrary to the original recommendations from Ioffe and Szegedy[ioffe_batch_nodate] it turned out that the batchnormalization layer works better if it is attached after the non-linear ELU activation function. To work with categorical labels a sigmoid activation was applied on the last layer of the U-Net. Several activation functions, optimizers and parameters were tested for this network, but they did not produce reasonable different results. For the final evaluation we used the ADAM optimizer with a factorised decreasing learning rate. The trainings setup and parameters are described in section 2.5. The network input was set to (32, 224, 224, 1). Within this study a 3D version of this U-Net was also implemented, but the 2D version beat the 3D network in all tests. Because of that, the following sections will only describe the results for the 2D model. The U-Net implementation from Isensee et al.[Isensee2018] was also applied to the GCN dataset. The model architecture of the original paper needed to be modified to work with the non-isotropic spatial resolution of the CMR images. Unfortunately the segmentation performance of this modified Isensee-model was outperformed by the own 2D model, which could be related to the modified layers and the case that the 3D volumes of GCN dataset are not well aligned in the axial direction.
2.3 Losses and metrics
The following loss functions and their combinations have been tested:
Binary cross-entropy (BCE)
where the sums are calculated over the N voxels of the predicted mask volume (softmax output of the network) P and the ground truth binary mask volume G.
Weighted cross-entropy (WCE)
where is a vector with one multiplicative coefficient per class . Values smaller than 1 reduce, while values greater 1 increase the weighting of this class. This loss was introduced by Dalca et al.[dalca_anatomical_2018].
Jaccard distance loss (JDL)
with set to 1 to avoid dividing by zero.
Soft Dice loss (SDL)
where C represents the foreground classes and the binary soft Dice coefficient (eq. 5) for one class. The original SDL by Milletari et al.[milletari2016v] used the squared sum, which did not work well in our experiments. Because of this, the plain sum, as described by Drozdzal et al. [drozdzal_importance_2016], is used.
The WCE learned the fastest. Overall the SDL produced the best results. All loss functions worked better by ignoring the background channel during training. A learning rate scheduler with a patience of five epochs, an initial learning rate of , a decreasing factor of 0.5 and a minimal learning rate of was used to adapt the learning rate, trainings with no change in the loss for more than 10 epochs were stopped.
To measure the performance of the models across the two datasets the Dice coefficient was used. Equation 5 gives the binary soft Dice coefficient which is used to define the model performance per label.
In the following this will be referred to . To interpret the performance across all labels within one image the class-wise averaged dice coefficient is calculated for the foreground classes. The overrepresented background class is ignored to reflect the performance of the labels of interest. In the following, this metric will be referred as and is used to evaluate the overall performance drop between the train/test splits and the unseen other dataset. All losses were able to produce a above 85%.
The following pre-processing/augmentation steps have been applied on each batch:
No resampling to get different sizes - it turned out that the network learned slower but generalised better to unseen datasets if the data is not resampled. By this, the network is forced to learn the segmentation features in any size.
Using the grid distortion method from the library Albumentations[albumentations] with a probability parameter of 80 % during training. The number of distortion steps are kept to the standard value of 10. The default behaviour of this function is to distort with linear interpolation on the images and with nearest neighbour interpolation on the masks. This augmentation method increased the robustness of the model to unseen data a lot. Other augmentation methods (random zoom, rotation, shift and cropping) are tested but the combination of no resampling (c.f. first preprocessing point) and the grid distortion performed the best.
Cropping to square - to avoid distortions in subsequent resizing steps.
Center cropping - if image size is bigger than the network input shape, else resize image.
Resize to proper network size if cropping is not possible - bi-linear interpolation and anti-aliasing for the images, nearest neighbour interpolation and anti-aliasing for the masks.
Pixel value clipping - value clipping based on the .999 quantile to improve the pixel intensity distribution. Some images had outliers with pixel values near to 20.000 while 99.9 % of the pixel values are below 2000.
Normalise pixel values - min/max-normalise each image to scale the pixel values of different scanners into the same range. Scaling to zero mean and one standard deviation was also tested, but this led to worse results.
Both datasets were shuffled and split in four folds. One model was trained per fold, which results in four models per dataset. Neither the model architecture/parameters nor the preprocessing and training parameters were optimised at this step to avoid any parameter based advantages/disadvantages for each of the datasets. The ACDC dataset was split with respect to the five pathologies. This resulted in 15 patients per pathology per training fold and five patients per pathology per test fold. All together this led to 75 patients in each training fold with 1413 slices and 25 patients in each test fold with 471 slices. Figure 8 in the Appendix provides more details on the pathology based ACDC splits. The GCN dataset was randomly split. This resulted in 152 patients per training fold with 10544 slices and 51 patients per test fold with 3514 slices.
The following two experiments were evaluated. In experiment one four models were trained on the ACDC dataset. Afterwards they were validated on the corresponding unseen test splits and the unseen pathologies from the GCN dataset. In experiment two four models were trained on the GCN trainings splits and validated on the corresponding unseen GCN test splits and the unseen ACDC dataset. Early stopping with a patience of 10 epochs was applied to the learning process. The initial learning rate was set to 0.001, this learning rate was decreased by a factor of 0.5 (minimal learning rate = ) after five epochs without any gain in the loss.
After the train/test/unseen pathology gaps for both datasets were defined, three methods for closing the gap between the public available dataset and the unseen TOF pathology were investigated. Within the first method an increasing amount of GCN patients were successively added to the ACDC trainings dataset, for each extended dataset a new model was trained without changing the model-parameters. For the second and third approach, one trained ACDC model from the crossvalidation (cf. section 2.5) was randomly chosen as a baseline model and finetuned. For the second method, the baseline model was finetuned with the ACDC trainings split plus an increasing amount of GCN patients, whereas for the third approach the baseline model was finetuned only on an increasing amount of GCN patients. For each method ten models were trained each with a different amount of additional GCN patients (5 to 150 patients). Each model was then evaluated on the ACDC trainings split, the unseen ACDC test split and the unseen rest of the GCN dataset.
3.1 Generalisation gap
The following section defines the generalisation gap in both directions. Figure 4 shows boxplots for the cross validated model performances within the four folds for both datasets. Table 2 provides the mean dice scores per trainings- and evaluation-dataset and table 3 lists the absolute gap between the averaged training and evaluation scores.
To answer the question whether it is possible to train a U-net on a publicly available dataset and apply it afterwards on unseen pathologies, the following sections will mainly describe the model performances/ gaps. The baseline models which were trained on the ACDC dataset achieved a mean score of 0.917 /0.899 (train/test) and a mean train/test dice gap for all labels of 0.017 (0.026/0.010/0.016 for RV/LV/MYO). The LV (0.951/0.941 - train/test) seems to be the easiest label to learn. The trained models achieved a mean validation for the unseen test split of 0.941 for the LV, which is 0.027 below the ED score of the MICCAI winning model from Isensee et al. [Isensee2018] and 0.010 better than the ES score of that winning model. The MYO seems to be the most difficult label to learn, related to the absolute test score of the ACDC models with a mean dice score of 0.866.
By validating these models on the unseen pathologies from the GCN dataset all scores dropped in a range from 0.072 to 0.165 compared to the training score. The dice score of the LV has the smallest drop (0.072) which confirms the expectations that the LV of TOF patients looks more like a normal LV. The dice score for the RV has the biggest drop with 0.165, which could be correlated with the clinical picture of TOF patients which usually have an deformed and increased end-systolic and end-diastolic volume of the right ventricle.
The models which were trained on the GCN dataset achieved lower train ( 0.889) and test dice scores ( 0.865) than the models trained on the ACDC dataset (0.917 / 0.899 - train/validation). On the other hand they were able to generalise better to the unseen dataset which results in lower gaps for all labels (cf. table 3). The reason for the smaller generalisation gap of the GCN-models could be the greater amount of CMR images and the greater clinical diversity of the GCN dataset, which forces the model to find other features.
It turned out that all three methods closed the generalisation gap with nearly equal results for each of the labels and each of the added amount of GCN patients. The generalisation increase of the second method was slightly more stable and is shown in figure 5. The other improvement plots are attached in the appendix (cf. figure 9 and figure 10)).
Figure 6 illustrates the dice score increase between the baseline ACDC model and the best finetuned ACDC model (trained on ACDC + 150 GCN patients). All labels and each evaluation modality benefit from finetuning. The performance of the model on the GCN dataset increased by 0.10 , the ACDC train and test dice stayed stable. The RV and the MYO got the biggest boost (0.11 ), the left ventricle got the smallest boost (0.06 ) from finetuning the models.
The segmentation networks were able to segment the LV in the unseen dataset with a small reduction in the dice score (left ventricle dice score gap of 0.072 between the train dataset and the unseen pathology), but they generalised much worse to the TOF patients (right ventricle/ left ventricle myocardium dice score gap of 0.165/0.164).
This indicates that the U-Net may still overfit to one dataset even if high dropout rates and massive data augmentation is applied. In this case the U-net performed very well within the 4-fold splits of the ACDC dataset and achieved a train/test of 0.917/0.899 with a train/test gap of 0.017. The dropped by 0.136 when the models were applied to patients with an unseen pathology, in this case TOF patients. The crosscheck (trained on GCN data, evaluated on ACDC data) showed a smaller but still mentionable gap between the two datasets (test/unseen - 0.865/0.788) which supports the assumption of a pathology driven gap. Still it would be very interesting to see the performance of these models when they are applied to a healthy group of children.
The results showed that this gap could be decreased by finetuning the ACDC model with the ACDC + TOF patients, which worked slightly better than continue the training of the baseline models only with the GCN data or training a new model with the ACDC + GCN data included from scratch. There are many other finetuning methods available, which might overcome the generalisation gap faster or with less necessary finetuning examples. One could try and finetune the models by fixing the weights of the decoding layers and finetune only the last encoding layers.
During the parameter optimisation it turned out that each loss function was able to train the models, but the learning process could be increased in convergence speed and accuracy a lot by choosing the right loss function for the given problem. All tested loss functions worked better if they are only applied to the foreground classes.
There is still the problem of rare pathologies with small available datasets, and the results indicate that it is necessary to finetune on them to make reasonable predictions. Maybe the idea of federated and distributed learning as applied by Ken et al.[Ken_distributed_learning_2018] could overcome the problem of rare datasets and pathologies by sharing the network weights across distributed institutions.
The experiments gave five mentionable insights. First, U-Nets are able to generalise very well to unseen data of the same pathology (with a gap of 1-2% ) if the training data is pre-processed (cf. section 2.4) and augmented in the right manner. Second, U-Nets might overfit to the trained pathologies and generalise bad to unseen pathologies with deformed structures ( dropped by 0.136). Third, the U-Net performance of the finetuned models was stable in segmenting the ACDC data, even when more than double of the trainings data (75 ACDC patients & 150 GCN patients) consisted of patients with Tetralogy. Fourth, the model performance on the TOF patients increased even after more than 125 TOF patients are added to the training data. One might expect that this converges earlier. The biggest performance gain could be generated by adding at least 50 patients. Fifth, the generalisation gap towards the TOF pathology indicates that there are different features within the MRI images of TOF patients compared to the five pathologies contained in the ACDC data. This, together with the observation that the model performance did not stop to increase after a certain amount of patients, could be an indication that the TOF patients themselves have different features which could be used to define subgroups within the TOF patients.
The Titan Xp GPU card used for this research was donated by the NVIDIA Corporation. This work was supported by the Competence Network for Congenital Heart Defects, which has received funding from the Federal Ministry of Education and Research, grant number 01GI0601 (until 2014), and the DZHK (German Centre for Cardiovascular Research; as of 2015).