Two-Stream CNN with Loose Pair Training for Multi-modal AMD Categorization
This paper studies automated categorization of age-related macular degeneration (AMD) given a multi-modal input, which consists of a color fundus image and an optical coherence tomography (OCT) image from a specific eye. Previous work uses a traditional method, comprised of feature extraction and classifier training that cannot be optimized jointly. By contrast, we propose a two-stream convolutional neural network (CNN) that is end-to-end. The CNN’s fusion layer is tailored to the need of fusing information from the fundus and OCT streams. For generating more multi-modal training instances, we introduce Loose Pair training, where a fundus image and an OCT image are paired based on class labels rather than eyes. Moreover, for a visual interpretation of how the individual modalities make contributions, we extend the class activation mapping technique to the multi-modal scenario. Experiments on a real-world dataset collected from an outpatient clinic justify the viability of our proposal for multi-modal AMD categorization.
Keywords:AMD categorization, multi-modal, fundus, OCT, two-stream CNN
This paper targets at automated categorization of age-related macular degeneration (AMD). As a common macular disease among people over 50, AMD may cause blurred vision or even blindness if not treated in time . Depending on whether the retina contains choroidal neovascularization, AMD is classified into two subcategories, i.e., dry AMD (non-neovascular) and wet AMD (neovascular) . Due to different treatments, such a fine-grained classification is crucial. In the clinical practice, color fundus photography and optical coherence tomography (OCT) are used by an ophthalmologist to assess the condition of an eye. Not surprisingly, the lack of experienced ophthalmologists has driven the research towards automated AMD categorization based on either fundus images, OCT images or both.
The majority of previous works are based on a single modality, let it be color fundus images capturing the posterior pole [1, 3, 2, 6] or OCT images [12, 9, 14, 13, 10]. In , for instance, Burlina et al. employ a deep convolutional neural network (CNN) pretrained on ImageNet to extract visual features from fundus images and then train a linear SVM classifier. As for OCT-based methods, Lee et al.  train a VGG16 model to classify OCT images either as normal or as AMD. Since fundus images capture the state of the retinal plane, while OCT images reflect the longitudinal section of the retina, they describe distinct aspects of the retina and can thus be complementary to each other. While jointly exploiting the two modalities seems to be natural, this direction is largely unexplored. To the best of our knowledge, Yoo et al.  make an initial attempt towards multi-modal AMD categorization. Given a pair of fundus and OCT images from a specific eye, the authors employ a VGG19 model pretrained on ImageNet to extract visual features from both images. The features are concatenated and used as input of a random forest classifier. Despite their encouraging result that the multi-modal method is better than its single-modal counterpart, some crucial questions remain open.
Note that both the VGG19 features and the classifier, i.e., random forest, used in  are suboptimal in the context of deep learning based visual categorization. The following questions arise. First, when the single-modal baseline is re-implemented using a state-of-the-art CNN, say ResNet , in an end-to-end manner, is the multi-modal method by  still better? If the answer is negative, a follow-up question is can multi-modal AMD categorization be performed end-to-end as well? Training a deep network with multi-modal input is nontrivial because by definition, the number of paired multi-modal training instances is less than the number of single-modal training instances. Moreover, the method by  lacks the capability of interpreting how the individual modalities contribute to the final prediction.
Towards answering the above questions, we make contributions as follows.
We propose a two-stream CNN specifically designed for multi-modal AMD categorization, see Fig. 1. Two-stream CNNs have been actively investigated in the context of video action recognition . However, the fusion layer needs to be re-considered for the new task, not only for effectively combining the information from fundus and OCT images but also for visually interpreting their contributions.
To attack the inadequacy of multi-modal training instances, we introduce Loose Pair Training, a simple sampling strategy that effectively increases the number of training instances.
Experiments on real-world data collected from an outpatient clinic show the viability of the proposed method. The new method outperforms the state-of-the-art  with a large margin, i.e., 0.971 versus 0.826 in terms of overall accuracy, for multi-modal AMD categorization.
2 Our Method
Given a color fundus image and an OCT image taken from a specific eye, we aim to build a multi-modal CNN (MM-CNN) that takes the paired input and categorizes the eye’s condition to a specific class :
2.1 Multi-modal CNN
Network architecture. To handle the multi-modal input, we design a two-stream network as illustrated in Fig. 1. It consists of two symmetric branches, one for processing the fundus image and the other for processing the OCT image . Note that such an architecture resembles to some extent the two-stream network widely used for video action recognition . The major difference is at which layer multi-modal fusion is performed. Feature maps generated by intermediate layers of a CNN preserves, to some extent, the spatial information of an input image. As different streams of video data are spatially correlated, the state-of-the-art for video action recognition performs fusion by combining feature maps from the individual streams . By contrast, as and are not spatially correlated, we opt to perform the fusion after the global average pooling (GAP) layer, which removes the spatial information by averaging each feature map into a single value.
For each branch, we use convolutional blocks of ResNet-18 . In principle, any other state-of-the-art CNN can be used here. We choose ResNet-18 as it has fewer parameters and thus requires less training data. Also, this CNN is shown to be effective for other fundus image analysis tasks . For an OCT image, we convert each of its pixels from grayscale to RGB by duplicating the intensity for each RGB component. As such, the same architecture and initialization are applied to both branches.
Let be an array of feature maps generated by the ResNet-18 module in the fundus branch. The value of depends on the size of the input, which is for an input size of . Given a specific feature map , the value of a specific position is acquired as . In a similar vein, we define the feature maps for the OCT branch as .
Our fusion layer is implemented by first feeding separately and into a GAP layer to obtain two vectors, denoted as and , respectively. The two vectors are then concatenated to form a vector which contains information from the two modalities. For classification, the combined vector is fed into a fully connected (FC) layer to produce a score for a specific class , denoted as ,
where and are class-dependent weights parameterizing the FC layer. Classification as expressed in Eq. 1 is achieved by selecting the class with the maximum score.
Multi-modal class activation mapping for visual interpretation. As Eq. 2 shows, the classification score for a given class is additively contributed by both modalities. For a more intuitive interpretation, we leverage class activation mapping (CAM) , which reveals the (implicit) attention of a CNN on an input image. We compute the multi-modal version of CAMs as
According to Eq. 4, and indicate the contribution of a specific position of the fundus and OCT images, respectively. Consequently, the contribution of each modality can be visualized by overlaying with the corresponding up-sampled CAM, see Fig. 2.
2.2 Network Training
A conventional way to construct a multi-modal training instance is to strictly select a fundus image and an OCT image from the same eye, which we term strict pairing. By contrast, we construct instances based on labels instead of eyes. That is, a fundus image is allowed to be paired with an OCT image if their labels are identical. We coin this sampling strategy Loose Pairing. Such a strategy expands the size of the training set quadratically. Note that loose pairing is applied only on the training data.
All fundus and OCT images are resized to . As the input of the pretrained ResNet-18 model is , we adjust the kernel size of the GAP layer from to . Following , we enhance fundus images by contrast-limited adaptive histogram equalization. Meanwhile, median filtering is applied on OCT images for noise reduction. For image-level data augmentation, random rotation, crop, flip and random changes in brightness, saturation and contrast are performed on training images.
Our deep models are implemented in the PyTorch (version 1.0.0) framework. ResNet-18 was pretrained on ImageNet. We use cross-entropy, a common loss function for multi-class classification. SGD with momentum of 0.9 and weight decay of 1e-4 is used as the optimizer. Each convolution layer is followed by batch normalization. No dropout is used. The model that obtaining the best validation performance is selected.
3.1 Experimental Setup
Dataset for multi-modal AMD categorization. We collect 1,059 color fundus images from 1,059 distinct eyes at the outpatient clinic of the Department of Ophthalmology, Peking Union Medical College Hospital. That is, one fundus image per eye. For 781 eyes, they are associated with one to five OCT images, which are central B-scans manually selected by technicians. The fundus images were acquired from a Topcon fundus camera, while OCT images came from a Topcon OCT camera and a Heidelberg OCT camera. For each eye, two ophthalmologists jointly classify its condition as normal, dryAMD or wetAMD, by examining the corresponding fundus image plus OCT, fluorescein angiography (FA) or indocyanine green angiography (ICGA) images, if applicable. Fundus and OCT images associated with a specific eye are assigned with the same class.
In order to build a multi-modal test set, per class we select 20 eyes at random from the eyes that have both fundus and OCT images available. Such a setting allows us to justify the effectiveness of multi-modal input against its single-modal counterpart. Moreover, it enables a head-to-head comparison between the two single modalities, i.e., fundus versus OCT. In a similar vein, we construct a multi-modal validation set from the remaining data for model selection. All the rest is used for training. Table 1 shows data statistics.
|Training images||Validation images||Test images|
|normal||155 (155)||156 (155)||20 (20)||20 (20)||20 (20)||20 (20)|
|dryAMD||67 ( 67)||33 ( 22)||20 (20)||35 (20)||20 (20)||38 (20)|
|wetAMD||717 (717)||821 (484)||20 (20)||42 (20)||20 (20)||46 (20)|
Performance metrics. Per class we report three metrics, i.e., sensitivity, specificity and F1 score defined as the harmonic mean between sensitivity and specificity. For an overall comparison, the average F1 score over the three classes is used. In addition, we report accuracy, computed as the ratio of correctly classified instances (which are fundus or OCT images for single-modal CNNs and fundus-OCT pairs for MM-CNNs).
3.2 Experiment 1. Multi-modal versus Single-modal
Single-modal baselines. For single-modal models, we train two ResNet-18 on the fundus images and the OCT images, respectively. For the ease of reference we term the two models Fundus-CNN and OCT-CNN.
Results. As Table 2 shows, OCT-CNN is on par with MM-CNN-S, which is trained on the strict pairs. The result suggests that training an effective multi-modal model requires more training data. The proposed loose pair training strategy is effective, resulting in MM-CNN-L that presenting the best performance.
Comparing the two single-modal models, OCT-CNN is better than Fundus-CNN (0.942 versus 0.879 in terms of the overall F1). Confusion matrices are provided in the supplementary material. While the two single-modal CNNs recognize the normal class with ease, they tend to misclassify dryAMD as wetAMD. Such mistakes are reduced by MM-CNN-L. The above results justify the advantage of multi-modal models for AMD categorization.
|Yoo et al. ||1.000||0.976||0.952||0.552||1.000||0.711||0.978||0.724||0.841||0.835||0.826|
|Yoo et al. -L||1.000||0.988||0.975||0.763||0.954||0.828||0.913||0.844||0.866||0.890||0.875|
3.3 Experiments 2. Comparison with the State-of-the-art
Multi-modal baselines. As aforementioned, the only existing work on multi-modal AMD categorization is by Yoo et al. , where the authors employ a VGGNet pretrained on ImageNet to extract visual features from fundus and OCT images and then train a random forest classifier on strictly matched pairs. We therefore consider that work as our multi-modal baseline. As their data is not fully available, we replicate their method and evaluate on our test set. For a fair comparison, we substitute ResNet-18 for VGGNet. Moreover, we investigate if the proposed loose pair strategy is also beneficial for the baseline. So we train another random forest with loose pairs. We term this variant Yoo et al. -L.
Results. As Table 2 shows, MM-CNN-L outperforms the baseline with a large margin (0.975 versus 0.835 in terms of overall F1). The two single-modal baselines outperform Yoo et al. . These results justify the necessity of end-to-end learning. The loose pair training strategy is found to be useful for the baseline also, improving its overall F1 from 0.835 to 0.890.
Multi-modal AMD categorization experiments on a clinical dataset allow us to answer the questions asked in Section 1 as follows. When end-to-end trained, a single-modal CNN, in particular OCT-CNN, is a nontrivial baseline to beat. Multi-modal CNN recognizes dry AMD and wet AMD at a higher accuracy. This advantage is obtained by the proposed two-stream CNN with loose pair training.
Acknowledgments. This work was supported by NSFC (No. 61672523), the Fundamental Research Funds for the Central Universities and the Research Funds of Renmin University of China (No. 18XNLG19), and CAMS Initiative for Innovative Medicine (No. 2018-I2M-AI-001).
-  Burlina, P., Freund, D.E., Joshi, N., Wolfson, Y., Bressler, N.M.: Detection of age-related macular degeneration via deep learning. In: ISBI (2016)
-  Burlina, P.M., Joshi, N., Pekala, M., Pacheco, K.D., Freund, D.E., Bressler, N.M.: Automated grading of age-related macular degeneration from color fundus images using deep convolutional neural networks. JAMA Ophthalmology 135(11), 1170–1176 (2017)
-  Burlina, P., Pacheco, K.D., Joshi, N., Freund, D.E., Bressler, N.M.: Comparing humans and deep learning performance for grading AMD: A study in using universal deep features and transfer learning for automated amd analysis. Computers in Biology & Medicine 82, 80–86 (2017)
-  Feichtenhofer, C., Pinz, A., Zisserman, A.: Convolutional two-stream network fusion for video action recognition. In: CVPR (2016)
-  Ferris, F.L., Wilkinson, C.P., Alan, B., Usha, C., Emily, C., Karl, C., Sadda, S.V.R.: Clinical classification of age-related macular degeneration. Ophthalmology 120(4), 844–851 (2013)
-  Grassmann, F., Mengelkamp, J., Brandl, C., Harsch, S., Zimmermann, M.E., Linkohr, B., Peters, A., Heid, I.M., Palm, C., Weber, B.H.: A deep learning algorithm for prediction of age-related eye disease study severity scale for age-related macular degeneration from color fundus photography. Ophthalmology 125(9), 1410–1420 (2018)
-  He, K., Zhang, X., Ren, S., Jian, S.: Deep residual learning for image recognition. In: CVPR (2016)
-  Jintasuttisak, T., Intajag, S.: Color retinal image enhancement by rayleigh contrast-limited adaptive histogram equalization. In: ICCAS (2014)
-  Karri, S.P.K., Chakraborty, D., Chatterjee, J.: Transfer learning based classification of optical coherence tomography images with diabetic macular edema and dry age-related macular degeneration. Biomedical Optics Express 8(2), 579–592 (2017)
-  Kermany, D.S., Goldbaum, M., Cai, W., Valentim, C.C.S., Liang, H., Baxter, S.L., Mckeown, A., Ge, Y., Wu, X., Yan, F.: Identifying medical diagnoses and treatable diseases by image-based deep learning. Cell 172(5), 1122–1131.e9 (2018)
-  Lai, X., Li, X., Qian, R., Ding, D., Wu, J., Xu, J.: Four models for automatic recognition of left and right eye in fundus images. In: MMM (2019)
-  Lee, C.S., Baughman, D.M., Lee, A.Y.: Deep learning is effective for classifying normal versus age-related macular degeneration oct images. Ophthalmology Retina 1(4), 322–327 (2017)
-  Russakoff, D.B., Lamin, A., Oakley, J.D., Dubis, A.M., Sivaprasad, S.: Deep learning for prediction of amd progression: A pilot study. Investigative ophthalmology & visual science 60(2), 712–722 (2019)
-  Treder, M., Lauermann, J.L., Eter, N.: Automated detection of exudative age-related macular degeneration in spectral domain optical coherence tomography using deep learning. Graefe’s Archive for Clinical and Experimental Ophthalmology 256(2), 259–265 (2018)
-  Wong, W.L., Su, X., Li, X., Cheung, C.M.G., Klein, R., Cheng, C.Y., Wong, T.Y.: Global prevalence of age-related macular degeneration and disease burden projection for 2020 and 2040: a systematic review and meta-analysis. The Lancet Global Health 2(2), e106–e116 (2014)
-  Yoo, T.K., Choi, J.Y., Seo, J.G., Ramasubramanian, B., Selvaperumal, S., Kim, D.W.: The possibility of the combination of oct and fundus images for improving the diagnostic accuracy of deep learning for age-related macular degeneration: a preliminary experiment. Medical & Biological Engineering & Computing 57(3), 677–687 (2019)
-  Zhou, B., Khosla, A., Lapedriza, A., Oliva, A., Torralba, A.: Learning deep features for discriminative localization. In: CVPR (2016)