Deformable Registration through Learning of ContextSpecific Metric Aggregation
Abstract
^{†}^{†}* indicates equal contribution.We propose a novel weakly supervised discriminative algorithm for learning context specific registration metrics as a linear combination of conventional similarity measures. Conventional metrics have been extensively used over the past two decades and therefore both their strengths and limitations are known. The challenge is to find the optimal relative weighting (or parameters) of different metrics forming the similarity measure of the registration algorithm. Handtuning these parameters would result in sub optimal solutions and quickly become infeasible as the number of metrics increases. Furthermore, such handcrafted combination can only happen at global scale (entire volume) and therefore will not be able to account for the different tissue properties. We propose a learning algorithm for estimating these parameters locally, conditioned to the data semantic classes. The objective function of our formulation is a special case of nonconvex function, difference of convex function, which we optimize using the concave convex procedure. As a proof of concept, we show the impact of our approach on three challenging datasets for different anatomical structures and modalities.
1 Introduction
Deformable image registration is a highly challenging problem frequently encountered in medical image analysis. It involves the definition of a similarity criterion (data term) that, once endowed with a deformation model and a smoothness constraint, determines the optimal transformation to align two given images. We adopt a popular graphical model framework [5] to cast deformable registration as a discrete inference problem. The definition of the data term is among the most critical components of the registration process. It refers to a function that measures the (dis)similarity between images such as mutual information (mi) or sum of absolute differences (sad). Metric learning in the context of image registration [2, 8, 9, 12, 19] is an alternative that aims to determine the most efficient means of image comparison (similarity measure) from labeled visual correspondences. Our approach can be considered as a specific case of metric learning where the idea is to efficiently combine the well studied mono/multimodal metrics depending on the local context. We aim to learn the relative weighting from a given training dataset using a learning framework conditioned on prior semantic knowledge. We propose a novel weakly supervised discriminative learning framework, based structured support vector machines (ssvm) [14, 16] and its extension to latent models lssvm [17], to learn the relative weights of context specific metric aggregations.
Metric learning. Various metric learning methods have been proposed in the context of image registration. Lee et al. [8] introduced a multimodal registration algorithm where the similarity measure is learned such that the target and the correctly deformed source image receive high similarity scores. The training data consisted of prealigned images and the learning is performed at the patch level with an assumption that the similarity measure decompose over the patches. [2, 9] proposed the use of sensitive hashing to learn a multimodal metric. Similar to [8], they adopted a patchwise approach. The dataset consisted of pairs of perfectly aligned images and a collection of positive/negative pairs of patches. Another patchbased alternative was presented by [15] where the training set consisted of nonaligned images with manually annotated patch pairs (landmarks). More recently, approaches based on convolutional neural networks started to gain popularity. Zagoruyko et al. [19] discussed CNN architectures to learn patch based similarity measures. One of them was then adopted in [12] to perform image registration. These methods require ground truth data in the form of correspondences (patches, landmarks or dense deformation fields), which is extremely difficult to obtain in real clinical data. Instead, our method is only based on segmentation masks.
Metric aggregation. In contrast to the above approaches, our method aggregates standard metrics using contextual information. [3] showed, in fact, that using a multichannel registration method where a set of features is globally considered instead of a single similarity measure, produced robust registration compared to using individual features. However, they did not discuss how these features can be weighted. Following this, [4] proposed to estimate different deformation fields from each feature independently, and then compose them into final diffeomorphic transformation. Such strategy produces multiple deformation models (equal to number of metrics) which might be locally inconsistent. Thus, their combination may not be anatomically meaningful. Our method is most similar to Tang et al. [13], which generates a vector weight map that determines, at each spatial location, the relative importance of each constituent of the overall metric. However, the proposed learning strategy still requires ground truth data in the form of correspondences (preregistered images) which is not necesseary in our case.
Contribution. We tackle the scenario where the ground truth deformations are not known a priori. We consider these deformation fields as latent variables, and devise an algorithm within the lssvm framework [17]. We model the latent variable imputation problem as the deformable registration problem with additional constraints. In the end, we incorporate the learned aggregated metrics in a contextspecific registration framework, where different weights are used depending on the structures being registered.
2 The Deformable Registration Problem
Let us assume a source three dimensional image , a source segmentation mask and a target image . The segmentation mask is formed by labels , where is the set of classes. We focus on the 3D to 3D deformable registration problem. Let us also adopt without loss of generality a graphical model [5, 10] for the deformable registration problem. A deformation field is sparsely represented by a regular grid graph , where is the set of nodes and is the set of edges. Each node corresponds to a control point . Each control point is allowed to move in the space, therefore, can be assigned a label from the set of displacement vectors . Notice that each displacement vector is a tuple defined as , where , , and are the displacements in the , , and directions, respectively. The deformation (labeling of the graph ) denoted as is associated to a set of nodes , where each node is assigned a displacement vector from the set . The new control point obtained when the displacement is applied to the original control point is denoted as . Let us define a patch on the source image centered at the displaced control point . Similarly, we define as the patch on the target image centered at the original control point , and as the patch on the input segmentation mask centered at the displaced control point . Using the above notations, we define the unary feature vector corresponding to the node for a given displacement vector as , where is the number of metrics (or similarity measures) and is the unary feature corresponding to the metric on the patches and . In case of single metric, we define . Therefore, given a weight matrix , where denote the weight of the metric corresponding to the class , the unary potential of the node for a given displacement vector is computed as:
(1) 
where, is the column of the weight matrix and is the most dominant class in the patch on the source segmentation mask obtained as , with being the number of voxels of class in the patch . Other criterion could be used to find the dominant class. The pairwise clique potential between the control points and is defined as , where is the norm between the two input arguments. Thus, the multiclass energy function is:
(2) 
Then, the optimal deformation is obtained as This problem is nphard in general. Similar to [5], we adopt a pyramidal approach to solve the problem efficiently. We use FastPD [7] for the inference at every level of the pyramid. Notice that the energy function (2) is defined over the nodes of the sparse graph .Once we obtain the optimal deformation , we estimate the dense deformation field using a free form deformation (FFD) model [11] in order to warp the input image.
3 Learning the Parameters
Knowing the weight matrix a priori is nontrivial and hand tuning it quickly becomes infeasible as the number of metrics and classes increases. We propose an algorithm to learn conditioned on the semantic labels assuming that in the training phase semantic masks are available for the source and the target images. Instead of learning the complete weight matrix at once, we learn the weights (or parameters) for each class individually. Now onwards, the weight vector denotes a particular column of the weight matrix , representing the weights corresponding to the class.
Training Data. Consider a dataset , where , is the source image and is the target. Similarly, , where and are the segmentation masks for the source and target images. The size of each segmentation mask is the same as that of the corresponding images. As stated earlier, the segmentation mask is formed by the elements (or voxels) , where is the set of classes.
Loss Function. The loss function evaluates the similarity between the segmentation masks and . Higher implies higher dissimilarity. We use a dice based loss function as this is our evaluation criteria:
(3) 
where, and are the patches at the control point on the segmentation masks and , respectively, and represents cardinality. This approximation makes the dice decomposable over the nodes of enabling a very efficient training.
Joint Feature Map. Given for the th class, the deformation and input , the multiclass function (2) can be trivially converted into classbased energy function as:
(4) 
where is the parameter for the pairwise term. The final parameter vector is the concatenation of and . Thus, the function (4) can be written as:
(5) 
where is the joint feature map defined as:
(6) 
Notice that the energy function (4) does not depend on the source segmentation mask . The only use of the source segmentation mask in the energy function (2) is to obtain the dominant class which in this case is not required. However, we will shortly see that the source segmentation mask plays a crucial role in the learning algorithm.
Latent Variables. Ideally, the dataset must contain the ground truth deformations corresponding to the source image in order to compute the energy term defined in the equation (4). Since annotating the dataset with the ground truth deformation is nontrivial, we use them as the latent variables in our algorithm.
The Objective Function. Given , we learn the parameter such that minimizing the energy function (4) leads to a deformation field which when applied to the source segmentation mask gives minimum loss with respect to the target segmentation mask. We denote as the deformed segmentation when the dense deformation field obtained from is applied to the segmentation mask . Similarly to the latent ssvm [17], we optimize a regularized upper bound on the loss:
(7) 
where, . The above objective function minimizes an upper bound on the given loss, called slack (). The effect of the regularization term is controlled by the hyperparameter . The second term is the proximity term to ensure that the learned is close to the initialization . The effect of the proximity term is controlled by the hyperparameter . Intuitively, for a given inputoutput pair, the constraints of the above objective function enforce that the energy corresponding to the best possible deformation field, in terms of both energy and loss (in order to be semantically meaningful), must always be less than or equal to the energy corresponding to any other deformation field with a margin proportional to the loss and some non negative slack.
The Learning Algorithm. The objective function (3) turns out to be a special case of nonconvex functions (difference of convex), thus can be locally optimized using the well known cccp algorithm [18]. The cccp algorithm consist of three steps – (1) upperbounding the concave part at a given , which leads to an affine function in ; (2) optimizing the resultant convex function (sum of convex and affine functions is convex); (3) repeating the above steps until the objective can not be further decreased beyond a given tolerance of . The complete cccp algorithm for the optimization of (3) is shown Algorithm 1. The first step of upperbounding the concave functions (Line 4) is the same as the latent imputation step, which we call the segmentation consistent registration problem. The second step is the optimization of the resultant convex problem (Line 5), which is the optimization of the ssvm for which we use the well known cutting plane algorithm [6]. In what follows, we discuss these steps in detail.
Segmentation Consistent Registration. This step involves generating the best possible ground truth deformation field (unknown a priori) at a given , known as the latent imputation step. Since we optimize the dice loss, we formulate this step as an inference problem with additional constraints to ensure that the imputed deformation warps the image minimizing the loss between the deformed source and the target. Mathematically, for a given parameter vector , the latent deformation is imputed by solving:
(8) 
We relax the above problem as it is difficult and may not have a unique solution:
(9) 
where, controls the relaxation tradeoff. Since the loss function used is decomposable, the above problem can be optimized using FastPD inference for the deformable registration with trivial modifications on the unary potentials.
Parameters update. Given the imputed latent variables, the resultant objective is:
(10) 
where, is the latent deformation field imputed by solving the problem (9). Intuitively, the above objective function tries to learn the parameters such that the energy corresponding to the imputed deformation field is always less than the energy for any other deformation field with a margin proportional to the loss function. The above objective function has exponential number of constraints, one for each possible deformation field . In order to alleviate this problem we use cutting plane algorithm [6]. Briefly, for a given , each deformation field gives a slack. Instead of minimizing all the slacks for a particular sample at once, we find the deformation field that leads to the maximum value of the slack and store this in a set known as the working set. This is known as finding the most violated constraint. Thus, instead of using exponentially many constraints, the algorithm uses the constraints stored in the working set and this process is repeated until no constraints can be further added. Rearranging the terms in the constraints of the objective function (3) and ignoring the constant term , the most violated constraint can be obtained by solving:
(11) 
Since the loss is decomposable, this problem can be solved using FastPD inference for the deformable registration with trivial modifications on the unary terms.
Prediction. Once we obtain the learned parameters for each class using the Algorithm 1, we form the matrix where each column of the matrix represents the learned parameter for a specific class. This is then used to solve the registration problem (equation (2)) using the approximate inference discussed in Section 2.
4 Results and discussion
As a proof of concept, we evaluate the effect of the aggregated metric on three different medical datasets – (1) RT Parotids, (2) RT Abdominal, and a downsampled version of (3) IBSR [1], involving several anatomical structures, different image modalities, and inter/intra patient images We used four different metrics: (1) sum of absolute differences (sad), (2) mutual information (mi), (3) normalized cross correlation (ncc), and (4) discrete wavelet coefficients (dwt). The datasets consists of 8 CT (RT Parotids, head images of voxels), 5 CT (RT Abdominal, abdominal images of voxels) and 18 MRI images (a downsampled version of IBSR dataset, including brain images voxels). We performed mulifold cross validation in every dataset, considering pairs with different patients in training and testing. For a complete description of the datasets and the experimental setting, please refer to the supplementary material. The results on the test sets are shown in Figure 1.
As it can be observed in Figure 1, the linear combination of similarity measures weighted using the learned coefficients systematically outperforms (or is as good as) single metric based registration, with max improvements of 8% in terms of dice.
Discussion and conclusions. We have showed that associating different similarity criteria to every anatomical region yields results superior to the classic single metric approach. In order to learn this mapping where ground truth is generally given in the form of segmentation masks, we defined deformation fields as latent variables and proposed a lssvm based framework. The main limitation of our method is the need of segmentation masks for the source images in testing time. However, different real scenarios like radiation therapy or atlasbased segmentation methods fulfill this condition. Note that, at prediction (testing) time, the segmentation mask is used to determine the metrics weights combination per control node (finding the dominant class). The segmentation labels are not used at testing time to guide the registration process which is purely image based. In our multimetric registration approach, segmentation masks are only required (at testing time) for the source image and used to choose the best learned metric aggregation. The idea could be further extended to unlabeled data (as it concerns the source image at testing time) where the dominant label class per control node is the output of a classification/learning method. From a theoretical viewpoint, we showed how the three main components of LSSVM: (1) latent imputation (Eq. 9); (2) prediction (optimizing Eq. 2) and (3) finding most violated constraint (Eq. (11)), can be formulated as the exact same problem. The difference among these problems is the unary potentials used. This is extremely important given that further improvements in inference algorithms will directly increase the quality of the results. As future work, the integration of alternative accuracy measures, other than dice, such as the Hausdorff distance between surfaces or real geometric distances for anatomical landmarks could further enhance the performance of the method.
References
 [1] IBSR. Internet Brain Segmentation Repository [Online]. Available: http://www.cma.mgh.harvard.edu/ibsr/ Available: http://www.cma.mgh.harvard.edu/ibsr/
 [2] Bronstein, M.M., Bronstein, A.M., Michel, F., Paragios, N.: Data fusion through crossmodality metric learning using similaritysensitive hashing. In: CVPR 2010. pp. 3594–3601. IEEE (2010)
 [3] Cifor, A., Risser, L., Chung, D., Anderson, E.M., Schnabel, J.A.: Hybrid featurebased LogDemons registration for tumour tracking in 2D liver ultrasound images. ISBI (2012)
 [4] Cifor, A., Risser, L., Chung, D., Anderson, E.M., Schnabel, J.A.: Hybrid featurebased diffeomorphic registration for tumor tracking in 2D liver ultrasound images. IEEE TMI (2013)
 [5] Glocker, B., Komodakis, N., Tziritas, G., Navab, N., Paragios, N.: Dense image registration through mrfs and efficient linear programming. Medical Image Analysis 12(6) (2008)
 [6] Joachims, T., Finley, T., Yu, C.: Cuttingplane training of structural SVMs. Machine Learning (2009)
 [7] Komodakis, N., Tziritas, G., Paragios, N.: Fast, approximately optimal solutions for single and dynamic mrfs. In: CVPR (2007)
 [8] Lee, D., Hofmann, M., Steinke, F., Altun, Y., Cahill, N.D., Scholkopf, B.: Learning similarity measure for multimodal 3d image registration. In: Computer Vision and Pattern Recognition, 2009. CVPR 2009. IEEE Conference on. pp. 186–193. IEEE (2009)
 [9] Michel, F., Bronstein, M., Bronstein, A., Paragios, N.: Boosted metric learning for 3D multimodal deformable registration. In: ISBI (2011)
 [10] Paragios, N., Ferrante, E., Glocker, B., Komodakis, N., Parisot, S., Zacharaki, E.I.: (hyper)graphical models in biomedical image analysis. Medical Image Analysis 33, 102–106 (2016)
 [11] Rueckert, D., Sonoda, L.I., et al.: Nonrigid registration using freeform deformations: Application to breast mr images. In: IEEE TMI (1999)
 [12] Simonovsky, M., GutiérrezBecker, B., Mateus, D., Navab, N., Komodakis, N.: A deep metric for multimodal registration. In: MICCAI (2016)
 [13] Tang, L., Hero, A., Hamarneh, G.: Locallyadaptive similarity metric for deformable medical image registration. In: ISBI. IEEE (2012)
 [14] Taskar, B., Guestrin, C., Koller, D.: Maxmargin Markov networks. In: NIPS (2003)
 [15] Toga, A.W.: Learning based coarsetofine image registration. CVPR (2008)
 [16] Tsochantaridis, I., Hofmann, T., Joachims, T., Altun, Y.: Support vector machine learning for interdependent and structured output spaces. In: ICML (2004)
 [17] Yu, C.N., Joachims, T.: Learning structural svms with latent variables. In: ICML (2009)
 [18] Yuille, A., Rangarajan, A.: The concaveconvex procedure. Neural Computation (2003)
 [19] Zagoruyko, S., Komodakis, N.: Learning to Compare Image Patches via Convolutional Neural Networks
Supplementary Material
Appendix A Detailed experimental setting description
In what follows, we provide a detailed description of the experimental setup used for the evaluation presented in the main paper. For all the experiments, we used the same set of parameters for the pyramidal based inference discussed in Section 2: pyramid levels, refinement steps per pyramid level, labels, and distance between control points of mm in the finer level. The running time for each registration case was around seconds. For the training, we initialized with the hand tuned values (obtained using grid search for the values {0.01, 0.1, 1, 10} for each metric: , for sad, mi, ncc, and dwt, respectively.
The images used for evaluation correspond to the following datasets:

RT Parotids. This dataset contains CT volumes of head, obtained from 4 different patients, 2 volumes per patient. The volumes are captured in two different stages of a radiotherapy treatment in order to estimate the radiation dose. Right and left parotid glands were segmented by the specialists in every volume. The dimensions of the volumes are voxels with a physical spacing of mm, mm, and mm, in x, y, and z axes, respectively. We generated pairs of source and target volumes using the given dataset. Notice that, while generating the source and target pairs, we did not mix the volumes coming from different patients. We splitted the dataset into train and test. The average results on the test set are shown in Figure 1.a from the main paper.

RT Abdominal. The second dataset contains CT volumes of abdomen for a particular patient captured with a time window of about 7 days during a radiotherapy treatment. Three organs have been manually segmented by the specialists: (1) sigmoid, (2) rectum, and (3) bladder. The dimensions of the volumes are voxels with a physical spacing of mm, mm, and mm, in x, y, and z axes, respectively (there are small variations depending on the volume). We generated a train dataset of 6 pairs and test dataset of 4 pairs. The results on the test set are shown in Figure 1.b from the main paper.

IBSR. The third dataset (IBSR) is the well known Internet Brain Segmentation Repository dataset, which consists of brain mri volumes, coming from different patients. Segmentations of three different brain structures are provided by the specialists: white mater (WM), gray mater (GM), and cerebrospinal fluid (CSF). We used a downsampled version of the dataset to reduce the computation cost. The dimension of the volumes are voxels with a physical spacing of mm, mm, and mm in x, y, and z axes, respectively. We divided the 18 volumes in 2 folds of 9 volumes on each fold, giving total of 72 pairs per fold. We used an stochastic approach for the learning process, where we sample 10 different pairs from the training set, and we tested on the 72 pairs of the other fold. We run this experiment 3 times per fold, giving a total of 6 different experiments, with 72 testing samples and 10 training samples randomly chosen. The results on the test set are shown in Figure 1.c from the main paper.
For all the datasets, experiments were performed in two steps. First, we learned the weighting vector independently for every organ . Second, we plugged the learned weights in the multimetric registration algorithm and we register every testing case using the method presented in Section 2. We also run experiments using single metrics (sum of absolute differences (SAD), mutual information (MI), normalized crossed correlation (NCC) and discrete wavelet transform (DWT)) with hand tuned weights obtained using a simple grid search. Results are summarized in Figure 1 from the main paper (detailed numerical values are included in Table 1). As it can be observed, the trained multimetric algorithm outperforms the single metric approaches in all the organs.
Appendix B Quantitative Results
The following table contains the numerical results corresponding to Figure 1 from the main paper.
Dataset  Organ  SAD  MI  NCC  DWT  MW  Average dice increment for MW 
Parotl  0,756  0,760  0,750  0,757  0,788  0,033  
RT Parotids  Parotr  0,813  0,798  0,783  0,774  0,811  0,019 
Bladder  0,661  0,643  0,662  0,652  0,736  0,082  
Sigmoid  0,429  0,423  0,432  0,426  0,497  0,070  
RT Abdominal  Rectum  0,613  0,606  0,620  0,617  0,660  0,046 
CSF  0,447  0,520  0,543  0,527  0,546  0,037  
GM  0,712  0,725  0,735  0,734  0,761  0,035  
IBSR  WM  0,629  0,658  0,669  0,661  0,682  0,028 
Appendix C Qualitative results
Below we show visual results on three datasets used as a proofofconcept for our proposed method, to highlight the effects of learning the weights of different metrics for the task of deformable registration.