Real-time Deep Registration With Geodesic Loss
Abstract
With an aim to increase the capture range and accelerate the performance of state-of-the-art inter-subject and subject-to-template 3D registration, we propose deep learning-based methods that are trained to find the 3D position of arbitrarily oriented subjects or anatomy based on slices or volumes of medical images. For this, we propose regression convolutional neural networks (CNNs) that learn to predict the angle-axis representation of 3D rotations and translations using image features. We use and compare mean square error and geodesic loss for training regression CNNs in two different scenarios: 3D pose estimation from slices (2D to 3D registration) and 3D to 3D registration. As an exemplary application, we applied the proposed methods to register arbitrarily oriented reconstructed images of fetuses scanned in-utero at a wide gestational age range to a standard atlas space. Our results show that in such registration applications that are amendable to learning, the proposed deep learning methods with geodesic loss minimization can achieve accurate results with a wide capture range in real-time (). We also tested the generalization capability of the trained CNNs on an expanded age range and on images of newborn subjects with similar and different MR image contrasts. We trained our models on T2-weighted fetal brain MRI scans and used them to predict the 3D position of newborn brains based on T1-weighted MRI scans. We showed that the trained models generalized well for the new domain when we performed image contrast transfer through a conditional generative adversarial network. This indicates that the domain of application of the trained deep regression CNNs can be further expanded to image modalities and contrasts other than those used in training. A combination of our proposed methods with accelerated optimization-based registration algorithms can dramatically enhance the performance of automatic imaging devices and image processing methods of the future.
Image registration, Pose estimation, Deep learning, Convolutional neural network, CNN, MRI, fetal MRI.
1 Introduction
1.1 Background
\IEEEPARstartImage registration is one of the most fundamental tools in biomedical image processing, with applications that range from image-based navigation in imaging and image-guided interventions to longitudinal and group analyses [1, 2, 3, 4, 5, 6]. Registration can be performed between images of the same modality or across modalities, and within a subject or across subjects, with diverse goals such as motion correction, pose estimation, spatial normalization, and atlas-based segmentation. Image registration is defined as an optimization problem to find a global transformation or a deformation model that maps a source (or moving) image to a target (or fixed) image. The complexity of the transformation is defined by its degree-of-freedom (DOF) or the number of its parameters. The most widely used transformations in biomedical image registration range from rigid and affine to high-dimensional small or large deformations based on biophysical/biomechanical, elastic, or viscous fluid models [5].
Given a transformation model and images, iterative numerical optimization methods are used to maximize intensity-based similarity metrics or minimize point cloud or local feature distances between images; however the cost functions associated with these metrics are often non-convex, limiting the capture range of these registration methods. Techniques such as center-of-gravity matching, principal axes and moments matching, grid search, and multi-scale registration are used to initialize transformation parameters so that iterative optimization starts from the vicinity of the global optimum. These techniques, however, are not always successful, especially if the range of possible rotations is wide and shapes have complex features. Grid search and multi-scale registration may find global optima but are computationally expensive and may not be useful in time-sensitive applications such as image-based navigation.
There has been an increased interest in using deep learning in medical image processing, motivated by promising results that have been achieved in semantic segmentation in computer vision [7] and medical imaging [8, 9]. The use of learning-based techniques in image registration, however, has been limited. Some registration tasks, for example those on image to model, atlas, or standard-space registration are amenable to learning and may provide significant improvement over strategies such as iterative optimization or grid search when the range of plausible position/orientation is wide, demanding a large capture range. Under these conditions, a human observer can find the approximate pose of 3D objects quickly and bring them into rough alignment without solving an iterative optimization. This is performed through feature identification.
1.2 Related Work
Deep feature representations have recently been used to learn metrics to guide local deformations for multi-modal inter-subject registration [10, 11]. These works have shown that deep learned metrics provide slight improvements over local image intensity and patch features that are currently used in deformable image registration. Initialized by rigid and affine alignments, the goal here was merely to improve local deformations and not the global alignment. In another recent work on deformable registration, Yang et al. [12] developed a deep autoencoder-decoder convolutional neural network (CNN) that learned to predict the Large Deformation Diffeomorphic Metric Mapping (LDDMM) model, and achieved state-of-the-art performance with an order of magnitude faster optimization in inter-subject and subject-to-atlas deformable registration.
For 3D global rigid registration, which is the subject of this study, Liao et. al. [13] proposed a reinforcement learning algorithm for a CNN with 3 fully connected layers. They used a greedy supervised learning strategy with an attention-driven hierarchical method to simultaneously encode a matching metric and learn a strategy; and showed improved accuracy and robustness compared to state-of-the-art registration methods in computed tomography (CT). This algorithm is relatively slow and lacks a systematic stopping criterion at test time.
In an effort to speed up 2D to 3D (X-ray to CT) rigid registration and improve its capture range, Miao et. al. [14, 15] proposed a real-time registration algorithm using CNN regressors. In this method, called pose estimation via hierarchical learning, they partitioned the 6-dimensions of the parameter space to three groups (zones) to learn, hierarchically, the regression function based on in-plane, out-of-plane rotation, and out-of-plane translation parameters. CNN regressors were trained separately in each zone, where local image residual features were used as input and the Euclidean distance of the transformation parameters were used as the loss function. In experiments with relatively small rotations of up to (perturbations with standard deviations of in each of the rotation parameters), they reported improved registrations achieved in 100ms ( times faster than the best intensity-based registrations in that 2D to 3D application).
The 2D to 3D (X-ray to CT) image registration problem shares similarity with 3D pose estimation in computer vision. The term 3D pose estimation in computer vision is referred to as finding the underlying 3D transformation between an object and the camera from 2D images. State-of-the-art methods for CNN-based 3D pose estimation can be classified in two groups: 1) models that are trained and used to predict keypoints as models and then use object models to find the orientation [16, 17]; and 2) models that predict the pose of the object directly from images [18, 19]. Pose estimation in computer vision has been largely treated as a classification problem, where the pose space is discretized into bins and the pose is predicted to belong to one of the bins [18, 19]. Mahendran et al. [20] have recently modeled the 3D camera/object pose estimation as a regression problem. They proposed deep CNN regression to find rotation matrices and a new loss function based on the geodesic distance for training.
1.3 Contributions
Similar to [14, 15, 20], we propose a deep CNN regression model for 3D registration; but unlike those works that focused on estimating pose based on 2D-projected image representation of objects (thus limited rotations), we aimed to find the 3D pose of arbitrarily-oriented objects based on their volumetric or sectional (slice) image representations. Our goal was to speed up and improve the capture range of volume-to-volume and slice-to-volume registrations. To achieve this, we formulated the regression problem based on the angle-axis representation of 3D rotations that form a Special Orthogonal Group ; and used the bi-invariant geodesic distance, which is a natural Reimannian metric on [21], as the loss function. We augmented our proposed deep residual regression network with a correction network to estimate translation parameters, and ultimately used it to initialize optimization-based registration to achieve robust and accurate registration at the widest plausible range of 3D rotations.
We applied our proposed method to rigidly register reconstructed fetal brain MRI images [22] to a standard (atlas) space. Fetal brains can be in any arbitrary orientation with respect to the MRI scanner coordinate system, as one cannot pre-define the position of a fetus when a pregnant woman is positioned on an MRI scanner table. Moreover, fetuses frequently move and can rotate within scan sessions. Our deep model, trained on reconstructed T2-weighted images of 28-37 week gestational age (GA) fetuses from the training set, was able to find the 3D position of fetuses in the test set in real-time (ms) in the majority of cases, where optimization-based methods failed due to falling in local minima. We then examined the generalization properties of the learned model on test images of much younger fetuses (21-27 weeks GA), as well as T2- and T1-weighted images of newborns, that all exhibited significantly different size, shape, and features.
Based on our formulation, we also trained models for slice-to-volume registration, an application that exhibits significant technical challenges in medical imaging, as recently reviewed in [6]. Prior work on slice-to-volume registration in fetal MRI has shown a strong need for regularization and initialization of slice transformations through hierarchical registration [22, 23] or state-space motion modeling [24]. Learning-based methods have been recently used to improve prediction of slice locations in fetal MRI [25, 26] and fetal ultrasound [27]. In [25, 26] anchor-point slice parametrization was used along with the Euclidean loss function based on [28] to predict slice positions and reconstruct fetal MRI in canonical space. The alignment of fetal ultrasound slices in [27] was formulated as z-position estimation and 3-class slice plane classification (mid-axial, target eye, and eye planes); where a CNN was trained using negative likelihood loss for simultaneous prediction of slice location and brain segmentation. For slice-to-volume registration we used 3D full rotation representation to train our CNN regression model. Our results are also promising in this application as they show initial pose of the fetal head can be estimated in real time from slice acquisitions, which is particularly helpful if good-quality slices are only sparsely acquired due to fetal motion. The remainder of this paper involves details of the methods in Section 2, followed by results in Section 3, a discussion in Section 4, and the conclusion. Our formulation is generic and may be used in other applications.
2 Methods
In this section we present a 3D rotation representation that helps us build our CNN regression models for 3D registration. We show how using a non-linear activation function can mimic exact rotation representation. We present our network architectures and propose a two-step training algorithm with appropriate loss functions to train the network.
2.1 3D Rotation Representation
A 3D rotation is commonly represented by a matrix with 9 elements that are subject to six norm and orthogonality constraints ( is orthogonal and ). The set of 3D rotations form the Special Orthogonal Group that is a 3-dimensional object embedded in (thus has 3 DOFs). is a compact Lie group that has skew symmetric matrices as its Lie algebra. Its 3 DOFs can be represented as 3 consecutive rotations relative to principle axes of the coordinate frame.
Based on Euler’s theorem each rotation matrix can be described by an axis of rotation and an angle around it (known as angle-axis representation). A 3-dimensional rotation vector is a compact representation of rotation matrix such that rotation axis is its unit vector and angle in radians is its magnitude. The axis is oriented so that the angle rotation is counterclockwise around it. As a consequence, the rotation angle is always non-negative, and at most ; i.e. .
For a 3-dimensional vector , by defining as the axis of orientation and as the angle of rotation (in radians), the rotation matrix is calculated as:
(1) |
where is the skew-symmetric operator:
(2) |
Using Rodrigues’ rotation formula, (1) can be simplified to:
(3) |
and
(4) |
As a result, to find any arbitrary rotation in 3D space it is sufficient to find the rotation vector corresponding to that orientation. In the next section, the proposed networks that can find this rotation vector are introduced.
Figure 1.a shows general parts of the regression networks used in this study. Each network contains 3 parts: input, feature extraction, and output. In this study we used three networks with slightly different configurations of these parts. Next we discuss the architecture of each network in detail.
2.2 2D to 3D Network Architecture
The proposed network for 2D to 3D pose (3D slice pose) estimation comprises of an 18-layer residual CNN [29] for feature extraction, and a regression head that contains two heads: one for regression over 3 rotation parameters and the other for the slice number in the atlas space. The network architecture is shown in Figure 1.b. For the rotation head, the last fully connected layer has size three which corresponds to the elements of the rotation vector . The last non-linear function on top of the fully connected layer is which limits the output from to and simulates the constraints of the rotation vector. The slice number head contains a scalar, as the network tries to estimate the slice location (number) along with its orientation. ReLU non-linearity is applied on top of this head as the value of the slice number is non-negative.
2.3 3D to 3D Pose Network Architecture
The 3D feature extraction part of our 3D to 3D pose estimation architecture is shown in Figure 1.c. All convolutional kernels have size . In the first layer, eight convolutional kernels are applied on the 3D input image, followed by ReLU nonlinear function and batch normalization. Before the second and third convolutional layers, the tensors are down-sampled by a factor 2 using the 3D max-pooling function. For the second and third layers, ReLU nonlinear function and batch normalization are used after convolutional layers applying 32 convolutional kernels. In the last two convolutional layers, 64 kernels are used followed by ReLU and batch normalization. Before the last fully connected layer, 3 fully connected layers with size of 512, 512, and 256 are used with ReLU nonlinear function and batch normalization. The feature extraction part provides 256 features that are fed into the regression head. The overall architecture of the pose network is shown in Figure 1.d. This network only estimates orientation, and has the same regression head as the 2D to 3D network.
2.4 3D to 3D Correction Network Architecture
The correction network aims at estimating translations while estimating rotations. Note that we assume initial translations are calculated by center-of-gravity matching and initial rotations by the 3D to 3D pose estimation network. The network architecture of the correction network (Correction-Net) is shown in Figure 1.e. The 3D feature extraction part of this network is the same as the 3D to 3D pose network. In this architecture, both a 3D reference image (an atlas or template image) and a roughly-oriented 3D moving image are fed as 2-channel input, as we aim to estimate both rotation and translation parameters. The regression head of this network contains two heads: a rotational head as already described and a translational head. The translational head is a vector of 3 parameters that translate the moving image into the target image.
2.5 Training the Networks
In this section we describe the training procedures for the networks. The loss function is designed as:
(5) |
where is a hyper-parameter to balance between the rotation loss (which is bounded between 0 to ) and the translation loss . The translation loss is the mean-squared error (MSE) between the predicted and ground truth translation vectors. For the first stage of training, we use the MSE loss also for the rotation parameters, and then switch to the geodesic loss in the second stage. The MSE loss is defined as
(6) |
where and are the output of the rotation head and the ground truth rotation, respectively. While MSE can help narrow down the search space for pose prediction, it does not accurately represent a distance function between two rotations. The distance between two 3D rotations is geometrically interpreted as the geodesic distance between two points on the unit sphere. The geodesic distance (or the shortest path) is the radian angle between two viewpoints, which has an exponential form. Let and be two rotation matrices between which we want to measure a distance, i.e. the 3D angle between these rotations:
(7) |
Equation (7) shows the amount of rotation in radian around a specific vector that needs to be applied on rotation matrix to reach rotation matrix , and is calculated as:
(8) |
where is the Frobenius norm and is the matrix logarithm of a rotation matrix that can be written as:
(9) |
To show that (8) is actually the distance between rotation matrices we should consider the fact that a rotation matrix is orthogonal () and the rotation from to is . Considering (9) and the fact that can be calculated using (3), where and are the axis and angle of rotation of as the 3-dimensional rotation vector representation of , and knowing that the norm of the skew-symmetric matrix of unit vector is one, one can show that (8) is equal to .
On the other hand, since the distance between and can be represented as rotation matrix using (4), is equal to . Therefore, the geodesic loss which is defined as the distance between two rotation matrices can be written as:
(10) |
This is a natural Reimannian metric on the compact Lie group . We calculate the geodesic loss using (10), but in order to use (10) we need to find rotation matrices as described in Section 2.1. In summary, training the networks involves iterations of back-propagation with the total loss function in (5) where translation loss is the MSE, and the rotation loss is calculated by (6) in the first stage and by (10) in the second stage. In our experiments each stage involved ten epochs. The details of the data and experiments are discussed next.
3 Experiments
3.1 Datasets
The dataset contained 93 reconstructed T2-weighted fetal MRI scans. Fetal MRI data was obtained from fetuses scanned at a gestational age between 21 and 37 weeks (mean=30.1, stdev=4.6) on 3-Tesla Siemens Skyra scanners with 18-channel body matrix and spine coils. Repeated multi-planar T2-weighted single shot fast spin echo scans were acquired of the moving fetuses, ellipsoidal brain masks were automatically extracted based on the real-time algorithm in [30]. The scans were then combined through slice-level motion correction and robust super-resolution volume reconstruction [22, 23]. Brain masks were generated on the reconstructed images using Auto-Net [31] and manually corrected in ITK-SNAP [32] as needed.
Brain-extracted reconstructed images were then registered to a spatiotemporal fetal brain MRI atlas [33]. This registration was performed through the procedure described in [33] and is briefly described here as it generated the set of fetal brain scans (all registered to the standard atlas space) used to generate ground truth data. First, a rigid transform was found between the fetal head coordinates and the MRI scanner (world) coordinates by inverting the direction cosine matrix of one of the original fetal MRI scans that appeared in an orthogonal plane with respect to the fetal head (the idea behind this is that the MR technologist who prescribed scan planes identified and used the fetal head coordinates and did not use the world coordinates). Applying to the image reconstructed in the world coordinates mapped it to the fetal coordinates; thus the oblique reconstructed image appeared orthogonal with respect to the fetal head after this mapping; which in-turn enabled a grid search on all orthogonal 3D rotations that could map this image to the corresponding age of the spatiotemporal atlas (fetal coordinates to atlas space). Multi-scale rigid registration was performed afterwards to fine tune the alignment. We note that due to differences in the anatomy of subjects and the atlas, the final alignments have an intrinsic level of inaccuracy; but since our goals are improved capture range and speed, in our analysis we are not sensitive to small errors in alignment of reference data. The results in terms of accuracy, however, should be interpreted cautiously. All registered images were manually controlled to ensure visually-correct alignment to the atlas space, and were corrected as needed.
Training Dataset
Reconstructed T2-weighted images of 36 fetuses scanned at 28 to 37 weeks GA were used in training. Each image was 3D rotated and translated randomly and fed to the network. Since the rotation matrix was known the rotation vector was computed and used as the ground truth. Two different algorithms were used to randomly generate rotation matrices.
For the 2D to 3D application, each input image randomly rotated around the and axes between and . This algorithm covered half of all possible orientations, and provided us with all different views in the training set. Therefore, for training the network, the separation of different views (i.e. axial, coronal, and saggital) was unnecessary. The reason that we did not span the whole space in this experiment is that 2D brain slices do not have enough information to separate between rotations that are radians away around arbitrary rotation vectors as predicting the 3D direction of the brain from a 2D slice is difficult due to the symmetrical shape of the brain. In order to choose input slices we randomly chose 30 slices from 66 percent of the middle slices, skipping the border slices that did not carry sufficient information for training.
For the 3D to 3D application we used the algorithm proposed in [34, p. 355] to uniformly span the whole space. This algorithm mapped three random variables in the range onto the set of orthogonal matrices with positive determinant; that is, the set of all 3D rotations. This algorithm generates uniformly distributed samples on unit sphere.
For the 3D to 3D training of the correction network, each moving image was randomly rotated around the and axes between to and translated randomly in each direction between to millimeters. The transformed image was then concatenated with its corresponding atlas image to form a 2-channel input to the network. The range of transform parameter variations was lower for this network as the objective of this network was to correct initial predictions made by other networks.
Translation, rotation, and scaling of the images were applied using one transformation and the resampling was done on-the-fly during training. Linear interpolation was used for resampling images for faster training. The total number of generated training samples was slices for the 2D to 3D network and volumes for the 3D to 3D networks. The number of epochs for each training step was set to .
Testing Dataset
To test the performance and generalization properties of the trained models, three test sets were used: Test Set 1) reconstructed T2-weighted images of 40 fetuses with GA between 27 to 37 weeks that were not used in training; Test Set 2) reconstructed T2-weighted images of 17 fetuses with GA between 22 and 26 weeks; as well as T2-weighted MRI scans of 7 newborns scanned at 38 to 44 weeks GA-equivalent age; and Test Set 3) T1-weighted MRI scans of those newborns.
On each image 10 randomly generated rotation matrices were applied resulting in 400, 170, and 70 samples for each set. For each application, rotation matrices were generated through the same process used for the training data as discussed in section 3.1.1. Figure 2 shows the histogram of the synthetic rotations for the slice-to-volume and volume-to-volume experiments. The axis shows the distance of the generated rotation matrix from the identity matrix in degrees.
3.2 Intensity-Based Registration
To compare the pose predictions made by our pose estimation CNNs, referred to as Pose-Net, with conventional intensity-based registration methods, we developed multiple variations of rigid registration for volume-to-volume registration (VVR) and slice-to-volume registration (SVR). For VVR comparisons, we developed the following programs: (i) VVR-GC: A gradient-descent optimizer was used to maximize the Normalized Mutual Information (NMI) metric. A multi-scale multi-resolution approach was used for registration with 3 levels of transform refinement. Transform was initialized using a geometric-center alignment strategy. (ii) VVR-PAA: the same as VVR-GC except that the transform was initialized using a moments matching and principle axis alignment approach. (iii) VVR-Deep: same as VVR-GC except that the transform was initialized using Pose-Net predicted transforms, without employing any other initialization strategy. For SVR comparisons, we developed two versions of the program: (i) SVR-GC: A gradient-descent optimizer was used to maximize Normalized Cross Correlation (NCC). A multi-scale multi-resolution approach was used for registration with 3 levels of the transform refinement. The transform was initialized using the geometric center matching method. (ii) SVR-Deep: same as SVR-GC except that the transform was initialized using Pose-Net predicted transforms. To handle the difference in number of voxels between fixed and moving images in SVR, we used the NCC metric instead of the NMI metric that was used in VVR. The learning rate for the optimization process was set lower in both VVR and SVR programs when they were initialized using Pose-Net predictions.
3.3 Results
We evaluated pose predictions obtained from the proposed methods in different scenarios.
Slice Pose Estimation
As described in section 2, optimization-based SVR methods and the trained 2D to 3D deep CNN were used for slice pose estimation. To investigate the influence of geodesic loss, the results of deep learning based algorithms were computed twice. Once after training using the MSE loss function, and once after fine tuning using the geodesic loss. In visualizing the results, test samples were distributed over 6 different bins according to their magnitude of rotation in a way that the number of samples in each bin was roughly equal, while highlighting the performance of methods in terms of their capture range. It can be seen in Figure 3 and Table 1 that 1) the geodesic loss improved the results. This improvement was significant in bins of and ; 2) the optimization based method without deep CNN initialization failed in most cases; and 3) the optimization-based method with deep initialization performed the best.
Slice to Volume | ||||||
Method | ||||||
SVR-GC | 21.36 () | 50.62 () | 70.16 () | 83.28 () | 105.64 () | 147.69 () |
Pose-Net (MSE) | 16.73 () | 17.3 () | 21.96 () | 23.86 () | 23.84 () | 46.01 () |
Pose-Net (Geodesic) | 15.1 () | 14.05 () | 16.18 () | 24.33 () | 19.8 () | 36.47 () |
SVR-Deep | 10.23 () | 12.32 () | 13.08 () | 17.6 () | 16.19 () | 26.85 () |
3D to 3D pose estimation
In the 3D to 3D rigid registration scenario, 6 different algorithms were compared: VVR-GC, VVR-PAA, Pose-Net with MSE, Pose-Net with geodesic loss, Correction network, and VVR-Deep. Figure 4 shows that 1) the VVR-GC performed very well for rotations between but it failed for almost all samples with rotations as it converged to the wrong local minima; 2) by using the principal axis initialization, the VVR-PAA significantly improved the performance for but again failed for the majority of samples with rotations, and it resulted in a huge loss in performance (compared to VVR-GC) in as it incorrectly shifted the initial point to the region of a wrong local minimum. 3) The trained deep CNN models all performed well as they showed much lower number of failures. The geodesic loss showed significant improvement over the MSE loss; and the Correction-Net performed the best with only a very small fraction of failures in the range of rotations. 4) VVR-Deep, which is the optimization-based registration initialized by deep pose estimation generated the most accurate results and the minimum number of failures. Table 2 shows that VVR-Deep performed the best, while Correction-Net results were also comparable, especially as Correction-Net based registration is real-time and several orders of magnitude faster than the VVR-Deep registration. The average runtime of methods is discussed in Section 3.3.4.
Volume to Volume | ||||||
Method | ||||||
VVR-GC | 2.42 () | 45.39 () | 149.91 () | 177.0 () | 174.87 () | 177.2 () |
VVR-PAA | 95.54 () | 131.1 () | 128.44 () | 129.68 () | 131.15 () | 141.44 () |
Pose-Net (MSE) | 23.96 () | 26.11 () | 24.01 () | 28.74 () | 29.01 () | 55.76 () |
Pose-Net (Geodesic) | 10.08 () | 11.44 () | 12.43 () | 13.46 () | 16.46 () | 34.19 () |
Correction-Net | 4.54 () | 4.45 () | 4.83 () | 4.82 () | 6.33 () | 19.42 () |
VVR-Deep | 2.42 () | 2.35 () | 2.43 () | 2.36 () | 4.84 () | 20.44 () |
Figure 5 shows the translation error of the correction network. The error is calculated as the distance of true translation vector and the predicted one. The initial translation was calculated as the distance of the input image to the atlas location. Note that all errors reported here including the translation errors are between images of different subjects and atlases, so there is an intrinsic level of inaccuracy in alignment as the exact alignment of two different anatomies (with different size and shape) using rigid registration is not well defined.
Figure 6 shows the results of different algorithms on an example from the 3D to 3D rigid registration tests. All algorithms tried to register the brain of this fetus (with mild unilateral ventriculomegaly) to the corresponding age of the atlas on the right. The first column is the input with synthetic rotation. As the rotation was more than , VVR-GC failed due to the use of the non-convex NMI-based cost function. Without deep initialization this algorithm converged to the wrong local optimum which resulted in a flipped version of the correct orientation (the forth column). The second and third columns show the results of the Pose-Net and the Correction-Net. The geodesic distance errors (in degrees) of each algorithm are given underneath each column. For this example, the correction network generated the most accurate results.
Generalization property of the trained models
An important question that is frequently asked about learning-based methods such as the ones developed in this study concerns their generalization performance: can they generalize well for new test data, possibly with different features? In this section, we aimed to investigate the generalization property of our trained models. For this, we carried out two sets of experiments, with Test Sets 2 and 3:
First, we added Test Set 1 to Test Set 2 to investigate the generalization of the algorithm for fetal brains at ages other than those used in the training set (younger fetuses at 22-27 weeks GA) and newborn brains scanned at 38-44 weeks GA-equivalent age (scanned in a different, ex-utero scan settings). We recall that the training dataset only contained fetuses scanned at 28-37 weeks. The brain develops very rapidly especially throughout the second trimester of pregnancy, therefore the difference in brain size and shape between these test sets and the training set was significant. The images underneath the box plots in Figure 7 show sample slices for different ages. By simply using a scale parameter that was calculated by the size ratio of atlases at different ages, we scaled the images and fed them into the network. Box plots of the estimated pose error in different ages in Figure 7 showed that the network generalized very well over all age ranges and for different scan settings. It is, however, seen that the average and median errors slightly increased towards the lower age range as the anatomy became significantly different from the anatomy of the training set.
In our second experiment on generalization, we investigated one of the main drawbacks of the deep learning based methods that is their limited generalization over different modalities. To investigate whether the Pose-Net could generalize on T1-weighted newborn MRI scans while trained only on reconstructed T2-weighted scans of fetuses, we applied our 3D to 3D registration test pipeline to T1-weighted scans of 7 newborns (70 samples in total) in Test Set 3. Figure 8 shows the results of applying the trained model on T1-weighted scans (blue box plots) compared to T2-weighted scans (orange box plots) with exact same synthetic random rotations. While Pose-Net still performed better than VVR-GC and VVR-PAA (compare it to Figure 4), it did not generalize well on T1-weighted scans.
To solve this issue through pre-processing, we developed an image contrast transfer algorithm and trained a model based on the approach in [35]. In this algorithm we trained a conditional generative adversarial network (GAN) to simultaneously learn the mapping from T1 to T2-weighted images and a loss function to learn this mapping. Figure 9 shows the pipeline to train the adversarial network on T1 and T2 image pairs. In this algorithm two networks, a generator () and a discriminator (), were trained simultaneously in a way that tried to generate T2-like images from the T1-weighted scans, and tried to distinguish real from fake (synthetically-generated) T2-weighted image contrast in {T1, T2} pairs. To train these networks the following objective was used, where was random noise vector:
(11) |
where the loss function of the conditional GAN, , was defined as:
(12) |
and the distance between the generated and real T2 scans in the training set were calculated by the -norm to encourage generating sharper images:
(13) |
To train the conditional GAN networks we used 33 pairs of T1 and T2-weighted newborn brain images resulting in 3300 paired 2D slices. These images were used for training only. We then tested the trained on the test set of 7 newborn brain images. The results of generated T2-weighted images are shown in the last row of Figure 8. The pose error estimation box plots in this figure show that the image contrast transfer from T1 to T2-weighted images and using the generated T2 images as input to the pose network significantly decreased the pose estimation error. In fact the trained generator can be used as an input cascade to the Pose-Net or Correction-Net so that they can be directly used to register T1-weighted newborn brain images without even being trained in this domain. Note that no reference data (aligned to an atlas) was needed for T1-weighted scans to train the conditional GAN except a set of paired T1 and T2 scans in the subject space that was easy to obtain. A similar approach can be taken to further expand the generalization domain of the trained pose estimation networks, for example to adult brains. In this work we had access to paired T1 and T2 images. In case in any other application paired images are not accessible between two domains, cycleGAN [36] can be used.
Testing time comparison
Table 3 shows the average testing time (in seconds) for each algorithm. It should be mentioned that the testing time for all the CNN-based methods were measured on GPUs, whereas the testing time for all non-CNN based methods were measured on multi-core CPUs, therefore this data does not directly compare the computational cost of different algorithms. As expected the deep learning based algorithms were real time and faster (by order of ) than the optimization based algorithms. The test time difference between Pose-Net and the Correction-Net was because of a resampling operation on the image between the two stages of the Correction-Net, which took about 80 milliseconds.
Method | Volume to Volume | Slice to volume |
---|---|---|
VVR/SVR-GC | 553 | 498 |
Pose-Net | ||
Correction-Net | - | |
VVR/SVR-Deep | 365 | 184 |
4 Discussion
In this work, we trained deep CNN regression models for 3D pose estimation of anatomy based on medical images. Our results show that deep learning based algorithms not only can provide a good initialization for optimization-based methods to improve the capture range of 2D to 3D registration, but also can be directly used to find robust and accurate 3D to 3D rigid registration in real time. Using these learning-based methods along with accelerated optimization-based registration methods will provide powerful registration systems that can capture almost all possible rotations in 3D space.
Our networks composed of feature extraction layers and regression heads at the output layer. Using non-linearity at the regression layer mimics the behaviour of the angle-axis representation of the rotation matrix, where the geodesic loss was used as a bi-invariant, natural Reimannian distance metric for the space of 3D rotations. Compared to MSE on rotation vectors, our results showed that the geodesic loss led to significantly improved performance especially in 3D when images contained sufficient information for pose estimation.
By using a two step approach, where the 3D pose of an object (anatomy) is first approximately found in a standard (atlas) space, and then fed along with a reference image as two channels of input to a regression CNN (the correction network), accurate inter-subject rigid registration can be achieved in real-time for all ranges of rotation and translation. Initial translations may be achieved also in real-time through center of gravity matching.
One of the main concerns with learning based methods is their generalization property when they face test images with features that are different from the training set. This would be more important in medical imaging studies as the number of training samples is rather limited. In this study, to evaluate the generalization of the trained models over different ages, as the shape and size of the brain aggressively changes in early gestational weeks, we intentionally trained the network on older cases and tested it on younger ages. We only used a pre-defined scale parameter inferred from the gestational age based on a fetal brain MRI atlas.
We also tested the trained models on brain MRI scans of newborns which were obtained in a completely different setting, with head coils for ex-utero imaging. While the trained models worked very well for T2-weighted brain scans of newborns at 38-44 weeks, we challenged the trained models by testing T1-weighted MRI scans of newborn brains. For the T1-weighted scans the performance of the networks dropped significantly; but we showed that by using a GAN based technique that learned to translate T1-weighted images into T2-like images, and feeding the outputs into the trained regression CNNs, we achieved great performance for T1-weighted images as well. To achieve this, we designed and trained an image to image translation GAN from pairs of T1 and T2 images of newborn subjects in a training set; and used it as a real-time pre-processing step for T1-weighted scans before they were fed into the pose estimation networks. In fact, with the conditional GAN algorithm, many of the learning based algorithms can be generalized over different modalities as long as some paired images are provided for training.
5 Conclusion
We developed and evaluated deep pose estimation networks for 2D to 3D and 3D to 3D registration. In learning-based image to atlas (standard space) registration scenarios, the proposed methods provided very fast (real-time) registration with a wide capture range on the space of all plausible 3D rotations, and provided good initialization for current optimization based registration methods. While the current highly-evolved multi-scale optimization-based methods that use cost functions such as mutual information or local cross correlation can converge to wrong local minima due to non-convex cost functions, our proposed CNN-based methods learn to predict 3D pose of images based on their features, in a way that is more similar to how human observers analyze shapes and their pose in 3D. A combination of these techniques and accelerated optimization-based methods can dramatically enhance the performance of imaging devices and image processing methods of the future.
Acknowledgment
This study was supported in part by the National Institute of Biomedical Imaging and Bioengineering of the National Institutes of Health (NIH) grant R01 EB018988. The content of this work is solely the responsibility of the authors and does not necessarily represent the official views of the NIH.
References
- D. L. Hill, P. G. Batchelor, M. Holden, and D. J. Hawkes, “Medical image registration,” Physics in medicine & biology, vol. 46, no. 3, p. R1, 2001.
- J. P. Pluim, J. A. Maintz, and M. A. Viergever, “Mutual-information-based registration of medical images: a survey,” IEEE transactions on medical imaging, vol. 22, no. 8, pp. 986–1004, 2003.
- A. Gholipour, N. Kehtarnavaz, R. Briggs, M. Devous, and K. Gopinath, “Brain functional localization: a survey of image registration techniques,” IEEE transactions on medical imaging, vol. 26, no. 4, pp. 427–451, 2007.
- P. Markelj, D. Tomaževič, B. Likar, and F. Pernuš, “A review of 3D/2D registration methods for image-guided interventions,” Medical image analysis, vol. 16, no. 3, pp. 642–661, 2012.
- A. Sotiras, C. Davatzikos, and N. Paragios, “Deformable medical image registration: A survey,” IEEE transactions on medical imaging, vol. 32, no. 7, pp. 1153–1190, 2013.
- E. Ferrante and N. Paragios, “Slice-to-volume medical image registration: A survey,” Medical image analysis, vol. 39, pp. 101–123, 2017.
- J. Long, E. Shelhamer, and T. Darrell, “Fully convolutional networks for semantic segmentation,” in Proceedings of the IEEE conference on computer vision and pattern recognition, 2015, pp. 3431–3440.
- H. Greenspan, B. van Ginneken, and R. M. Summers, “Guest editorial deep learning in medical imaging: Overview and future promise of an exciting new technique,” IEEE Transactions on Medical Imaging, vol. 35, no. 5, pp. 1153–1159, 2016.
- G. Litjens, T. Kooi, B. E. Bejnordi, A. A. A. Setio, F. Ciompi, M. Ghafoorian, J. A. van der Laak, B. van Ginneken, and C. I. Sánchez, “A survey on deep learning in medical image analysis,” Medical image analysis, vol. 42, pp. 60–88, 2017.
- M. Simonovsky, B. Gutiérrez-Becker, D. Mateus, N. Navab, and N. Komodakis, “A deep metric for multimodal registration,” in International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, 2016, pp. 10–18.
- G. Wu, M. Kim, Q. Wang, B. C. Munsell, and D. Shen, “Scalable high-performance image registration framework by unsupervised deep feature representations learning,” IEEE Transactions on Biomedical Engineering, vol. 63, no. 7, pp. 1505–1516, July 2016.
- X. Yang, R. Kwitt, M. Styner, and M. Niethammer, “Quicksilver: Fast predictive image registration–a deep learning approach,” NeuroImage, vol. 158, pp. 378–396, 2017.
- R. Liao, S. Miao, P. de Tournemire, S. Grbic, A. Kamen, T. Mansi, and D. Comaniciu, “An artificial agent for robust image registration.” in AAAI, 2017, pp. 4168–4175.
- S. Miao, Z. J. Wang, Y. Zheng, and R. Liao, “Real-time 2D/3D registration via cnn regression,” in Biomedical Imaging (ISBI), 2016 IEEE 13th International Symposium on. IEEE, 2016, pp. 1430–1434.
- S. Miao, Z. J. Wang, and R. Liao, “A cnn regression approach for real-time 2D/3D registration,” IEEE transactions on medical imaging, vol. 35, no. 5, pp. 1352–1363, 2016.
- J. Wu, T. Xue, J. J. Lim, Y. Tian, J. B. Tenenbaum, A. Torralba, and W. T. Freeman, “Single image 3d interpreter network,” in European Conference on Computer Vision. Springer, 2016, pp. 365–382.
- G. Pavlakos, X. Zhou, A. Chan, K. G. Derpanis, and K. Daniilidis, “6-dof object pose from semantic keypoints,” in Robotics and Automation (ICRA), 2017 IEEE International Conference on. IEEE, 2017, pp. 2011–2018.
- S. Tulsiani and J. Malik, “Viewpoints and keypoints,” in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2015, pp. 1510–1519.
- H. Su, C. R. Qi, Y. Li, and L. J. Guibas, “Render for CNN: Viewpoint estimation in images using cnns trained with rendered 3d model views,” in Proceedings of the IEEE International Conference on Computer Vision, 2015, pp. 2686–2694.
- S. Mahendran, H. Ali, and R. Vidal, “3d pose regression using convolutional neural networks,” in IEEE International Conference on Computer Vision, vol. 1, no. 2, 2017, p. 4.
- D. Q. Huynh, “Metrics for 3D rotations: Comparison and analysis,” Journal of Mathematical Imaging and Vision, vol. 35, no. 2, pp. 155–164, 2009.
- A. Gholipour, J. A. Estroff, and S. K. Warfield, “Robust super-resolution volume reconstruction from slice acquisitions: application to fetal brain MRI,” IEEE transactions on medical imaging, vol. 29, no. 10, pp. 1739–1758, 2010.
- B. Kainz, M. Steinberger, W. Wein, M. Kuklisova-Murgasova, C. Malamateniou, K. Keraudren, T. Torsney-Weir, M. Rutherford, P. Aljabar, J. V. Hajnal et al., “Fast volume reconstruction from motion corrupted stacks of 2D slices,” IEEE transactions on medical imaging, vol. 34, no. 9, pp. 1901–1913, 2015.
- B. Marami, S. S. M. Salehi, O. Afacan, B. Scherrer, C. K. Rollins, E. Yang, J. A. Estroff, S. K. Warfield, and A. Gholipour, “Temporal slice registration and robust diffusion-tensor reconstruction for improved fetal brain structural connectivity analysis,” NeuroImage, vol. 156, pp. 475–488, 2017.
- B. Hou, A. Alansary, S. McDonagh, A. Davidson, M. Rutherford, J. V. Hajnal, D. Rueckert, B. Glocker, and B. Kainz, “Predicting slice-to-volume transformation in presence of arbitrary subject motion,” in International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, 2017, pp. 296–304.
- B. Hou, B. Khanal, A. Alansary, S. McDonagh, A. Davidson, M. Rutherford, J. V. Hajnal, D. Rueckert, B. Glocker, and B. Kainz, “3d reconstruction in canonical co-ordinate space from arbitrarily oriented 2D images,” arXiv preprint arXiv:1709.06341, 2017.
- A. I. Namburete, W. Xie, M. Yaqub, A. Zisserman, and J. A. Noble, “Fully-automated alignment of 3D fetal brain ultrasound to a canonical reference space using multi-task learning,” Medical Image Analysis, 2018.
- A. Kendall, M. Grimes, and R. Cipolla, “Posenet: A convolutional network for real-time 6-dof camera relocalization,” in Computer Vision (ICCV), 2015 IEEE International Conference on. IEEE, 2015, pp. 2938–2946.
- K. He, X. Zhang, S. Ren, and J. Sun, “Identity mappings in deep residual networks,” in European Conference on Computer Vision. Springer, 2016, pp. 630–645.
- S. S. M. Salehi, S. R. Hashemi, C. Velasco-Annis, A. Ouaalam, J. A. Estroff, D. Erdogmus, S. K. Warfield, and A. Gholipour, “Real-time automatic fetal brain extraction in fetal mri by deep learning,” arXiv preprint arXiv:1710.09338, 2017.
- S. S. M. Salehi, D. Erdogmus, and A. Gholipour, “Auto-context convolutional neural network (auto-net) for brain extraction in magnetic resonance imaging,” IEEE transactions on medical imaging, vol. 36, no. 11, pp. 2319–2330, 2017.
- P. A. Yushkevich, J. Piven, H. C. Hazlett, R. G. Smith, S. Ho, J. C. Gee, and G. Gerig, “User-guided 3D active contour segmentation of anatomical structures: significantly improved efficiency and reliability,” Neuroimage, vol. 31, no. 3, pp. 1116–1128, 2006.
- A. Gholipour, C. K. Rollins, C. Velasco-Annis, A. Ouaalam, A. Akhondi-Asl, O. Afacan, C. M. Ortinau, S. Clancy, C. Limperopoulos, E. Yang et al., “A normative spatiotemporal mri atlas of the fetal brain for automatic segmentation and analysis of early brain growth,” Scientific reports, vol. 7, no. 1, p. 476, 2017.
- J. Arvo, Graphics gems II. Elsevier, 2013.
- P. Isola, J.-Y. Zhu, T. Zhou, and A. A. Efros, “Image-to-image translation with conditional adversarial networks,” arXiv preprint, 2017.
- J.-Y. Zhu, T. Park, P. Isola, and A. A. Efros, “Unpaired image-to-image translation using cycle-consistent adversarial networks,” arXiv preprint arXiv:1703.10593, 2017.