The increasing efficiency and compactness of deep learning architectures, together with hardware improvements, have enabled the complex and high-dimensional modelling of medical volumetric data at higher resolutions. Recently, Vector-Quantised Variational Autoencoders (VQ-VAE) have been proposed as an efficient generative unsupervised learning approach that can encode images to a small percentage of their initial size, while preserving their decoded fidelity. Here, we show a VQ-VAE inspired network can efficiently encode a full-resolution 3D brain volume, compressing the data to of the original size while maintaining image fidelity, and significantly outperforming the previous state-of-the-art. We then demonstrate that VQ-VAE decoded images preserve the morphological characteristics of the original data through voxel-based morphology and segmentation experiments. Lastly, we show that such models can be pre-trained and then fine-tuned on different datasets without the introduction of bias.
– Under Review
\jmlrworkshopFull Paper – MIDL 2020 submission
\editorsUnder Review for MIDL 2020
Neuromorphologicaly-preserving Encoding]Neuromorphologicaly-preserving Volumetric data encoding using VQ-VAE
\midlauthor\NamePetru-Daniel Tudosiu\nametag \Emailpetru.firstname.lastname@example.org
\NameThomas Varsavsky\nametag \Emailthomas.email@example.com
\NameRichard Shaw\nametag \Emailrichard.firstname.lastname@example.org
\NameMark Graham\nametag \Emailmark.email@example.com
\NameParashkev Nachev\nametag \Emailp.firstname.lastname@example.org
\NameSébastien Ourselin\nametag \Emailsebastien.email@example.com
\NameCarole H. Sudre\nametag \Emailcarole.firstname.lastname@example.org
\NameM. Jorge Cardoso\nametag \Emailm.email@example.com
\addr School of Biomedical Engineering & Imaging Sciences, Kingâs College London, London, United Kingdom
\addr Department of Medical Physics & Biomedical Engineering, University College London, London, United Kingdom
\addr Institute of Neurology, University College London, London, United Kingdom
D, MRI, Morphology, Encoding, VQ-VAE
It is well known that Convolutional Neural Networks (CNN) excel at a myriad of computer vision (CV) tasks such as segmentation [29, 18], depth estimation  and classification . Such success stems from a combination of dataset size, improved compute capabilities, and associated software, making it ideal for tackling problems with 2D input imaging data. Medical data, however, has limited data availability due to privacy and cost, and causes increased model complexity due to the volumetric nature of the information, making its modelling non-trivial. Unsupervised generative models of 2D imaging data have recently shown excellent results on classical computer vision datasets [35, 7, 28, 15, 10], with some promising early results on 2D medical imaging data [4, 14, 21, 22]. However, approaches on 3D medical data have shown limited sample fidelity [19, 9, 39], and have yet to demonstrate that they are morphology preserving.
From an architectural and modelling point of view, Generative Adversarial Networks (GAN)  are known to have a wide range of caveats that hinder both their training and reproducibility. Convergence issues caused by problematic generator-discriminator interactions, mode collapse that results in a very limited variety of samples, and vanishing gradients due to non-optimal discriminator performance are some of the known problems with this technique. Variational Autoencoders (VAE) , on the other hand, can mitigate some convergence issues but are known to have problems reconstructing high frequency features, thus resulting in low fidelity samples. The Vector Quantised-VAE (VQ-VAE)  was introduced by Oord et al. with the aim of improving VAE sample fidelity while avoiding the mode collapse and convergence issues of GANs. VQ-VAEs replace the VAE Gaussian prior with a vector quantization procedure that limits the dimensionality of the encoding to the amount of atomic elements in a dictionary, which is learned either via gradient propagation or an exponential moving average (EMA). Due to the lack of an explicit Gaussian prior, sampling of VQ-VAEs can be achieved through the use of PixelCNNs  on the dictionary’s elements.
In this work we propose a 3D VQ-VAE-inspired model that successfully reconstructs high-fidelity, full-resolution, and neuro-morphologically correct brain images. We adapt the VQ-VAE to 3D inputs and introduce SubPixel Deconvolutions  to address grid-like reconstruction artifacts. The network is trained using FixUp blocks  allowing us to stably train the network without batch normalization issues caused by small batch sizes. We also test two losses, one inspired by  and a 3D adaptation of . Lastly, to demonstrate that decoded samples are morphologically preserved, we run VBM analysis on a control-vs-Alzheimer’s disease (AD) task using both original and decoded data, and demonstrate high dice and volumetric similarities between segmentations of the original and decoded data.
2.1 Model Architecture
The original VAE is composed of an encoder network that models a posterior of the random variable given the input , a posterior distribution which is usually assumed to be and a distribution over the input data via a decoder network. A VQ-VAE replaces the VAE’s posterior with an embedding space with being the number of atomic elements in the space and the dimension of each atomic element . After the encoder network projects an input to a latent representation each feature depth vector (the features corresponding to each voxel) is quantized via nearest neighbour look-up through the shared embedding space . The posterior distribution can be seen as a categorical one defined as follows:
This can be seen through the prism of a VAE that has an ELBO of and a KL divergence equal to given that the prior distribution is a categorical one with atomic elements. The embedding space is learned via back propagation or EMA (exponential moving average).
We modified the VQ-VAE to work with 3D data as shown in Figure 1. Firstly, all 2D convolutional blocks were replaced with 3D blocks due to the 3D nature of the input data and information that we want to encode. To limit the dimensionality of the model and of the representation, the VQ blocks were only introduced at , and resolutions. The number of features was doubled after each strided convolution block starting from the maximum number of features that would fit a GPU at full resolution. Our residual blocks are based on the FixUp initialization  so we circumvent any possible interference from the batch statistics being too noisy due to memory constraints. Furthermore, we are using transpose convolutions with a kernel of 4 and ICNR initialization  followed by an average pooling layer with a kernel of 2 and stride of 1  for upsampling the activations. The last upsampling layer uses a subpixel convolution [33, 1, 34] in order to counter the checkerboard artifacts that the transpose can generate.
The current architecture means that input data is compressed to a representation that is only 3.3% of the original image in terms of number of variables. More specifically, they are composed of three levels, the top one is 4864482, the middle one is 1216128 and the bottom one is 34332. Note, however, that the quantized parameterization is encoded as a 8-bit integer while the original input as a 32-bit float, making the bit-wise compression rate of the original size. The higher resolution codes are conditioned on the immediately lower resolution one which encourages them not to learn the same information.
2.2 Loss Functions
For the image reconstruction loss, we propose to use the loss function from . The mathematical formulation is:
Due to the non-ideal nature of the L1 and L2 losses from a fidelity point of view, we explore the use of an adaptive loss as decribed in . We extend this adaptive loss to work on single-channel 3D volumes as input rather than 2D images. This loss automatically adapts itself during training by learning an alpha and scale parameter for each output dimension so that it is able to smoothly vary between a family of loss functions (Cauchy, Geman-McClure, L1, L2, etc.).
As demonstrated in the original paper, the best results using the adaptive loss rely on some image representation besides pixels, for example, the discrete cosine transform (DCT) or wavelet representation. To apply this to 3D MRI volumes, we take the voxel-wise reconstruction errors and compute 3D DCT decompositions of them, placing the adaptive loss on each of the output image dimensions. 3D DCTs are simple to compute as we can use the separable property of the transform to simplify the calculation by doing a normal 1D DCT on all three dimensions. This works much better than using it on the raw pixel representation, as the DCT representation will avoid some of the issues associated with requiring a perfectly aligned output space, and it will model gradients instead of pixel intensities which almost always works better. The codebook losses are exactly the ones used in the original VQ-VAE 2 paper  and implemented within Sonnet (https://github.com/deepmind/sonnet).
Where refers to a stop-gradient operation that blocks gradients from flowing through . As per [35, 28] the second term of the codebook loss was replaced with an exponential moving average for faster convergence:
where is the decay parameter, is the codebook element, are the features to be quantized and is the number of vectors in the minibatch. The code will be available at the time of publication on GitHub.
3 Dataset and Preprocessing
The networks are first trained on a dataset of T1 structural images labeled as controls from the ADNI 1,GO,2 [27, 6], OASIS  and Barcelona studies . We skull stripped the images by first generating a binary mask of the brain using GIF , then blurring the mask to guarantee a smooth transition from background to the brain area and then superimposing the original brain mask to guarantee that the brain is properly extracted. Following that, we registered the extracted brain to MNI space. Due to the memory requirements of the baseline network , images were also resampled to 3mm isotropic to test how different methods worked at this resolution. We set aside 10% for testing results, totaling 1581 subjects for the training dataset and 176 subjects for the testing dataset. The images have been robust min-max scaled and no spatial augmentations have been applied during the training as images were MNI aligned. For fine-tuning, we used the Alzheimer’s Disease (AD) patients from the ADNI 1, GO, 2 [27, 6] datasets. The preprocessing and data split is identical to the control subjects dataset. For training we have 1085 subjects, while for testing we have 121 subjects.
4 Experiments and Results
|Network||Net Res||Met Res||Tr Mode||MS-SSIM||log(MMD)||Dice WM||Dice GM||Dice CSF|
4.1 Model Training Details
Our models were run on NVIDIA Tesla V100 32GB GPUs. The networks were implemented using NiftyNet . The chosen optimizer was Adam  combined with the SDGR learning rate scheduler  and a starting learning rate of . Depending on the combination of architecture and loss function, we set the batch size to the maximum allowed size given GPU memory (ranging from batch size 32 for Kwon et al, to 512 for the proposed method at low res, and 8 for the proposed method at full resolution). The Control Normal models have run for 7 days and the fine-tuning with pathological data was run for an additional 4 days. To have a fair comparison we also trained models on the pathological data from scratch for the same amount of time as the fine-tuning. The best baseline model that we found is . The authors propose an VAE - alpha Wasserstein GAN with Gradient Penalty based approach and encode brain volumes of size 64 64 64 to a one dimensional tensor of length 1000. In our experiments we compress images to the same extent. Both of our methods use the ADNI dataset .
Table 1 details the quantitative results on image reconstruction fidelity. To provide comparison with  we have measured Maximum Mean Discrepancy (MMD)  and Multi-Scale Structural Similarity (MS-SSIM) . Significant improvements in fidelity were observed with the proposed method, both at 3mm and full resolution. We measured Dice  overlap between segmentations of Gray Mater (GM) and White Matter (WM) in the ground truth and reconstructed volumes, and Cerebrospinal Fluid (CSF) as a proxy to the neuromorphological correctness of the reconstructions. The segmentations were extracted from the unified normalisation and segmentation step of Voxel Based Morphometry (VBM)  pipeline of Statistical Parametric Mapping  version 12. All metrics have been calculated over the test cases. Again, the proposed methods achieved statistically-significant ( Wilcoxon signed rank test) improved Dice scores against the -WGAN baseline, interchangeably between the Baur and Adaptive loss function.
It can clearly be seen that our VQ-VAE model combined with the 3D Adaptive loss achieves the best performance in all three training modes. Interestingly the Bau r loss trained model consistently performs better on MS-SSIM then the adaptive one. This might be attributed to the fact that the reconstructions of the adaptive appear like they have been passed through a total variation (TV) filter  which could interfere with the fine texture of white matter. This indicates the need for future research in a hybrid loss function that is able to work better at a texture level, possibly combining the adaptive loss with image gradient losses. Even though the die scores are excellent which indicates possible good neuromorphometry, we would like to refer the reader to the VBM analysis that follows for a more in depth analysis since the SPM implementation is known to be highly robust.
VBM was performed to test for differences in morphology caused by the reconstruction process. Figure 3 displays the VBM analysis of the grey matter of the AD patient subgroup, comparing the original data and the reconstructed data. Ideally, if no morphological changes are observed, the t-map should be empty. Results show that the method by Kwon et al. show large t-values, while the proposed method shows significantly lower (closer to zero) maps, corroborating the hypothesis that the proposed method is more morphologically preserving.
Lastly, in Figure 4, we looked at the T-test maps between AD and HC patients at the original resolution using the original data (labeled in the figure as ground truth) and then again using the reconstructed data for each method. In contrast to Figure 3, the best performing model is the VQ-VAE with Adaptive loss, where similar T-map clusters are observed, with low t-map residuals throughout. This means the proposed VQ-VAE Adaptive model was able to better learn the population statistics even though the MS-SSIM was found to be marginally lower when compared with the VQ-VAE with Baur loss. The discrepancy might be due to structural nature of the brain which is enforced by the sharper boundaries between the tissues of the TV filter like reconstructions in comparison with the more blurrier VQ-VAE Baur based output as seen in Figure 2.
In this paper, we introduced a novel vector-quantisation variational autoencoder architecture which is able to encode a full-resolution 3D brain MRI to a 0.825% of its original size whilst maintaining image fidelity and image structure. Higher multi-scale structural similarity index and lower maximum mean discrepancy showed that our proposed method outperformed the existing state-of-the-art in terms of image consistency metrics. We compared segmentations of white matter, grey matter and cerebro-spinal fluid in the original image and in the reconstructions showing improved performance. Additionally, VBM was employed to further study the morphological differences both within original and reconstructed pathological populations, and between pathological and control ones for each method. The results confirmed that both variants of our VQ-VAE method preserve the anatomical structure of the brain better than previously published GAN-based approaches when looking at healthy brains and those with Alzheimer’s disease. We hope that this paper will encourage further advances in 3D reconstruction and generative 3D models of medical imaging.
|Network||Net Res||Met Res||Tr Mode||MS-SSIM||log(MMD)||Dice WM||Dice GM||Dice CSF|
- (2017) Checkerboard artifact free sub-pixel convolution: a note on sub-pixel convolution, resize convolution and convolution resize. arXiv preprint arXiv:1707.02937. Cited by: §2.1.
- (2000) Voxel-based morphometryâthe methods. NeuroImage 11 (6), pp. 805 – 821. External Links: Cited by: §4.2.
- (2019) A general and adaptive robust loss function. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4331–4339. Cited by: §1, §2.2.
- (2019-08–10 Jul) Image synthesis with a convolutional capsule generative adversarial network. In Proceedings of The 2nd International Conference on Medical Imaging with Deep Learning, M. J. Cardoso, A. Feragen, B. Glocker, E. Konukoglu, I. Oguz, G. Unal and T. Vercauteren (Eds.), Proceedings of Machine Learning Research, Vol. 102, London, United Kingdom, pp. 39–62. Cited by: §1.
- (2019-08–10 Jul) Fusing unsupervised and supervised deep learning for white matter lesion segmentation. In Proceedings of The 2nd International Conference on Medical Imaging with Deep Learning, M. J. Cardoso, A. Feragen, B. Glocker, E. Konukoglu, I. Oguz, G. Unal and T. Vercauteren (Eds.), Proceedings of Machine Learning Research, Vol. 102, London, United Kingdom, pp. 63–72. Cited by: §1, §2.2.
- (2015) The alzheimer’s disease neuroimaging initiative phase 2: increasing the length, breadth, and depth of our understanding. Alzheimer’s & Dementia 11 (7), pp. 823–831. Cited by: §3, §4.1.
- (2018) Large scale gan training for high fidelity natural image synthesis. arXiv preprint arXiv:1809.11096. Cited by: §1.
- (2015) Geodesic information flows: spatially-variant graphs and their application to segmentation and fusion. IEEE transactions on medical imaging 34 (9), pp. 1976–1988. Cited by: §3.
- (2018) Predicting aging of brain metabolic topography using variational autoencoder. Frontiers in aging neuroscience 10, pp. 212. Cited by: §1.
- (2019) Large scale adversarial representation learning. In Advances in Neural Information Processing Systems, pp. 10541–10551. Cited by: §1.
- (2018) NiftyNet: a deep-learning platform for medical imaging. Computer Methods and Programs in Biomedicine. External Links: Cited by: §4.1.
- (2014) Generative adversarial nets. In Advances in Neural Information Processing Systems 27, Z. Ghahramani, M. Welling, C. Cortes, N. D. Lawrence and K. Q. Weinberger (Eds.), pp. 2672–2680. External Links: Cited by: §1.
- (2012) A kernel two-sample test. Journal of Machine Learning Research 13 (Mar), pp. 723–773. Cited by: §4.2.
- (2019-08–10 Jul) Generative image translation for data augmentation of bone lesion pathology. In Proceedings of The 2nd International Conference on Medical Imaging with Deep Learning, M. J. Cardoso, A. Feragen, B. Glocker, E. Konukoglu, I. Oguz, G. Unal and T. Vercauteren (Eds.), Proceedings of Machine Learning Research, Vol. 102, London, United Kingdom, pp. 225–235. Cited by: §1.
- (2019) Flow++: improving flow-based generative models with variational dequantization and architecture design. arXiv preprint arXiv:1902.00275. Cited by: §1.
- (2014) Adam: a method for stochastic optimization. arXiv preprint arXiv:1412.6980. Cited by: §4.1.
- (2013) Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114. Cited by: §1.
- (2018) A probabilistic u-net for segmentation of ambiguous images. In Advances in Neural Information Processing Systems, pp. 6965–6975. Cited by: §1.
- (2019) Generation of 3d brain mri using auto-encoding generative adversarial networks. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 118–126. Cited by: §1, §3, §4.1, §4.2.
- (2018) OASIS-3: longitudinal neuroimaging, clinical, and cognitive dataset for normal aging and alzheimerâs disease. Alzheimer’s & Dementia: The Journal of the Alzheimer’s Association 14 (7), pp. P1097. Cited by: §3.
- (2019-08–10 Jul) DavinciGAN: unpaired surgical instrument translation for data augmentation. In Proceedings of The 2nd International Conference on Medical Imaging with Deep Learning, M. J. Cardoso, A. Feragen, B. Glocker, E. Konukoglu, I. Oguz, G. Unal and T. Vercauteren (Eds.), Proceedings of Machine Learning Research, Vol. 102, London, United Kingdom, pp. 326–336. Cited by: §1.
- (2019) Active appearance model induced generative adversarial network for controlled data augmentation. In Medical Image Computing and Computer Assisted Intervention – MICCAI 2019, D. Shen, T. Liu, T. M. Peters, L. H. Staib, C. Essert, S. Zhou, P. Yap and A. Khan (Eds.), Cham, pp. 201–208. External Links: Cited by: §1.
- (2017) SGDR: stochastic gradient descent with warm restarts. In 5th International Conference on Learning Representations, ICLR 2017, Toulon, France, April 24-26, 2017, Conference Track Proceedings, External Links: Cited by: §4.1.
- (2016) V-net: fully convolutional neural networks for volumetric medical image segmentation. In 2016 Fourth International Conference on 3D Vision (3DV), pp. 565–571. Cited by: §4.2.
- (2018) Unsupervised depth estimation, 3d face rotation and replacement. In Advances in Neural Information Processing Systems, pp. 9736–9746. Cited by: §1.
- (2011) Statistical parametric mapping: the analysis of functional brain images. Elsevier. External Links: Cited by: §4.2.
- (2010) Alzheimer’s disease neuroimaging initiative (adni): clinical characterization. Neurology 74 (3), pp. 201–209. Cited by: §3.
- (2019) Generating diverse high-fidelity images with vq-vae-2. In Advances in Neural Information Processing Systems, pp. 14837–14847. Cited by: §1, §2.2.
- (2015) U-net: convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention, pp. 234–241. Cited by: §1.
- (1992) Nonlinear total variation based noise removal algorithms. Physica D: nonlinear phenomena 60 (1-4), pp. 259–268. Cited by: §5.
- (2017) PixelCNN++: a pixelcnn implementation with discretized logistic mixture likelihood and other modifications. In ICLR, Cited by: §1.
- (2019) Spatial patterns of white matter hyperintensities associated with alzheimerâs disease risk factors in a cognitively healthy middle-aged cohort. Alzheimer’s research & therapy 11 (1), pp. 12. Cited by: §3.
- (2016) Real-time single image and video super-resolution using an efficient sub-pixel convolutional neural network. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1874–1883. Cited by: §1, §2.1.
- (2019) Checkerboard artifacts free convolutional neural networks. APSIPA Transactions on Signal and Information Processing 8. Cited by: §2.1.
- (2017) Neural discrete representation learning. In Advances in Neural Information Processing Systems, pp. 6306–6315. Cited by: §1, §1, §2.2.
- (2003) Multiscale structural similarity for image quality assessment. In The Thrity-Seventh Asilomar Conference on Signals, Systems & Computers, 2003, Vol. 2, pp. 1398–1402. Cited by: §4.2.
- (2019) Self-training with noisy student improves imagenet classification. arXiv preprint arXiv:1911.04252. Cited by: §1.
- (2019) Fixup initialization: residual learning without normalization. In ICLR, Cited by: §1, §2.1.
- (2019) FMRI data augmentation via synthesis. In 2019 IEEE 16th International Symposium on Biomedical Imaging (ISBI 2019), pp. 1783–1787. Cited by: §1.