Learning to rank for censored survival data
Survival analysis is a type of semi-supervised ranking task where the target output (the survival time) is often right-censored. Utilizing this information is a challenge because it is not obvious how to correctly incorporate these censored examples into a model. We study how three categories of loss functions, namely partial likelihood methods, rank methods, and our classification method based on a Wasserstein metric (WM) and the non-parametric Kaplan Meier estimate of the probability density to impute the labels of censored examples, can take advantage of this information. The proposed method allows us to have a model that predict the probability distribution of an event. If a clinician had access to the detailed probability of an event over time this would help in treatment planning. For example, determining if the risk of kidney graft rejection is constant or peaked after some time. Also, we demonstrate that this approach directly optimizes the expected C-index which is the most common evaluation metric for ranking survival models.
Survival analysis, also known as time-to-event analysis aims to predict the first time of the occurrence of a stochastic event, conditioned on a set of features. An example in the case of medical data is the time of death or a graft failure after an operation. In cases where the time of event for many samples is missing because the event wasn’t observed, this can be framed as a particular type of semi-supervised learning where part of the target values are referred to as right-censored. Formally we can say that for some examples we do not have the time of event , but rather a time (censoring time) such that we know . The classical approach to survival analysis is the Cox proportional hazards model (Cox, 1972) that takes into account censored samples. Ranking approaches (Raykar et al., 2007) are also a way to take these censored samples into account by incorporating them into the training using pairwise ranking loss where although the exact time of event is not known the pairwise relationship with respect to a censoring date is known for event occurring before the censored event. We would like to predict the probability distribution of an event as it will help in treatment planning. For example, determining if the risk of kidney graft rejection is constant or peaked after some time.
In this study, we propose to use the Wasserstein metric to have a model predict the probability distribution of the event time. This approach not only provides an interpretable prediction but allows us to impute the distribution of censored samples given global survival statistics with the non-parametric Kaplan Meier estimate. Our intuition is that training with the KM estimate provides a richer signal during training than a rank loss would provide. Also, we find that this approach directly optimizes the C-index (Harrell et al., 1982) which is the most common evaluation metric for ranking survival models. We compared our proposed loss with a set of common ranking-specific losses on several reference survival datasets.
2 Survival data
In what follows, we will use the following notations. Let be the feature vector of the -th example and let take value 1 if event happened at time and 0 otherwise. Moreover, let be the estimated probability of event happening at time and let be the (scalar) actual time of event . We denote by and the true and estimated cumulative probability distribution of . Namely, . Finally, let be 1 if example is observed (non-censored) and 0 otherwise.
2.1 Ties and censored data
Survival datasets describe medical events that can have a low temporal resolution (time scale) causing ties between patients. A given unique time (at a given resolution, e.g., one day) can correspond to multiple events. Such events are tied and that would imply that more precise predictions are not relevant. However, they must be given special attention in constructing loss functions.
As mentioned earlier, another characteristic of survival data is that they are right-censored. We can still use these examples by only comparing with patients that had an event before the date of censorship or by imputing the event time based on statistics over the data.
2.2 Metric of evaluation
The concordance index or C-index (Harrell et al., 1982) is the standard evaluation metric for survival data. It corresponds to the normalized Kendall tau metric between the true and predicted distribution (Kendall, 1938). It can be seen as a generalization of the Area Under the Receiver Operating Characteristic Curve (AUROC) that can handle right-censored data (Raykar et al., 2007).
We define an acceptable pair as one for which we are sure the first event occurs before the second. These are the pairs for which the first element is non-censored, and for which the censoring or event time of the second element is strictly greater than the first. Let be the set of acceptable pairs. Then, the C-index to be maximized can be written as:
3 Loss functions for censored data
In this section, we present loss functions in the context of survival prediction for censored data. We divide these loss functions into three categories: partial likelihood methods, rank methods, and our classification method based on a Wasserstein metric (WM).
3.1 Cox Model
Cox introduced a general conditional log-likelihood to fit survival models, in which the probability of observations is maximized (Cox, 1972). It was demonstrated by (Raykar et al., 2007) that maximizing the Cox’s partial likelihood is approximately equivalent to maximizing the C-index. We present the general formula, with a real-valued score prediction function estimating the probability of the event at a particular time, given input features . Denoting the predicted score the loss is:
We also consider a variant of this loss, Efron’s approximation (Efron, 1977) that commonly improves performance when there are many tied event times.
3.2 Ranking losses
Many methods attempt to directly predict the rank of the different examples. This is done by learning the following objective:
where is a function that relaxes the non-differentiable of the C-index (Raykar et al., 2007). We evaluated the functions used in (Raykar et al., 2007), Ranking SVM (Herbrich, 2000), Rankboost (Freund et al., 2003) and RankNet (Burges et al., 2005). These functions have been shown in (Kalbfleisch, 1978) to correspond to lower bounds on the C-index. We use to denote the Sigmoid function .
3.3 Wasserstein metric
While there have to our knowledge been no previous attempts to use the Wasserstein metric on survival data or ranking problems, (Frogner et al., 2015) used a Wasserstein loss for image classification and tag prediction. (Hou et al., 2016) and (Beckham & Pal, 2017) apply a Wasserstein metric for the more restrictive case of ordinal classification. Recently, (Mena et al., 2018) used the Sinkhorn algorithm, which is commonly used in optimal transport applications, as an analogy to the Softmax for permutations.
The WM is the minimum cost to transport the mass from one probability distribution to another. In the case of distributions of discrete supports (histograms of class probabilities), this is computed by moving probability mass from one class to another, according to the ground distance matrix specifying the cost to transport probability mass to and from different classes. Thus, the WM takes advantage of knowledge of the structure of the space of values considered, e.g., the 1-dimensional real-valued time axis, so that some errors (e.g. between neighboring events) are appropriately penalized less than others.
The WM is particularly adapted to a survival context. We denote the true data distribution, and the distribution estimated by the model. We write the set of joint distributions with left and right marginals and respectively. Given an example and corresponding real time of event , we can write:
As is a Dirac, we have that:
In all that follows, is chosen to be proportional to the number of train set elements having events between and . The term is therefore .
3.3.1 Use as a learning objective
(Levina & Bickel, 2001) notes that under certain conditions satisfied in the case of ordinal classification, the WM takes the following expression:
where is the size of the Softmax layer and is a function that returns the cumulative density function of its input density. Here, and are two probability distributions with discrete supports. We use in our experiments. We write to highlight the dependency on . The objective can be written as:
3.3.2 Imputing missing values for classification
In order to allow the WM objective to lead to good training, we have imputed the CDF of the censored data with , where is the Kaplan-Meier non parametric estimate of the survival distribution function computed on the training set (see Figure 1). With the KM estimator, the survival distribution function is estimated as a step function, where the value at time is calculated as follows:
with denoting the number of events at and the number of patients alive just before .
We assess the presented models on a variety of publicly available datasets. The characteristics of these datasets are summarized in Table 1.
|SUPPORT2||9105||2904 (32.2)||1724 (19.1)||98|
|AIDS3||3985||2223 (55.8)||1506 (37.8)||19|
|COLON||929||477 (51.3)||780 (84.0)||48|
SUPPORT2111available at http://biostat.mc.vanderbilt.edu/wiki/Main/DataSets records the survival time for patients of the SUPPORT study.
AIDS3222available at https://vincentarelbundock.github.io/Rdatasets/datasets.html corresponds to the Australian AIDS Survival Data.
COLON2 consists of data from the first successful trials of adjuvant chemotherapy for colon cancer. We considered death as a target event for our study.
4.2 Data pre-processing
We used a one-hot encoding for categorical features, and unit scaling for continuous features. For features with missing values, we added an indicator function for the absence of a value.
We performed 5 fold cross validation and kept 20% of the train set as a validation set. The prediction performance was reported as mean standard error of the C-index over the 5 folds. Early stopping was performed on the validation C-index.
We used a multi-layer perceptron (3 layers with 100 units each) with ReLU activation functions where applicable, and used Dropout (Hinton et al., 2012), Batch Normalization (Ioffe & Szegedy, 2015) and L2 regularization on the weights. We used the Adam optimizer. For the ranking and log-likelihood methods the output was a single unit with a linear activation function. For the methods requiring a prediction of output times, we used a Softmax function. Our code was written in PyTorch (Paszke et al., 2017)
We perform a grid-search for each split independently for the L2 regularization coefficient on the weight and the learning rate. We add a small constant (1 for Support2 and Aids3, 10 for Colon) to the distance between bins before normalizing. For colon we used a bin size of 2 days, and 1 day for the other two datasets.
4.3 Comparison of different ranking methods
We study the impact of the different loss functions in Table 2. We study how the standard Cox model performs in comparison to ranking and classification losses.
|Partial likelihood||Cox Efron’s||84.910.60||54.031.21||63.080.93|
4.4 Impact of using censored data
The purpose of this section is to explore how censoring is informative and demonstrate that we should not just ignore/process away censoring. We compare three methods to account for censored data. We first completely removed censored examples from the training set (no censored data). We also considered the time of censoring to correspond to an actual event occurrence (transforming each example censored at time into the same example with an event occurring at time ) (death at censoring). Finally, we also listed results for the standard approach (with censored data). In the case of WM, the censored times are imputed with the () curve.
We run this experiment on the SUPPORT2 dataset for the three best methods of each category as it is the largest public dataset we have : Cox Efron’s, and our methods WM. The results are presented in Table 3.
|No censored data||83.310.51||83.400.52||82.340.49|
|Death at censoring||82.340.58||81.970.67||80.670.55|
|With censored data||85.330.52||85.530.56||84.910.60|
4.5 Exploring the impact of censored data
In order to determine how much of an improvement we can obtain from incorporating censored data we can vary the composition of samples that are censored in the training data, while keeping the validation and test sets the same. In Figure 2 we show the evolution of the C-index with different percentages of censoring of the training set in the SUPPORT2 dataset.
We proposed a new method for learning to rank survival data. Experiments on the different datasets show that our models trained with the WM loss gives accurate predictions compared to the more classical losses of the Cox model and ranking loss functions, which directly approximate a lower bound of the C-index. While not always state of the art, our method is always among the best results for each dataset.
We also find that this approach allows the method to tolerate a high percentage of censored samples and continue to predict well given results consistently in the same range of the best methods. Also, we demonstrate that our method can be seen as directly optimizing the expected C-index which is the most common evaluation metric for ranking survival models. Moreover, our results demonstrate that imputing the values with the KM curve for the missing times in a classification framework can increase the resulting C-index.
We thank Christopher Pal and Christopher Beckham for their input on the project. This work is partially funded by a grant from the U.S. National Science Foundation Graduate Research Fellowship Program (grant number: DGE-1356104) and the Institut de valorisation des donnees (IVADO). This work utilized the supercomputing facilities managed by the Montreal Institute for Learning Algorithms, NSERC, Compute Canada, and Calcul Quebec.
- Beckham & Pal (2017) Beckham, Christopher and Pal, Christopher. Unimodal probability distributions for deep ordinal classification. International Conference on Machine Learning, 2017.
- Burges et al. (2005) Burges, Christopher J C, Shaked, Tal, Renshaw, Erin, Lazier, Ari, Deeds, Matt, Hamilton, Nicole, and Hullender, Greg. Learning to Rank using Gradient Descent. In International Conference on Machine Learning. ACM, 2005.
- Cox (1972) Cox, D. R. Regression models and life tables. Journal of the Royal Statistical Society, 1972.
- Efron (1977) Efron, Bradley. The Efficiency of Cox’s Likelihood for Censored Data Function. Journal of the American Statistical Association, 1977.
- Freund et al. (2003) Freund, Yoav, Iyer, Raj, Schapire, Robert E, and Singer, Yoram. An Efficient Boosting Algorithm for Combining Preferences. Journal of Machine Learning Research, 2003. doi: 10.1162/jmlr.2003.4.6.933.
- Frogner et al. (2015) Frogner, Charlie, Zhang, Chiyuan, Mobahi, Hossein, Araya, Mauricio, and Poggio, Tomaso A. Learning with a Wasserstein loss. In Advances in Neural Information Processing Systems, 2015.
- Harrell et al. (1982) Harrell, Frank E., Califf, Robert M., Pryor, David B., Lee, Kerry L., and Rosati, Robert A. Evaluating the Yield of Medical Tests. Journal of the American Medical Association, 1982. doi: 10.1001/jama.1982.03320430047030.
- Herbrich (2000) Herbrich, Ralf. Large margin rank boundaries for ordinal regression. Advances in large margin classifiers, 2000.
- Hinton et al. (2012) Hinton, Geoffrey E., Srivastava, Nitish, Krizhevsky, Alex, Sutskever, Ilya, and Salakhutdinov, Ruslan R. Improving neural networks by preventing co-adaptation of feature detectors. arXiv:1207.0580, 2012. doi: arXiv:1207.0580.
- Hou et al. (2016) Hou, Le, Yu, Chen-Ping, and Samaras, Dimitris. Squared Earth Mover’s Distance-based Loss for Training Deep Neural Networks. arXiv:1611.05916, 2016.
- Ioffe & Szegedy (2015) Ioffe, Sergey and Szegedy, Christian. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In International conference on machine learning, 2015.
- Kalbfleisch (1978) Kalbfleisch, John. Non-Parametric Bayesian Analysis of Survival Time Data. Journal of the Royal Statistical Society. Series B, 1978.
- Katzman et al. (2016) Katzman, Jared, Shaham, Uri, Cloninger, Alexander, Bates, Jonathan, Jiang, Tingting, and Kluger, Yuval. DeepSurv: Personalized Treatment Recommender System Using A Cox Proportional Hazards Deep Neural Network. International Conference of Machine Learning Computational Biology Workshop, 2016.
- Kendall (1938) Kendall, Maurice G. A new measure of rank correlation. Biometrika, 1938.
- Levina & Bickel (2001) Levina, Elizaveta and Bickel, Peter. The Earth Mover’s distance is the Mallows distance: Some insights from statistics. In International Conference on Computer Vision, volume 2. IEEE, 2001. doi: 10.1109/ICCV.2001.937632.
- Luck et al. (2017) Luck, Margaux, Sylvain, Tristan, Cardinal, Héloïse, Lodi, Andrea, and Bengio, Yoshua. Deep Learning for Patient-Specific Kidney Graft Survival Analysis. arXiv:1705.10245, 2017.
- Mena et al. (2018) Mena, Gonzalo, Belanger, David, Linderman, Scott, and Snoek, Jasper. Learning latent permutations with gumbel-sinkhorn networks. arXiv preprint arXiv:1802.08665, 2018.
- Paszke et al. (2017) Paszke, Adam, Chanan, Gregory, Lin, Zeming, Gross, Sam, Yang, Edward, Antiga, Luca, and Devito, Zachary. Automatic differentiation in PyTorch, 2017.
- Raykar et al. (2007) Raykar, Vikas C, Steck, Harald, Krishnapuram, Balaji, Dehing-oberije, Cary, and Lambin, Philippe. On ranking in survival analysis: Bounds on the concordance index. In Neural Information Processing Systems, 2007. doi: 10.1.1.121.2670.