Learning to rank for censored survival data

Learning to rank for censored survival data

Margaux Luck    Tristan Sylvain    Joseph Paul Cohen    Héloïse Cardinal    Andrea Lodi    Yoshua Bengio
Abstract

Survival analysis is a type of semi-supervised ranking task where the target output (the survival time) is often right-censored. Utilizing this information is a challenge because it is not obvious how to correctly incorporate these censored examples into a model. We study how three categories of loss functions, namely partial likelihood methods, rank methods, and our classification method based on a Wasserstein metric (WM) and the non-parametric Kaplan Meier estimate of the probability density to impute the labels of censored examples, can take advantage of this information. The proposed method allows us to have a model that predict the probability distribution of an event. If a clinician had access to the detailed probability of an event over time this would help in treatment planning. For example, determining if the risk of kidney graft rejection is constant or peaked after some time. Also, we demonstrate that this approach directly optimizes the expected C-index which is the most common evaluation metric for ranking survival models.

Machine Learning, ICML, Survival analysis, Ranking

1 Introduction

Survival analysis, also known as time-to-event analysis aims to predict the first time of the occurrence of a stochastic event, conditioned on a set of features. An example in the case of medical data is the time of death or a graft failure after an operation. In cases where the time of event for many samples is missing because the event wasn’t observed, this can be framed as a particular type of semi-supervised learning where part of the target values are referred to as right-censored. Formally we can say that for some examples we do not have the time of event , but rather a time (censoring time) such that we know . The classical approach to survival analysis is the Cox proportional hazards model (Cox, 1972) that takes into account censored samples. Ranking approaches (Raykar et al., 2007) are also a way to take these censored samples into account by incorporating them into the training using pairwise ranking loss where although the exact time of event is not known the pairwise relationship with respect to a censoring date is known for event occurring before the censored event. We would like to predict the probability distribution of an event as it will help in treatment planning. For example, determining if the risk of kidney graft rejection is constant or peaked after some time.

In this study, we propose to use the Wasserstein metric to have a model predict the probability distribution of the event time. This approach not only provides an interpretable prediction but allows us to impute the distribution of censored samples given global survival statistics with the non-parametric Kaplan Meier estimate. Our intuition is that training with the KM estimate provides a richer signal during training than a rank loss would provide. Also, we find that this approach directly optimizes the C-index (Harrell et al., 1982) which is the most common evaluation metric for ranking survival models. We compared our proposed loss with a set of common ranking-specific losses on several reference survival datasets.

2 Survival data

In what follows, we will use the following notations. Let be the feature vector of the -th example and let take value 1 if event happened at time and 0 otherwise. Moreover, let be the estimated probability of event happening at time and let be the (scalar) actual time of event . We denote by and the true and estimated cumulative probability distribution of . Namely, . Finally, let be 1 if example is observed (non-censored) and 0 otherwise.

2.1 Ties and censored data

Survival datasets describe medical events that can have a low temporal resolution (time scale) causing ties between patients. A given unique time (at a given resolution, e.g., one day) can correspond to multiple events. Such events are tied and that would imply that more precise predictions are not relevant. However, they must be given special attention in constructing loss functions.

As mentioned earlier, another characteristic of survival data is that they are right-censored. We can still use these examples by only comparing with patients that had an event before the date of censorship or by imputing the event time based on statistics over the data.

2.2 Metric of evaluation

The concordance index or C-index (Harrell et al., 1982) is the standard evaluation metric for survival data. It corresponds to the normalized Kendall tau metric between the true and predicted distribution (Kendall, 1938). It can be seen as a generalization of the Area Under the Receiver Operating Characteristic Curve (AUROC) that can handle right-censored data (Raykar et al., 2007).

We define an acceptable pair as one for which we are sure the first event occurs before the second. These are the pairs for which the first element is non-censored, and for which the censoring or event time of the second element is strictly greater than the first. Let be the set of acceptable pairs. Then, the C-index to be maximized can be written as:

3 Loss functions for censored data

In this section, we present loss functions in the context of survival prediction for censored data. We divide these loss functions into three categories: partial likelihood methods, rank methods, and our classification method based on a Wasserstein metric (WM).

3.1 Cox Model

Cox introduced a general conditional log-likelihood to fit survival models, in which the probability of observations is maximized (Cox, 1972). It was demonstrated by (Raykar et al., 2007) that maximizing the Cox’s partial likelihood is approximately equivalent to maximizing the C-index. We present the general formula, with a real-valued score prediction function estimating the probability of the event at a particular time, given input features . Denoting the predicted score the loss is:

We also consider a variant of this loss, Efron’s approximation (Efron, 1977) that commonly improves performance when there are many tied event times.

In our experiments, the Cox variant refers to a multi-layer perceptron (MLP) trained with the normal Cox loss or with Efron’s approximation loss, as in (Katzman et al., 2016; Luck et al., 2017).

3.2 Ranking losses

Many methods attempt to directly predict the rank of the different examples. This is done by learning the following objective:

where is a function that relaxes the non-differentiable of the C-index (Raykar et al., 2007). We evaluated the functions used in (Raykar et al., 2007), Ranking SVM  (Herbrich, 2000), Rankboost (Freund et al., 2003) and RankNet (Burges et al., 2005). These functions have been shown in (Kalbfleisch, 1978) to correspond to lower bounds on the C-index. We use to denote the Sigmoid function .

3.3 Wasserstein metric

While there have to our knowledge been no previous attempts to use the Wasserstein metric on survival data or ranking problems, (Frogner et al., 2015) used a Wasserstein loss for image classification and tag prediction. (Hou et al., 2016) and  (Beckham & Pal, 2017) apply a Wasserstein metric for the more restrictive case of ordinal classification. Recently, (Mena et al., 2018) used the Sinkhorn algorithm, which is commonly used in optimal transport applications, as an analogy to the Softmax for permutations.

The WM is the minimum cost to transport the mass from one probability distribution to another. In the case of distributions of discrete supports (histograms of class probabilities), this is computed by moving probability mass from one class to another, according to the ground distance matrix specifying the cost to transport probability mass to and from different classes. Thus, the WM takes advantage of knowledge of the structure of the space of values considered, e.g., the 1-dimensional real-valued time axis, so that some errors (e.g. between neighboring events) are appropriately penalized less than others.

The WM is particularly adapted to a survival context. We denote the true data distribution, and the distribution estimated by the model. We write the set of joint distributions with left and right marginals and respectively. Given an example and corresponding real time of event , we can write:

As is a Dirac, we have that:

In all that follows, is chosen to be proportional to the number of train set elements having events between and . The term is therefore .

3.3.1 Use as a learning objective

(Levina & Bickel, 2001) notes that under certain conditions satisfied in the case of ordinal classification, the WM takes the following expression:

where is the size of the Softmax layer and is a function that returns the cumulative density function of its input density. Here, and are two probability distributions with discrete supports. We use in our experiments. We write to highlight the dependency on . The objective can be written as:

3.3.2 Imputing missing values for classification

In order to allow the WM objective to lead to good training, we have imputed the CDF of the censored data with , where is the Kaplan-Meier non parametric estimate of the survival distribution function computed on the training set (see Figure 1). With the KM estimator, the survival distribution function is estimated as a step function, where the value at time is calculated as follows:

with denoting the number of events at and the number of patients alive just before .

Figure 1: An overview of the proposed distribution matching loss. In the case that a sample is censored the KM estimate is used to impute the probability that should be assigned for that event.

4 Experiments

4.1 Datasets

We assess the presented models on a variety of publicly available datasets. The characteristics of these datasets are summarized in Table 1.

Datasets
Nb.
samples
Nb. ()
censored
Nb. ()
unique times
Nb.
features
SUPPORT2 9105 2904 (32.2) 1724 (19.1) 98
AIDS3 3985 2223 (55.8) 1506 (37.8) 19
COLON 929 477 (51.3) 780 (84.0) 48
Table 1: Characteristics of the datasets used in our evaluation. The datasets have different numbers of samples, percentage of censored, and tied patients. The features are typically continuous or discrete clinical attributes.

SUPPORT2111available at http://biostat.mc.vanderbilt.edu/wiki/Main/DataSets records the survival time for patients of the SUPPORT study.

AIDS3222available at https://vincentarelbundock.github.io/Rdatasets/datasets.html corresponds to the Australian AIDS Survival Data.

COLON2 consists of data from the first successful trials of adjuvant chemotherapy for colon cancer. We considered death as a target event for our study.

4.2 Data pre-processing

We used a one-hot encoding for categorical features, and unit scaling for continuous features. For features with missing values, we added an indicator function for the absence of a value.

We performed 5 fold cross validation and kept 20% of the train set as a validation set. The prediction performance was reported as mean standard error of the C-index over the 5 folds. Early stopping was performed on the validation C-index.

We used a multi-layer perceptron (3 layers with 100 units each) with ReLU activation functions where applicable, and used Dropout (Hinton et al., 2012), Batch Normalization (Ioffe & Szegedy, 2015) and L2 regularization on the weights. We used the Adam optimizer. For the ranking and log-likelihood methods the output was a single unit with a linear activation function. For the methods requiring a prediction of output times, we used a Softmax function. Our code was written in PyTorch (Paszke et al., 2017)

We perform a grid-search for each split independently for the L2 regularization coefficient on the weight and the learning rate. We add a small constant (1 for Support2 and Aids3, 10 for Colon) to the distance between bins before normalizing. For colon we used a bin size of 2 days, and 1 day for the other two datasets.

4.3 Comparison of different ranking methods

We study the impact of the different loss functions in Table 2. We study how the standard Cox model performs in comparison to ranking and classification losses.

Loss Type Variant SUPPORT2 AIDS3 COLON
Partial likelihood Cox 84.900.63 54.840.82 64.660.44
Partial likelihood Cox Efron’s 84.910.60 54.031.21 63.080.93
Ranking 85.530.56 55.351.19 64.220.61
Ranking Log-sigmoid 85.440.57 55.281.29 63.360.52
Ranking 84.960.56 55.411.20 63.981.12
Ranking 85.350.58 55.730.93 61.960.91
Classification WM (ours) 85.330.52 56.031.01 64.320.39
Table 2: Performance scores of the different methods. The table reports the C-index mean standard error over the 5 fold. For each dataset, the best model in terms of mean score is highlighted in bold. We draw the readers attention to the classification losses which are among the losses that give the best results.

4.4 Impact of using censored data

The purpose of this section is to explore how censoring is informative and demonstrate that we should not just ignore/process away censoring. We compare three methods to account for censored data. We first completely removed censored examples from the training set (no censored data). We also considered the time of censoring to correspond to an actual event occurrence (transforming each example censored at time into the same example with an event occurring at time ) (death at censoring). Finally, we also listed results for the standard approach (with censored data). In the case of WM, the censored times are imputed with the () curve.

We run this experiment on the SUPPORT2 dataset for the three best methods of each category as it is the largest public dataset we have : Cox Efron’s, and our methods WM. The results are presented in Table 3.

Method WM Ranking Cox
No censored data 83.310.51 83.400.52 82.340.49
Death at censoring 82.340.58 81.970.67 80.670.55
With censored data 85.330.52 85.530.56 84.910.60
Table 3: We explore how the three categories of methods are impacted by adding censored data. The table reports the C-index mean standard error over the 5 fold. For ”Death at censoring”, we set the death event as the censored time. It is clear that censored data contains information that we can use to make better predictions.

4.5 Exploring the impact of censored data

In order to determine how much of an improvement we can obtain from incorporating censored data we can vary the composition of samples that are censored in the training data, while keeping the validation and test sets the same. In Figure 2 we show the evolution of the C-index with different percentages of censoring of the training set in the SUPPORT2 dataset.

Figure 2: Here we study how the composition of censored and uncensored patients during training impacts the C-index mean standard error over the 5 fold in the SUPPORT2 dataset. The validation and test sets are fixed and the training set has censored patients introduced by marking patients as censored at random. The plot starts at 30% because the dataset has that many censored patients by default. We find that the WM classification loss is robust to the introduction of censored data.

5 Conclusion

We proposed a new method for learning to rank survival data. Experiments on the different datasets show that our models trained with the WM loss gives accurate predictions compared to the more classical losses of the Cox model and ranking loss functions, which directly approximate a lower bound of the C-index. While not always state of the art, our method is always among the best results for each dataset.

We also find that this approach allows the method to tolerate a high percentage of censored samples and continue to predict well given results consistently in the same range of the best methods. Also, we demonstrate that our method can be seen as directly optimizing the expected C-index which is the most common evaluation metric for ranking survival models. Moreover, our results demonstrate that imputing the values with the KM curve for the missing times in a classification framework can increase the resulting C-index.

Acknowledgements

We thank Christopher Pal and Christopher Beckham for their input on the project. This work is partially funded by a grant from the U.S. National Science Foundation Graduate Research Fellowship Program (grant number: DGE-1356104) and the Institut de valorisation des donnees (IVADO). This work utilized the supercomputing facilities managed by the Montreal Institute for Learning Algorithms, NSERC, Compute Canada, and Calcul Quebec.

References

  • Beckham & Pal (2017) Beckham, Christopher and Pal, Christopher. Unimodal probability distributions for deep ordinal classification. International Conference on Machine Learning, 2017.
  • Burges et al. (2005) Burges, Christopher J C, Shaked, Tal, Renshaw, Erin, Lazier, Ari, Deeds, Matt, Hamilton, Nicole, and Hullender, Greg. Learning to Rank using Gradient Descent. In International Conference on Machine Learning. ACM, 2005.
  • Cox (1972) Cox, D. R. Regression models and life tables. Journal of the Royal Statistical Society, 1972.
  • Efron (1977) Efron, Bradley. The Efficiency of Cox’s Likelihood for Censored Data Function. Journal of the American Statistical Association, 1977.
  • Freund et al. (2003) Freund, Yoav, Iyer, Raj, Schapire, Robert E, and Singer, Yoram. An Efficient Boosting Algorithm for Combining Preferences. Journal of Machine Learning Research, 2003. doi: 10.1162/jmlr.2003.4.6.933.
  • Frogner et al. (2015) Frogner, Charlie, Zhang, Chiyuan, Mobahi, Hossein, Araya, Mauricio, and Poggio, Tomaso A. Learning with a Wasserstein loss. In Advances in Neural Information Processing Systems, 2015.
  • Harrell et al. (1982) Harrell, Frank E., Califf, Robert M., Pryor, David B., Lee, Kerry L., and Rosati, Robert A. Evaluating the Yield of Medical Tests. Journal of the American Medical Association, 1982. doi: 10.1001/jama.1982.03320430047030.
  • Herbrich (2000) Herbrich, Ralf. Large margin rank boundaries for ordinal regression. Advances in large margin classifiers, 2000.
  • Hinton et al. (2012) Hinton, Geoffrey E., Srivastava, Nitish, Krizhevsky, Alex, Sutskever, Ilya, and Salakhutdinov, Ruslan R. Improving neural networks by preventing co-adaptation of feature detectors. arXiv:1207.0580, 2012. doi: arXiv:1207.0580.
  • Hou et al. (2016) Hou, Le, Yu, Chen-Ping, and Samaras, Dimitris. Squared Earth Mover’s Distance-based Loss for Training Deep Neural Networks. arXiv:1611.05916, 2016.
  • Ioffe & Szegedy (2015) Ioffe, Sergey and Szegedy, Christian. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In International conference on machine learning, 2015.
  • Kalbfleisch (1978) Kalbfleisch, John. Non-Parametric Bayesian Analysis of Survival Time Data. Journal of the Royal Statistical Society. Series B, 1978.
  • Katzman et al. (2016) Katzman, Jared, Shaham, Uri, Cloninger, Alexander, Bates, Jonathan, Jiang, Tingting, and Kluger, Yuval. DeepSurv: Personalized Treatment Recommender System Using A Cox Proportional Hazards Deep Neural Network. International Conference of Machine Learning Computational Biology Workshop, 2016.
  • Kendall (1938) Kendall, Maurice G. A new measure of rank correlation. Biometrika, 1938.
  • Levina & Bickel (2001) Levina, Elizaveta and Bickel, Peter. The Earth Mover’s distance is the Mallows distance: Some insights from statistics. In International Conference on Computer Vision, volume 2. IEEE, 2001. doi: 10.1109/ICCV.2001.937632.
  • Luck et al. (2017) Luck, Margaux, Sylvain, Tristan, Cardinal, Héloïse, Lodi, Andrea, and Bengio, Yoshua. Deep Learning for Patient-Specific Kidney Graft Survival Analysis. arXiv:1705.10245, 2017.
  • Mena et al. (2018) Mena, Gonzalo, Belanger, David, Linderman, Scott, and Snoek, Jasper. Learning latent permutations with gumbel-sinkhorn networks. arXiv preprint arXiv:1802.08665, 2018.
  • Paszke et al. (2017) Paszke, Adam, Chanan, Gregory, Lin, Zeming, Gross, Sam, Yang, Edward, Antiga, Luca, and Devito, Zachary. Automatic differentiation in PyTorch, 2017.
  • Raykar et al. (2007) Raykar, Vikas C, Steck, Harald, Krishnapuram, Balaji, Dehing-oberije, Cary, and Lambin, Philippe. On ranking in survival analysis: Bounds on the concordance index. In Neural Information Processing Systems, 2007. doi: 10.1.1.121.2670.
Comments 0
Request Comment
You are adding the first comment!
How to quickly get a good reply:
  • Give credit where it’s due by listing out the positive aspects of a paper before getting into which changes should be made.
  • Be specific in your critique, and provide supporting evidence with appropriate references to substantiate general statements.
  • Your comment should inspire ideas to flow and help the author improves the paper.

The better we are at sharing our knowledge with each other, the faster we move forward.
""
The feedback must be of minimum 40 characters and the title a minimum of 5 characters
   
Add comment
Cancel
Loading ...
202125
This is a comment super asjknd jkasnjk adsnkj
Upvote
Downvote
""
The feedback must be of minumum 40 characters
The feedback must be of minumum 40 characters
Submit
Cancel

You are asking your first question!
How to quickly get a good answer:
  • Keep your question short and to the point
  • Check for grammar or spelling errors.
  • Phrase it like a question
Test
Test description