High Resolution Medical Image Analysis with Spatial Partitioning
Medical images such as 3D computerized tomography (CT) scans and pathology images, have to voxels/pixels. It is infeasible to train CNN models directly on such high resolution images, because of the memory limitation of a single GPU/TPU, and naïve data and model parallelism approaches do not work. Existing image analysis approaches alleviate this problem by cropping or down-sampling input images, which leads to complicated implementation and sub-optimal performance due to information loss. In this paper, we implement spatial partitioning, which internally distributes the input and output of convolutional layers across GPUs/TPUs. Our implementation is based on the Mesh-TensorFlow framework and the computation distribution is transparent to end users. With this technique, we train a 3D Unet on up to 512512512 resolution data. To the best of our knowledge, this is the first work for handling such high resolution images end-to-end.
Applying neural networks models Ronneberger et al. (2015); Chen et al. (2017); Mobadersany et al. (2018); Lin et al. (2017) on high resolution image data, such as computerized tomography (CT) scans, satellite imagery, or histopathology slides is very computational extensive. For example, CT scans are typically acquired with sub-millimeter resolution resulting in image data sizes of 512512512 or more voxels. Furthermore, higher and higher resolution medical images are being collected Ruddle et al. (2016). Assuming that neural activations are stored in half-precision floating point numbers (2 bytes) and the batch size is 8, a 1-layer Convolutional Neural Network (CNN) with 64 filters requires more than 137GB of GPU/TPU memory. To handle such high resolution input, existing models use a combination of down-sampling, dividing, and/or coarse-to-fine schemes Hou et al. (2016); Biswas et al. (2019); Li et al. (2018); Vorontsov et al. (2018); Chlebus et al. (2018). An obvious drawback of these methods is that potentially useful information such as contextual features, small pathological volumes or high resolution details are lost.
To scale up neural network models, the data parallelism approach divides an input batch of instances into sub-batches and distribute them across multiple GPU/TPUs. Model parallelism Dean et al. (2012) approaches split and distribute model parameters, typically network layers, across multiple GPU/TPUs. These approaches do not solve the problem, since each GPU/TPU still need to process at least one high resolution image which results in more than 16GB of neural activations (in some cases 32GB) in a single GPU/TPU. More recent approaches Jia et al. (2018); Shazeer et al. (2018) are able to split the input batch along multiple dimensions, in addition to the batch dimension, and distribute parts across GPU/TPUs. However, it is not straightforward to split input images and distribute non-overlapping patches, since convolutions may take input across multiple patches. Splitting images into overlapping patches is not computationally efficient, because for most network architectures the overlap is so large such that almost the entire image is broadcast to every GPU/TPU. To overcome this, we implement spatial partitioning with halo exchange in tensorflow-TPU, to Mesh-TensorFlow. Halo exchange is a process during which GPU/TPU devices exchange data (patch margins) before convolution operations Dryden et al. (2019).
We evaluate our approach on the Liver Tumor Segmentation (LiTS) benchmark Bilic et al. (2019) data (131 CT abdominal scans with liver segmentations). Directly training neural network models on high resolution images is prone to overfitting. To address this problem, we propose a synthesis-based data augmentation method for this application.
Our contributions are:
An open source framework111Code: github.com/tensorflow/mesh. Example under mesh/mesh_tensorflow/experimental. for training neural network models on high resolution images. It has the following advantages:
It supports training and evaluation on both GPUs and TPUs.
It is highly efficient: spatial partitioning adds around 5% to the total training time.
To the best of our knowledge, we are the first to train neural networks on 512512512 resolution CT scans, in an end-to-end fashion.
A data augmentation method for the liver tumor segmentation task using CT scans.
2 Mesh-TensorFlow with Spatial Partitioning and Halo Exchange
Mesh-TensorFlow Shazeer et al. (2018) is a framework for large scale data and model parallelism. Given a collection of computational devices (e.g. 8 GPUs), an end user defines how to map data dimensions to these devices. For example, splitting a batch of data along its batch dimension to 8 sub-batches and send each sub-batch to 1 GPU, is an 8-way data parallelism; splitting the parameters of a fully connected layer to 2 parts and send each part to 4 GPUs, is a 2-way model parallelism. Multiple such mappings can be defined for a model, and thus data and model parallelism can happen simultaneously. Mesh-TensorFlow internally transfers data across computational devices when necessary.
Mesh-TensorFlow is successful for training giant language models such as the transformer Vaswani et al. (2017). For image analysis tasks, high resolution images consume hundreds of gigabytes of GPU/TPU memory. To address this problem, one should split images along spatial dimensions to enable model parallelism in Mesh-TensorFlow. However, the current Mesh-TensorFlow framework does not support convolutional layers on spatially split images, due to the complexity and sliding-window nature of convolutions. Modeling a convolutional layer as many fully-connected layers in Mesh-TensorFlow is computationally unacceptable.
To enable convolutional layers on spatially split images, splitting the input into overlapping patches is not computationally efficient, since the overlap might be very large, when the network is deep. We propose to simply exchange margins of patches across computational devices, then pad the patches with received margins, and finally apply convolution. We illustrate this process in Fig. 1. With this method, we are able to train a 3D Unet model Ronneberger et al. (2015); Çiçek et al. (2016) on 512512512 resolution CT scans.
Data augmentation via Synthesizing
As opposed to training on 2D slices or 2D+ “slabs” of 3D scans, training on full resolution CT data directly is prone to overfitting, since there are fewer instances and more features per instance. To alleviate this problem, we propose a simple yet effective data augmentation method. We assume that on the training set, segmentation ground truth of liver and tumor are given. For each training image, we compute the intensity difference between tumor and non-tumor voxels in the liver region. We then “remove” tumor by subtracting the computed intensity difference. Finally, we synthesize tumor volumes by adding the computed intensity difference on random 3D volumes inside the liver. Boundaries around synthetic tumor are blurred (Fig. 2). Without this data augmentation method, the “Dice per case” scores (evaluation metrics are detailed in Sec. 3) drop at least 10%.
|Data resolution||Dice per case||Dice global|
|2D||5125125||0.4072 +/- 0.0081||0.6432 +/- 0.0579|
|3D||646464||0.2513 +/- 0.0024||0.5364 +/- 0.0108|
|128128128||0.3589 +/- 0.0147||0.6494 +/- 0.0445|
|256256256||0.4359 +/- 0.0126||0.5783 +/- 0.1654|
|512512512||0.4547 +/- 0.0475||0.7180 +/- 0.0446|
3 Experimental Results
We use the Liver Tumor Segmentation (LiTS) benchmark Bilic et al. (2019) for the evaluation of our implementation. This dataset has 131 3D CT images. We randomly use 99 images for training and the remaining 32 for validation. We train 3D Unet Çiçek et al. (2016) models on four 3D resolutions: 64, 128, 256, and 512. For the 64 resolution, we train a 3D Unet with 3 blocks (each block consists of 4 convolutional layers and 1 max-pooling) in the down-sampling (encoding) part. The up-sampling (decoding) part is symmetric to the down-sampling part. Each convolutional layer in the first block has 256 filters. The number of filters doubles after each max-pooling layer. For the resolution, we attach another block with 128 filters, right after the input layer. For the resolution, we further attach a block with 64 filters. Finally, for the 512 resolution, we attached a block with 32 filters. Thus networks at different resolutions have similar receptive field sizes. Since state-of-the-art methods on the LiTS benchmark work on 2D+ “slabs” instead of 3D data, we also test a 2D Unet on 5125125 resolution. The 2D Unet has the same architecture as the 3D Unet except 3D convolution/max-pooling layers are changed to 2D. We apply data augmentation stated in Sec. 2.
We train models on a cluster of TPUs. Each TPU has 2 cores. For and data, we use 2-way data parallelism and 16-way spatial partitioning on a TPU pod of 44 TPUs. For 256 data, we use 2-way data parallelism and 128-way spatial partitioning on a TPU pod of 816 TPUs. For 512 data, we use 2-way data parallelism and 256-way spatial partitioning on a TPU pod of 1616 TPUs. We use a batch size of 8 on 512 data, and 16 on other resolutions. We use the Adafactor Shazeer and Stern (2018) optimizer with a learning rate of 0.003. We use as the loss function Sudre et al. (2017).
We compute two dice scores as our evaluation criterion: dice per case, and dice global. For dice per case, we compute the dice score per validation image, then average the scores across all 32 validation scans. For dice global, we combine all scans as if there was only one scan and compute a dice score on the whole volume Bilic et al. (2019). We observe that the dice score fluctuates for the same network trained with different random initialization and different number of training iterations. Thus, we average the dice scores of 4 models (randomly initialize 2 times, then evaluate each model stopped at 2 different numbers of iterations).
Results and discussion
From the results in Tab. 1, we conclude that higher resolution data yields better Dice scores. Note that the results are not comparable to the results on the LiTS challenge leader board, due to different evaluation metrics. The Dice global scores have large standard deviations. This is because that a few CT scans in the validation set contain very large volumes of tumor, and hence prediction results of a few CT scans dominate the Dice global score.
Our method is computationally efficient. In our experiments, operations introduced by spatial partitioning (partitioning, reshaping, and halo exchange) together add around 5% to the total training time. In addition, more than 75% of the total training time is spend on the forward and backward pass of convolutions.
It is challenging to train convolutional neural networks on high resolution medical images, since naïve data and model parallelism methods cannot effectively reduce the per-GPU/TPU memory requirements. We contributed a new Mesh-TensorFlow based framework which is capable of handling images of any size. To the best of our knowledge, we are the first to train neural networks on 512512512 resolution CT scans, without significant computational overhead.
- The liver tumor segmentation benchmark (lits). arXiv preprint arXiv:1901.04056. Cited by: §1, Table 1, §3, §3.
- A smart system of 3d liver tumour segmentation. International Journal of Product Development 23 (2-3). Cited by: §1.
- DCAN: deep contour-aware networks for object instance segmentation from histology images. Medical Image Analysis. Cited by: §1.
- Deep learning based automatic liver tumor segmentation in ct with shape-based post-processing. Unpublished. Cited by: §1.
- 3D u-net: learning dense volumetric segmentation from sparse annotation. In MICCAI, Cited by: §2, §3.
- Large scale distributed deep networks. In NeurIPs, Cited by: §1.
- Improving strong-scaling of cnn training by exploiting finer-grained parallelism. arXiv preprint arXiv:1903.06681. Cited by: §1.
- Patch-based convolutional neural network for whole slide tissue image classification. In CVPR, Cited by: §1.
- Beyond data and model parallelism for deep neural networks. arXiv preprint arXiv:1807.05358. Cited by: §1.
- H-denseunet: hybrid densely connected unet for liver and tumor segmentation from ct volumes. IEEE transactions on medical imaging 37 (12). Cited by: §1.
- Focal loss for dense object detection. In ICCV, Cited by: §1.
- Predicting cancer outcomes from histology and genomics using convolutional networks. Proceedings of the National Academy of Sciences 115 (13). Cited by: §1.
- U-net: convolutional networks for biomedical image segmentation. In MICCAI, Cited by: §1, §2.
- The design and evaluation of interfaces for navigating gigapixel images in digital pathology. ACM Transactions on Computer-Human Interaction (TOCHI) 23 (1). Cited by: §1.
- Mesh-tensorflow: deep learning for supercomputers. In NeurIPs, Cited by: §1, §2.
- Adafactor: adaptive learning rates with sublinear memory cost. arXiv preprint arXiv:1804.04235. Cited by: §3.
- Generalised dice overlap as a deep learning loss function for highly unbalanced segmentations. In Deep learning in medical image analysis and multimodal learning for clinical decision support, Cited by: §3.
- Attention is all you need. In NeurIPs, Cited by: §2.
- Liver lesion segmentation informed by joint liver segmentation. In International Symposium on Biomedical Imaging (ISBI), Cited by: §1.