High Resolution Medical Image Analysis with Spatial Partitioning

High Resolution Medical Image Analysis with Spatial Partitioning

Le Hou  Youlong Cheng  Noam Shazeer  Niki Parmar  Yeqing Li
Panagiotis Korfiatis  Travis M. Drucker  Daniel J. Blezek  Xiaodan Song  
Google Brain, Mountain View, CA
Department of Computer Science, Stony Brook University, Stony Brook, NY
Department of Radiology, Mayo Clinic, Rochester MN
Enterprise Architecture, Mayo Clinic, Rochester MN

Medical images such as 3D computerized tomography (CT) scans and pathology images, have to voxels/pixels. It is infeasible to train CNN models directly on such high resolution images, because of the memory limitation of a single GPU/TPU, and naïve data and model parallelism approaches do not work. Existing image analysis approaches alleviate this problem by cropping or down-sampling input images, which leads to complicated implementation and sub-optimal performance due to information loss. In this paper, we implement spatial partitioning, which internally distributes the input and output of convolutional layers across GPUs/TPUs. Our implementation is based on the Mesh-TensorFlow framework and the computation distribution is transparent to end users. With this technique, we train a 3D Unet on up to 512512512 resolution data. To the best of our knowledge, this is the first work for handling such high resolution images end-to-end.

1 Introduction

Applying neural networks models Ronneberger et al. (2015); Chen et al. (2017); Mobadersany et al. (2018); Lin et al. (2017) on high resolution image data, such as computerized tomography (CT) scans, satellite imagery, or histopathology slides is very computational extensive. For example, CT scans are typically acquired with sub-millimeter resolution resulting in image data sizes of 512512512 or more voxels. Furthermore, higher and higher resolution medical images are being collected Ruddle et al. (2016). Assuming that neural activations are stored in half-precision floating point numbers (2 bytes) and the batch size is 8, a 1-layer Convolutional Neural Network (CNN) with 64 filters requires more than 137GB of GPU/TPU memory. To handle such high resolution input, existing models use a combination of down-sampling, dividing, and/or coarse-to-fine schemes Hou et al. (2016); Biswas et al. (2019); Li et al. (2018); Vorontsov et al. (2018); Chlebus et al. (2018). An obvious drawback of these methods is that potentially useful information such as contextual features, small pathological volumes or high resolution details are lost.

To scale up neural network models, the data parallelism approach divides an input batch of instances into sub-batches and distribute them across multiple GPU/TPUs. Model parallelism Dean et al. (2012) approaches split and distribute model parameters, typically network layers, across multiple GPU/TPUs. These approaches do not solve the problem, since each GPU/TPU still need to process at least one high resolution image which results in more than 16GB of neural activations (in some cases 32GB) in a single GPU/TPU. More recent approaches Jia et al. (2018); Shazeer et al. (2018) are able to split the input batch along multiple dimensions, in addition to the batch dimension, and distribute parts across GPU/TPUs. However, it is not straightforward to split input images and distribute non-overlapping patches, since convolutions may take input across multiple patches. Splitting images into overlapping patches is not computationally efficient, because for most network architectures the overlap is so large such that almost the entire image is broadcast to every GPU/TPU. To overcome this, we implement spatial partitioning with halo exchange in tensorflow-TPU, to Mesh-TensorFlow. Halo exchange is a process during which GPU/TPU devices exchange data (patch margins) before convolution operations Dryden et al. (2019).

We evaluate our approach on the Liver Tumor Segmentation (LiTS) benchmark Bilic et al. (2019) data (131 CT abdominal scans with liver segmentations). Directly training neural network models on high resolution images is prone to overfitting. To address this problem, we propose a synthesis-based data augmentation method for this application.

Our contributions are:

  1. An open source framework111Code: github.com/tensorflow/mesh. Example under mesh/mesh_tensorflow/experimental. for training neural network models on high resolution images. It has the following advantages:

    1. It supports training and evaluation on both GPUs and TPUs.

    2. It is highly efficient: spatial partitioning adds around 5% to the total training time.

    To the best of our knowledge, we are the first to train neural networks on 512512512 resolution CT scans, in an end-to-end fashion.

  2. A data augmentation method for the liver tumor segmentation task using CT scans.

Figure 1: An illustration of spatial partitioning and halo exchange. (a). Spatial partitioning: we split very high resolution medical images into non-overlapping patches. Each computational device (GPU/TPU) processes one or more patches. (b). Halo exchange: before every convolution operation, devices exchange patch margins (half the size of the convolution kernel) with each other.

2 Mesh-TensorFlow with Spatial Partitioning and Halo Exchange

Mesh-TensorFlow Shazeer et al. (2018) is a framework for large scale data and model parallelism. Given a collection of computational devices (e.g. 8 GPUs), an end user defines how to map data dimensions to these devices. For example, splitting a batch of data along its batch dimension to 8 sub-batches and send each sub-batch to 1 GPU, is an 8-way data parallelism; splitting the parameters of a fully connected layer to 2 parts and send each part to 4 GPUs, is a 2-way model parallelism. Multiple such mappings can be defined for a model, and thus data and model parallelism can happen simultaneously. Mesh-TensorFlow internally transfers data across computational devices when necessary.

Mesh-TensorFlow is successful for training giant language models such as the transformer Vaswani et al. (2017). For image analysis tasks, high resolution images consume hundreds of gigabytes of GPU/TPU memory. To address this problem, one should split images along spatial dimensions to enable model parallelism in Mesh-TensorFlow. However, the current Mesh-TensorFlow framework does not support convolutional layers on spatially split images, due to the complexity and sliding-window nature of convolutions. Modeling a convolutional layer as many fully-connected layers in Mesh-TensorFlow is computationally unacceptable.

To enable convolutional layers on spatially split images, splitting the input into overlapping patches is not computationally efficient, since the overlap might be very large, when the network is deep. We propose to simply exchange margins of patches across computational devices, then pad the patches with received margins, and finally apply convolution. We illustrate this process in Fig. 1. With this method, we are able to train a 3D Unet model Ronneberger et al. (2015); Çiçek et al. (2016) on 512512512 resolution CT scans.

Figure 2: Given a real 3D CT scan with its ground truth mask (example of a 2d slice shown in (a)), we generate new training examples with corresponding masks (examples of 2d slices shown in (b)).

Data augmentation via Synthesizing

As opposed to training on 2D slices or 2D+ “slabs” of 3D scans, training on full resolution CT data directly is prone to overfitting, since there are fewer instances and more features per instance. To alleviate this problem, we propose a simple yet effective data augmentation method. We assume that on the training set, segmentation ground truth of liver and tumor are given. For each training image, we compute the intensity difference between tumor and non-tumor voxels in the liver region. We then “remove” tumor by subtracting the computed intensity difference. Finally, we synthesize tumor volumes by adding the computed intensity difference on random 3D volumes inside the liver. Boundaries around synthetic tumor are blurred (Fig. 2). Without this data augmentation method, the “Dice per case” scores (evaluation metrics are detailed in Sec. 3) drop at least 10%.

Data resolution Dice per case Dice global
2D 5125125 0.4072 +/- 0.0081 0.6432 +/- 0.0579
3D 646464 0.2513 +/- 0.0024 0.5364 +/- 0.0108
128128128 0.3589 +/- 0.0147 0.6494 +/- 0.0445
256256256 0.4359 +/- 0.0126 0.5783 +/- 0.1654
512512512 0.4547 +/- 0.0475 0.7180 +/- 0.0446
Table 1: Validation results on the LiTS benchmark Bilic et al. (2019). We compute dice scores of 4 models and show their mean +/- standard deviation below. Note that the results are not directly comparable to the results on the LiTS challenge leader board, due to different evaluation metrics.

3 Experimental Results

We use the Liver Tumor Segmentation (LiTS) benchmark Bilic et al. (2019) for the evaluation of our implementation. This dataset has 131 3D CT images. We randomly use 99 images for training and the remaining 32 for validation. We train 3D Unet Çiçek et al. (2016) models on four 3D resolutions: 64, 128, 256, and 512. For the 64 resolution, we train a 3D Unet with 3 blocks (each block consists of 4 convolutional layers and 1 max-pooling) in the down-sampling (encoding) part. The up-sampling (decoding) part is symmetric to the down-sampling part. Each convolutional layer in the first block has 256 filters. The number of filters doubles after each max-pooling layer. For the resolution, we attach another block with 128 filters, right after the input layer. For the resolution, we further attach a block with 64 filters. Finally, for the 512 resolution, we attached a block with 32 filters. Thus networks at different resolutions have similar receptive field sizes. Since state-of-the-art methods on the LiTS benchmark work on 2D+ “slabs” instead of 3D data, we also test a 2D Unet on 5125125 resolution. The 2D Unet has the same architecture as the 3D Unet except 3D convolution/max-pooling layers are changed to 2D. We apply data augmentation stated in Sec. 2.

We train models on a cluster of TPUs. Each TPU has 2 cores. For and data, we use 2-way data parallelism and 16-way spatial partitioning on a TPU pod of 44 TPUs. For 256 data, we use 2-way data parallelism and 128-way spatial partitioning on a TPU pod of 816 TPUs. For 512 data, we use 2-way data parallelism and 256-way spatial partitioning on a TPU pod of 1616 TPUs. We use a batch size of 8 on 512 data, and 16 on other resolutions. We use the Adafactor Shazeer and Stern (2018) optimizer with a learning rate of 0.003. We use as the loss function Sudre et al. (2017).

Evaluation metrics

We compute two dice scores as our evaluation criterion: dice per case, and dice global. For dice per case, we compute the dice score per validation image, then average the scores across all 32 validation scans. For dice global, we combine all scans as if there was only one scan and compute a dice score on the whole volume Bilic et al. (2019). We observe that the dice score fluctuates for the same network trained with different random initialization and different number of training iterations. Thus, we average the dice scores of 4 models (randomly initialize 2 times, then evaluate each model stopped at 2 different numbers of iterations).

Results and discussion

From the results in Tab. 1, we conclude that higher resolution data yields better Dice scores. Note that the results are not comparable to the results on the LiTS challenge leader board, due to different evaluation metrics. The Dice global scores have large standard deviations. This is because that a few CT scans in the validation set contain very large volumes of tumor, and hence prediction results of a few CT scans dominate the Dice global score.

Computational efficiency

Our method is computationally efficient. In our experiments, operations introduced by spatial partitioning (partitioning, reshaping, and halo exchange) together add around 5% to the total training time. In addition, more than 75% of the total training time is spend on the forward and backward pass of convolutions.

4 Conclusions

It is challenging to train convolutional neural networks on high resolution medical images, since naïve data and model parallelism methods cannot effectively reduce the per-GPU/TPU memory requirements. We contributed a new Mesh-TensorFlow based framework which is capable of handling images of any size. To the best of our knowledge, we are the first to train neural networks on 512512512 resolution CT scans, without significant computational overhead.


  • P. Bilic, P. F. Christ, E. Vorontsov, G. Chlebus, H. Chen, Q. Dou, C. Fu, X. Han, P. Heng, J. Hesser, et al. (2019) The liver tumor segmentation benchmark (lits). arXiv preprint arXiv:1901.04056. Cited by: §1, Table 1, §3, §3.
  • A. Biswas, P. Bhattacharya, and S. P. Maity (2019) A smart system of 3d liver tumour segmentation. International Journal of Product Development 23 (2-3). Cited by: §1.
  • H. Chen, X. Qi, L. Yu, Q. Dou, J. Qin, and P. Heng (2017) DCAN: deep contour-aware networks for object instance segmentation from histology images. Medical Image Analysis. Cited by: §1.
  • G. Chlebus, A. Schenk, J. H. Moltz, B. van Ginneken, H. K. Hahn, and H. Meine (2018) Deep learning based automatic liver tumor segmentation in ct with shape-based post-processing. Unpublished. Cited by: §1.
  • Ö. Çiçek, A. Abdulkadir, S. S. Lienkamp, T. Brox, and O. Ronneberger (2016) 3D u-net: learning dense volumetric segmentation from sparse annotation. In MICCAI, Cited by: §2, §3.
  • J. Dean, G. Corrado, R. Monga, K. Chen, M. Devin, M. Mao, A. Senior, P. Tucker, K. Yang, Q. V. Le, et al. (2012) Large scale distributed deep networks. In NeurIPs, Cited by: §1.
  • N. Dryden, N. Maruyama, T. Benson, T. Moon, M. Snir, and B. Van Essen (2019) Improving strong-scaling of cnn training by exploiting finer-grained parallelism. arXiv preprint arXiv:1903.06681. Cited by: §1.
  • L. Hou, D. Samaras, T. M. Kurc, Y. Gao, J. E. Davis, and J. H. Saltz (2016) Patch-based convolutional neural network for whole slide tissue image classification. In CVPR, Cited by: §1.
  • Z. Jia, M. Zaharia, and A. Aiken (2018) Beyond data and model parallelism for deep neural networks. arXiv preprint arXiv:1807.05358. Cited by: §1.
  • X. Li, H. Chen, X. Qi, Q. Dou, C. Fu, and P. Heng (2018) H-denseunet: hybrid densely connected unet for liver and tumor segmentation from ct volumes. IEEE transactions on medical imaging 37 (12). Cited by: §1.
  • T. Lin, P. Goyal, R. Girshick, K. He, and P. Dollár (2017) Focal loss for dense object detection. In ICCV, Cited by: §1.
  • P. Mobadersany, S. Yousefi, M. Amgad, D. A. Gutman, J. S. Barnholtz-Sloan, J. E. V. Vega, D. J. Brat, and L. A. Cooper (2018) Predicting cancer outcomes from histology and genomics using convolutional networks. Proceedings of the National Academy of Sciences 115 (13). Cited by: §1.
  • O. Ronneberger, P. Fischer, and T. Brox (2015) U-net: convolutional networks for biomedical image segmentation. In MICCAI, Cited by: §1, §2.
  • R. A. Ruddle, R. G. Thomas, R. Randell, P. Quirke, and D. Treanor (2016) The design and evaluation of interfaces for navigating gigapixel images in digital pathology. ACM Transactions on Computer-Human Interaction (TOCHI) 23 (1). Cited by: §1.
  • N. Shazeer, Y. Cheng, N. Parmar, D. Tran, A. Vaswani, P. Koanantakool, P. Hawkins, H. Lee, M. Hong, C. Young, et al. (2018) Mesh-tensorflow: deep learning for supercomputers. In NeurIPs, Cited by: §1, §2.
  • N. Shazeer and M. Stern (2018) Adafactor: adaptive learning rates with sublinear memory cost. arXiv preprint arXiv:1804.04235. Cited by: §3.
  • C. H. Sudre, W. Li, T. Vercauteren, S. Ourselin, and M. J. Cardoso (2017) Generalised dice overlap as a deep learning loss function for highly unbalanced segmentations. In Deep learning in medical image analysis and multimodal learning for clinical decision support, Cited by: §3.
  • A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin (2017) Attention is all you need. In NeurIPs, Cited by: §2.
  • E. Vorontsov, A. Tang, C. Pal, and S. Kadoury (2018) Liver lesion segmentation informed by joint liver segmentation. In International Symposium on Biomedical Imaging (ISBI), Cited by: §1.
Comments 0
Request Comment
You are adding the first comment!
How to quickly get a good reply:
  • Give credit where it’s due by listing out the positive aspects of a paper before getting into which changes should be made.
  • Be specific in your critique, and provide supporting evidence with appropriate references to substantiate general statements.
  • Your comment should inspire ideas to flow and help the author improves the paper.

The better we are at sharing our knowledge with each other, the faster we move forward.
The feedback must be of minimum 40 characters and the title a minimum of 5 characters
Add comment
Loading ...
This is a comment super asjknd jkasnjk adsnkj
The feedback must be of minumum 40 characters
The feedback must be of minumum 40 characters

You are asking your first question!
How to quickly get a good answer:
  • Keep your question short and to the point
  • Check for grammar or spelling errors.
  • Phrase it like a question
Test description