Matching the Clinical Reality: Accurate OCT-Based Diagnosis From Few Labels
Unlabeled data is often abundant in the clinic, making machine learning methods based on semi-supervised learning a good match for this setting. Despite this, they are currently receiving relatively little attention in medical image analysis literature. Instead, most practitioners and researchers focus on supervised or transfer learning approaches. The recently proposed MixMatch and FixMatch algorithms have demonstrated promising results in extracting useful representations while requiring very few labels. Motivated by these recent successes, we apply MixMatch and FixMatch in an ophthalmological diagnostic setting and investigate how they fare against standard transfer learning. We find that both algorithms outperform the transfer learning baseline on all fractions of labelled data. Furthermore, our experiments show that exponential moving average (EMA) of model parameters, which is a component of both algorithms, is not needed for our classification problem, as disabling it leaves the outcome unchanged. Our code is available online: https://github.com/Valentyn1997/oct-diagn-semi-supervised.
2020 \copyrightclauseCopyright for this paper by its authors. Use permitted under Creative Commons License Attribution 4.0 International (CC BY 4.0).
Proceedings of the CIKM 2020 Workshops, October 19-20, 2020, Galway, Ireland
emi-supervised image classification \sepTransfer learning \sepOptical Coherence Tomography
In recent years deep learning techniques have taken the field of AI by storm. Virtually all state-of-the-art systems in computer vision (CV) rely on some form of deep learning. This paradigm shift has sparked the imagination of many practitioners and researchers in the medical image analysis domain. Computer-aided diagnosis appeared to be next-in-line to benefit from the advancements made in CV, as the amount of data in clinical diagnostics is increasing rapidly. The research community has proposed a plethora of new algorithms and systems for the automated diagnosis of a wide range of diseases. However, clinical adoption has been slow. One crucial reason is that supervised learning, which forms the basis for the vast majority of deep learning approaches, is ill-suited to the medical domain.
This mismatch is two-fold. For one, the labelled data needed for supervised learning is prohibitively costly to generate for medical applications. With a shortage of medical practitioners, diverting medical experts’ time and energy to labelling efforts becomes exceedingly expensive. More fine-grained problem formulations (e.g. single-label vs. multi-label, volume level vs. slice level annotation, etc.) result in exponentially more labelling expenses. Additionally, most clinics lack the tools to label vast amounts of data. Secondly, and perhaps more fundamentally, there is an epistemic problem in generating accurate labels. For any given diagnostic problem, the inter-expert agreement is well below 100%. This discrepancy stems from the fact that medicine is complex and does not always fit neatly into a classification formulation. Additionally, each expert comes with his or her own set of experiences and knowledge.
Instead of solely relying on supervised learning, semi-supervised learning (SSL) should discover the bulk of the knowledge required for solving a diagnostic task on its own, with labels only serving as additional guidance. The idea of SSL is to train a machine learning algorithm on vast amounts of unlabeled data and a small set of labelled samples. SSL is a much better match for the clinical setting, as unlabeled data is often abundant since it is acquired as part of the clinical routine.
In this work, we apply two recently proposed SSL methods, MixMatch  and FixMatch , to a diagnostic problem in ophthalmology. We test which performs better in classifying optical coherence tomography (OCT) b-scans into four classes (one healthy and three pathological) at different fractions of labelled data. We compare the two SSL methods to a baseline transfer learning approach, similar to . After going over related work in the next section, we explain the basis for our experiments in Section 3, covering MixMatch, FixMatch and the transfer learning baseline. In Section 4 we describe the dataset and present the results of our investigation. We conclude with Section 5 by summarizing our findings and discussing how they apply to the clinical setting.
2 Related work
State-of-the-art methods for image classification concentrate on finding the right combination of SSL paradigms. One of the early approaches – Mean-Teacher  – uses exponential moving average (EMA) of model parameters. Virtual Adversarial Training (VAT) , tries to find a minimal perturbation and fit a robust model against it. MixMatch  and RealMix  encompass mixing and overlaying labelled and unlabelled images to obtain consistent predictions. Unsupervised Data Augmentation (UDA)  uses strongly augmented images to force consistency among unlabeled images. ReMixMatch  uses so-called \sayaugmentation anchoring, i.e. strong and weak augmentations, to enforce consistency. Inspired by UDA and ReMixMatch, the authors of FixMatch  significantly simplify SSL by relying only on augmentations and pseudo-labelling with a confidence threshold. We provide a broader overview of applied SSL methods in Section 3.2.
Surprisingly, there exists only a little amount of literature on SSL applied to ophthalmological data.  and  utilize SSL for OCT segmentation. In the domain of automated diagnosis,  employ an autoencoder with an additional classification module on the latent code in the detection of retinopathy from colour fundus images.  tackle the same problem by extending the GAN framework  to one \sayfake and six \sayreal classes, i.e. the labeled classes. Recent works  and  apply the same principle to the classification of OCT b-scans. Most recently,  applied SSL methods to glaucoma detection by imputing missing visual field (VF) measurements through nearest-neighbour identification in the latent space of a pre-trained classification CNN. Afterwards,  train a multi-task network jointly on glaucoma classification and VF measurement prediction. To the best of our knowledge, we are the first to apply consistency regularization based SSL techniques (see Section 3) to the problem of automated diagnosis in ophthalmology.
Among numerous approaches existing in the deep transfer learning , we choose the fine-tuning or network-based transfer learning to be the most promising. [21, 8] proposed to use ImageNet  pre-trained CNN as the initialization for different visual recognition tasks with the limited amount of labels. Yosinski et al.  discovered how unfreezing different parts of the network while fine-tuning affects the target performance.
Transfer learning and semi-supervised learning are two main approaches for predictive modelling when dealing with data with few labels. Transfer learning approaches reuse knowledge from previously learned tasks. On the other hand, the SSL approaches allow learning with small labelled datasets by utilizing unlabeled data from the same distribution in the learning process. In the following, we first discuss our transfer learning baseline and afterwards describe the SSL approaches we have chosen for this study.
3.1 Transfer Learning
When applying transfer learning techniques, the user has to choose how to adapt the model from the auxiliary to the primary task. In our experiments, we use a network, which was pre-trained on ImageNet . For adapting the model to OCT classification we try two common approaches. In the feature extraction approach, we freeze all parameters except for the final fully connected layer, analogous to . Alternatively, we use the pre-trained network as initialization and allow all parameters to change. We refer to this as fine-tuning hereafter.
3.2 Semi-supervised Learning
In our study we compare two of recent state-of-the-art algorithms for SSL MixMatch  and FixMatch . Both algorithms combine several pre-existing techniques from SSL. In this chapter, we review the main ideas and compare their utilization in both algorithms. We refer the reader to Appendix A for the detailed algorithm descriptions.
Data Augmentation is a regularization technique which is often used in supervised learning. The goal is that the model’s prediction is not affected by the certain transformation of data instances. Therefore additional training data is added to the dataset by applying various perturbations to the data while keeping original labels. Most of the data augmentations are domain-specific and require domain knowledge.
MixMatch uses random flip-and-shift augmentations (horizontal flips and random crops) for both labelled and unlabeled data.
FixMatch distinguishes between weak and strong data augmentations. Flip-and-shift augmentations are considered as weak augmentations, whereas affine trasformations and color-jittering are examples of strong augmentations (originally – 14 different transformations from RandAugment ).
Pseudo-Labeling or self-training loss  is the process of using the trained model to obtain labels for unlabeled instances. The predicted labels are used to guide the further learning process, e.g. by using generated labels as new targets.
MixMatch applies different augmentations for an unlabeled instance and computes the class distribution for each augmentation. Therefore, instead of hard one hot label MixMatch defines a probability distribution as the target. To sharpen the distribution and to reduce its entropy, the temperature of distribution is adjusted .
FixMatch uses a \sayclassic version of pseudo-labelling with hard labels and fixed confidence. The class probability distribution is taken from model outputs after a weak augmentation. If the probability of the most probable class exceeds a predefined threshold the label is assigned to a strongly augmented version of the same instance and used in the loss calculation.
Consistency regularization  imposes the constraint that the model should make similar predictions for the same instance under different data augmentations. Both MixMatch and FixMatch apply data augmentation on labelled and unlabeled data and enforce similar prediction for the same instance under different augmentations. For the unlabeled instances, the pseudo-label is used as a target. FixMatch uses soft augmentations to compute pseudo-labels for hard augmentations of the same training sample.
Another popular consistency requirement in SSL is an exponential moving average (EMA) of model parameters over time to smooth the behaviour when the model changes its decisions rapidly. The EMA is the basis of Mean Teacher algorithm , which maintains two models. The teacher model stores an exponential moving average of student’s parameters and is used to make the predictions to compute the pseudo-labels. Therefore pseudo-labels computed by the teacher can be considered as a weighted combination of decisions of previous models. The student model makes the predictions for the training data and is updated based on the training loss. Both MixMatch and FixMatch employ EMA decay while inference. Note, that keeping a second model in memory and updating its parameters results in higher memory requirements and computation costs.
MixUp  is another regularization technique to avoid overfitting. MixUp linearly combines training instance pairs and their prediction targets. Therefore it tries to impose linear behaviour between training samples. MixMatch does not differentiate between pseudo targets predicted for the unlabeled instances and ground truth labels and mixes all possible target pairs. Therefore a resulting instance used in training may be a combination of two pseudo targets, two ground truth labels or of pseudo-target with ground truth label.
Our work follows the principles of the fair SSL evaluation framework, defined by . The authors highlight the importance of using the same classifying model structure for comparison. The evaluation is also meaningful for the real use-case if SSL methods are compared with well-fine-tuned transfer learning and fully supervised models.
For the evaluation we use the UCSD dataset published by Kermany et al. . It contains 84,495 optical coherence tomography (OCT) b-scans pertaining to four categories; \saynormal, \saydrusenoid (DRUSEN), \saychoroidal neovascularization (CNV) and \saydiabetic macular edema (DME).
The images vary in size, where the median image has a size of 496512 pixels. The height of the images ranges between 496 and 512 and the width between 384 and 1536.
The dataset is also obtainable through Kaggle
We compare the performance of transfer learning and SSL models using the same Wide ResNet-50-2  backbone. Since the images are monochrome we duplicate the channel three times for RGB channels.
4.1 Comparison of transfer learning and SSL approaches
First, in Table 1 we compare the performance of our backbone model trained with all labelled instances to the results reported previously in the literature for the same UCSD dataset. As we can see, the backbone model achieves almost perfect performance when trained with enough labels.
Next, in Fig. 1(b) we compare two transfer learning approaches. Note, that the hyperparameter search was done for each number of labels for each approach. We discover that, contrary to our expectations, the fine-tuning variant outperforms feature extraction approach in all label settings. We believe that a thorough selection of hyperparameters with representative validation set reduces the risk of overfitting. Furthermore, since the original models are trained on the dataset with RGB channels, we believe that the model can better adapt to the monochrome setting when all model weights are allowed to be changed.
In the Fig. 1(a) we present the results of both SSL algorithms and compare them with the best performing transfer learning setting. We find that the SSL approaches outperform transfer learning on all fractions of labelled data. The gap between SSL and transfer learning widens significantly for smaller fractions of labelled data. With only 10 labelled representatives per class, the FixMatch achieves an accuracy of over 86%, while transfer learning reaches only 59%. We also see, that with about 2000-4000 labels all methods achieve almost perfect performance. The Fix-Match algorithm also outperforms Mix-Match in almost all settings and with only 50 labelled points per class achieves the accuracy of 98.14%. We also observe a small SSL performance drop for 25 labelled images per class – mainly because methods require even more epochs to fit (we employ a heuristical formula for defining the maximum number of epochs based on the number of labels, see Appendix B.2, 3).
|Kermany et al. ||All||96.6%||Original paper|
|Alqudah ||All||97.1%||\pbox15cm Extended UCSD|
|with 5 classes|
|Wu et al. ||All||97.5%|
|Chetoui et al. ||All||98.46%|
|Tsuji et al.||All||99.6%|
|(our backbone)||All||99.69%||\pbox15cm With EMA decay|
|He et al. ||835||87.251.44% *||*Average precision|
Finally, since practitioners have often to deal with the resource constraints and actual running times are rarely reported in the literature, we report them in Table 2. Note, that all methods are implemented in the same framework and the experiments are done on the same machine with two Tesla V100 Nvidia GPUs. To use the same batch size as recommended in the original publications, we have used both GPUs to train Fix-Match. Other models are trained on a single GPU.
|Transfer Learning||10m||9m||12m||15m||24m||39m||1h 39m|
|Mix-Match||1d 16h 5m||9h 12m||6h 13m||2h 30m||2h 37m||2h 24m||2h 26m|
|FixMatch||5d 9h 36m||1d 19h 4m||1d 40m||9h 58m||10h 40m||9h 50m||7h 51m|
4.2 EMA decay
The EMA is inherent part of Fix-Match algorithm and is also optionally recommended for Mix-Match. We observe learning curve to be more stable for validation subset for all the models when models are trained using it. However, we assume that with the right chosen validation subset, the variability could be advantageous and one can find a better fit. Usage of EMA model causes additional computation and memory costs and as can be seen in Table 3 most of the time models without it perform better.
In this work, we have demonstrated the efficacy of MixMatch and FixMatch, when applied to an ophthalmological diagnostic problem on OCT data. The two algorithms were able to attain high accuracy, achieving well over 80% on as little as 40 labelled samples (i.e. ten per class). Both algorithms outperformed transfer learning in the few labelled data settings. This study emphasizes the use of SSL methods in the clinical adoption of AI. Although both MixMatch and FixMatch are more computationally expensive than transfer learning, the amount of labelling effort saved by using them is immense. With labelling being one of the biggest factors hindering clinical use of AI methodology, we argue that smarter use of the abundance of unlabeled data already present at the clinic will be a major strategy for overcoming this hurdle.
As part of future work, we propose to also compare SSL approach with the few-shot deep learning.
Acknowledgements.This work has been funded by the German Federal Ministry of Education and Research (BMBF) under Grant No. 01IS18036A and by the Bavarian Ministry for Economic Affairs, Infrastructure, Transport and Technology through the Center for Analytics-Data-Applications (ADA-Center) within the framework of “BAYERN DIGITAL II”. The authors of this work take full responsibilities for its content.
Appendix A MixMatch & FixMatch – algorithm details
The foundation of both MixMatch and FixMatch is consistency regularization – the idea that augmentations of the same data point should yield the same label. In this way, the model regularizes itself based on its predictions.
Let be the batch of labeled examples. denotes the set of weak augmentations and – strong augmentations. is the prediction of backbone classifier, parametrized by . denotes categorical cross-entropy and is unsupervised loss weight.
MixMatch employs only weak augmentations and MixUp . Let be the unlabeled data batch. The model outputs of random weak augmentations of the same unlabelled sample are treated as soft pseudo-labels . These soft pseudo-labels are averaged and sharpened with the temperature for each image in to yield a pseudo-label for that image. Then, images from both randomly augmented and are concatenated and shuffled, resulting in set . Afterwards, samples in and are weakly augmented and linearly interpolated with samples from . This results in and – “mixed-up” versions of augmented labelled and unlabelled batches. Coefficients of MixUp are sampled from distribution. The final loss is the sum of categorical cross-entropy for images from (supervised part) and Brier score for images (unsupervised part):
MixMatch linearly ramps up from 0 to its maximum after each batch to reduce the influence of unsupervised part during early stages of training.
FixMatch is a more simplified method. Unlabeled data batch is now -times bigger. Given the model’s prediction for a weakly augmented unlabelled sample , method yields hard pseudo-labels and . Afterwards, the model predicts labels for both a batch of weakly augmented labelled images and a batch of strongly augmented unlabelled images. Only the confident predictions for unlabelled samples are used in the final unsupervised part of the loss. They are filtered with the threshold . The loss of FixMatch is then the sum of two categorical cross-entropies for labelled and unlabelled images:
As we use for filtering confident pseudo-labels, we do not need the linear ramp-up for .
Appendix B Experiments
b.1 Transfer learning
We took a version of Wide ResNet-50-2 pre-trained on ImageNet from PyTorch.
optimizer weight decay
layers freezing Fine-tuning,
Feature extraction (see Section 3.1)
Further hyperparameters are kept fixed, namely we use Adam optimizer , , number of epochs = 50. Additionally, early stopping with the patience of 25 epochs was applied to avoid overfitting.
b.2 MixMatch & FixMatch
Hyperparameter fine-tuning for both SSL methods was two-fold: firstly, we fine-tuned more general parameters on 200 labelled samples () with respect to the validation loss (see Table 4). Secondly, for each specific , we tuned subset-size-dependent parameters.
The labeled batch size was for both algorithms. Additionally, we fix for FixMatch. We omit using cosine learning rate decay.
|Number of epochs|
Regarding secondary fine-tuning, after the increase of , each epoch becomes proportionally longer. Thus, we propose the following inverse formula to define the number of epochs:
where denotes total number of labelled batches, used while training.
While secondary fine-tuning, we vary:
for MixMatch /
- (2020) AOCT-net: a convolutional network automated classification of multiclass retinal diseases using spectral-domain optical coherence tomography images. Medical & biological engineering & computing 58 (1), pp. 41–53. Cited by: Table 1.
- (2019) Remixmatch: semi-supervised learning with distribution alignment and augmentation anchoring. arXiv preprint arXiv:1911.09785. Cited by: §2.
- (2019) Mixmatch: a holistic approach to semi-supervised learning. In Advances in Neural Information Processing Systems, pp. 5049–5059. Cited by: §1, §2, §3.2.
- (2020) Deep retinal diseases detection and explainability using oct images. In International Conference on Image Analysis and Recognition, pp. 358–366. Cited by: Table 1.
- (2019) RandAugment: practical automated data augmentation with a reduced search space. arXiv: Computer Vision and Pattern Recognition. Cited by: §3.2.
- (2020) A data-efficient approach for automated classification of oct images using generative adversarial network. IEEE Sensors Letters 4 (1), pp. 1–4. Cited by: §2.
- (2009) ImageNet: A Large-Scale Hierarchical Image Database. In CVPR09, Cited by: §2, §3.1.
- (2014) Decaf: a deep convolutional activation feature for generic visual recognition. In International conference on machine learning, pp. 647–655. Cited by: §2.
- (2014) Generative adversarial nets. In Advances in neural information processing systems, pp. 2672–2680. Cited by: §2.
- (2020) Retinal optical coherence tomography image classification with label smoothing generative adversarial network. Neurocomputing. Cited by: §2, Table 1.
- (2015) Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531. Cited by: §3.2.
- (2018) Identifying medical diagnoses and treatable diseases by image-based deep learning. Cell 172 (5), pp. 1122–1131. Cited by: §1, §3.1, Table 1, §4.
- (2014) Adam: a method for stochastic optimization. arXiv preprint arXiv:1412.6980. Cited by: §B.1.
- (2013) Pseudo-label: the simple and efficient semi-supervised learning method for deep neural networks. In Workshop on challenges in representation learning, ICML, Vol. 3. Cited by: §3.2.
- (2019) Semi-supervised adversarial learning for diabetic retinopathy screening. In International Workshop on Ophthalmic Medical Image Analysis, pp. 60–68. Cited by: §2.
- (2018) Semi-supervised automatic segmentation of layer and fluid region in retinal optical coherence tomography images using adversarial learning. IEEE Access 7, pp. 3046–3061. Cited by: §2.
- (2018) Virtual adversarial training: a regularization method for supervised and semi-supervised learning. IEEE transactions on pattern analysis and machine intelligence 41 (8), pp. 1979–1993. Cited by: §2.
- (2019) Realmix: towards realistic semi-supervised deep learning algorithms. arXiv preprint arXiv:1912.08766. Cited by: §2.
- (1983) A method for solving the convex programming problem with convergence rate o (1/k^ 2). In Dokl. akad. nauk Sssr, Vol. 269, pp. 543–547. Cited by: Table 4.
- (2018) Realistic evaluation of deep semi-supervised learning algorithms. In Advances in neural information processing systems, pp. 3235–3246. Cited by: §4.
- (2014) Learning and transferring mid-level image representations using convolutional neural networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1717–1724. Cited by: §2.
- (2016) Regularization with stochastic transformations and perturbations for deep semi-supervised learning. In Advances in neural information processing systems, pp. 1163–1171. Cited by: §3.2.
- (2019) Uncertainty guided semi-supervised segmentation of retinal layers in oct images. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 282–290. Cited by: §2.
- (2020) FixMatch: simplifying semi-supervised learning with consistency and confidence. ArXiv abs/2001.07685. Cited by: §1, §2, §3.2.
- (2018) A survey on deep transfer learning. In International conference on artificial neural networks, pp. 270–279. Cited by: §2.
- (2017) Mean teachers are better role models: weight-averaged consistency targets improve semi-supervised deep learning results. In Advances in neural information processing systems, pp. 1195–1204. Cited by: §2, §3.2.
- (2020) Classification of optical coherence tomography images using a capsule network. BMC ophthalmology 20 (1), pp. 1–9. Cited by: Table 1.
- (2020) Towards multi-center glaucoma OCT image screening with semi-supervised joint structure and function multi-task learning. Medical Image Analysis 63. External Links: Cited by: §2.
- (2020) AttenNet: deep attention based retinal disease classification in oct images. In International Conference on Multimedia Modeling, pp. 565–576. Cited by: Table 1.
- (2019) Unsupervised data augmentation for consistency training. arXiv: Learning. Cited by: §2.
- (2019) Retinopathy diagnosis using semi-supervised multi-channel generative adversarial network. In International Workshop on Ophthalmic Medical Image Analysis, pp. 182–190. Cited by: §2.
- (2014) How transferable are features in deep neural networks?. In Advances in neural information processing systems, pp. 3320–3328. Cited by: §2.
- (2016) Wide residual networks. CoRR abs/1605.07146. External Links: Cited by: §4.
- (2017) Mixup: beyond empirical risk minimization. arXiv preprint arXiv:1710.09412. Cited by: Appendix A, §3.2.