Harmonic Mean Point Processes: Proportional Rate Error Minimization for Obtundation Prediction
In healthcare, the highest risk individuals for morbidity and mortality are rarely those with the greatest modifiable risk. By contrast, many machine learning formulations implicitly attend to the highest risk individuals. We focus on this problem in point processes, a popular modeling technique for the analysis of the temporal event sequences in electronic health records (EHR) data with applications in risk stratification and risk score systems. We show that optimization of the log-likelihood function also gives disproportionate attention to high risk individuals and leads to poor prediction results for low risk individuals compared to ones at high risk. We characterize the problem and propose an adjusted log-likelihood formulation as a new objective for point processes. We demonstrate the benefits of our method in simulations and in EHR data of patients admitted to the critical care unit for intracerebral hemorrhage.
Clinical forecasting is a central task for prognostication in populations at risk for downstream morbidity and mortality. When followup is incomplete and right-censorship of data occurs, survival analysis and point process models are often preferable to binary classification for long-term prognostication due to their ability to mitigate selection bias associated with censorship. Many models exists for the survival analysis task, including Cox, Aalen, accelerated failure time models, random survival forests, and so on. In these models, characterization of proportional errors in rate is often an objective of primary interest, but its minimization is not straightforward because ground truth rates are unobserved. Indeed, the long-standing success of the Cox proportional hazards model is a demonstration of the interpretive value of proportional rate estimation. Despite this, the objective function specified for many survival models do not seek to minimize proportional error in rate, including that of Cox.
We approach the problem of proportional rate minimization for repeated events through the formulation of the objective function. Our analysis illustrates that the standard likelihood function attends to individuals at highest risk, potentially at the cost of poorly modeling proportional rates in low-risk individuals. This characterization provides us a way to attempt to mitigate the mis-attention and results in a reweighting scheme to fairly attend to all individuals with respect to proportional rate misspecification. Practically, this results in a fairness-variance trade-off, as the models suffer from high variance from low effective sample sizes. We demonstrate in simulations and in prediction of neurological deterioration among patients admitted for intracerebral hemorrhage (ICH) that our method empirically produces informative risk assessments in low rate regimes.
Background. Let be the event we want to model over time across samples. Let the event times be the sequence for for over a period of interest . We are interested in modeling the rate function:
where varies by model and represents the information or data to use in modeling . The probability density function and survival function are given by and .
Next we establish the relationship between rescaled time where a single Poisson process with rate 1 and our original time, where is defined. This comes from the time rescaling theorem.
Time rescaling theorem. Given the rate function , define the cumulative hazard function: . For the realization of a sequence of events from with times and , the sequence is distributed according to a unit rate Poisson process (Meyer, 1971; Ogata, 1981).
Details of the proof can be found in (Brown et al., 2002). The implication of the theorem is that if we could model the conditional intensity correctly, the intervals between rescaled times follow exponential distribution with rate 1.
Harmoic Mean Point Processes. From the time rescaling theorem, it is straightforward to observe that the relative contributions to the likelihood of each time interval is proportional to the rate within that interval, i.e., if we care about each individual’s risk in a time unit equally, then we could consider decreasing the likelihood contributions in proportion to the rate. In other words, our procedure will seek to nullify, partially or fully, the proportional factor of likelihood attention given to higher rates. We call this approach optimization of the adjusted log likelihood, which is illustrated in Figure 1.
where is the ground truth intensity at time . By assuming and are piecewise constant, we can view the adjusted log-likelihood as the weighted sum of log-likelihood contributions. Suppose we divide the time interval into sub-intervals where is a significantly large number so that is constant within any sub-interval. That is, with , is constant for and all . Then,
is the sum of log-likelihood contributions in time intervals weighted by the reciprocal of the ground truth intensity at for every and . The similarity is apparent when it is compared to the similar form of standard log-likelihood:
Therefore, we can weight each interval’s log likelihood by the inverse of the oracle rate to get the adjusted log likelihood.
Oracle Approximation. Without access to , however, we must resort to approximation of the reweighting. One choice for is our current estimate . However, this could lead to unstable weightings because a single example could dominate the weight distribution. To address this fairness-variance tradeoff, we introduce the attention coefficient and stability factor to help stabilize the weights. Pseudocode in Figure 1 (right) illustrates the training procedure and the stabilization modification. We call our method harmonic mean point processes (HMPP) because if the oracle is known and doubly stochastic (frail), then the estimates we get from the training procedure are harmonic mean estimates of the rate distribution. Note that in practice when we use an approximation, the denominator must be copied and detached from the computation graph so that the graph of which is a part is not further connected by the current model’s predictions .
3 Experimental Setup and Results
We test our method in two simulations, where we have access to ground truth rates, and in application to a health setting, where we illustrate important factors and effects of our approach. In all cases, we are comparing our objective versus the standard variant, and call the models Harmonic Mean Point Processes (HMPP, ours) and Maximum Likelihood Point Processes (MLPP, comparison). We use this labeling across multiple models and domains which we describe next.
Simulations. To test our ability to accurately determine the rates of low risk individuals, we developed a singly- and a doubly- stochastic univariate model with rates varying by 4-6 orders of magnitude ( to fold variation in rates). In each case, events are sampled for 10 units of time according to a sample-specific fixed rate , where is drawn from a truncated, base 10 exponential between to . For the singly-stochastic model, we sample events according to an exponential with rate , and for the doubly-stochastic model, we sample events according to an exponential with rate for . Then, a time-stamped sequence is produced, containing tuples of (id, time, event, value) features, with the first tuple containing (id, 0, rate, ). We use an embedding LSTM architecture for the simulations (Dong et al., 2018) and provide details in the Appendix.
Application: obtundation in intracerebral hemorrhage. We also apply our method to real data of neurological decline during critical care admissions for intracerebral hemorrhage (ICH). Intracerebral hemorrhage is a life-threatening extravasation of blood outside the vessel wall due to a tear or rupture that results in an accumulation of blood, which then presses upon brain tissue causing neuronal damage. Mortality rates are 40% at 1 month post diagnosis. In these individuals, frequent monitoring of neurological status is essential. The Glasgow Coma Score (GCS) is a score based on physical exam that assess progression and recovery. For ICH, GCS is a primary indicator for mortality stratification, intracranial pressure monitoring, and intubation (parry2013accuracy; manno2012update).
Data. We used data from MIMIC III v1.4 (Johnson et al., 2016), an EHR housing critical care data on 40,000 individuals. Of those, 1,010 had a primary diagnosis of ICH and were considered as members of our cohort. Chart, laboratory, medications, vitals, procedures, and demographics tables were extracted as time-varying features for nowcasting GCS decreases. The full outcome specification for obtundation (decreased GCS) is in the Appendix. Figure 2 depicts the study characteristics and approach. We use wavelet reconstruction networks (WRNs) (Weiss, 2018) with the objective function modification to accommodate the adjusted log likelihood.
We investigate the performance of the algorithms using inspection of calibration and variable importance plots.
Results. Figure 3 demonstrates the benefit of our method in simulations. In the singly-stochastic model (left figure), both the HMPP and MLPP approaches discriminate risk across the spectrum, illustrated by the (approximately) monotonic curves. However, the MLPP method never predicts hazards lower than 0.2 for any group, despite many of those groups having empirical rates near 0.01. By contrast, the HMPP method straightens the low-risk tail and makes more accurate predictionsin low risk individuals. At the same time, risk predictions for high risk individuals are similar in quality. The range of predicted risks from HMPP was half an order of magnitude larger than MLPP.
In the ICH study, the hyperparameters chosen were an elastic net formulation L1 and L2: with and , suggesting the model is constrained by limited sample size. Discriminatively, the C statistic among the lowest quartile was 0.68 (0.62-0.74, bootstrap CI 95%) and 0.66 (0.60, 0.72) for HMPP and MLPP respectively. Thus, while HMPP does identify lower risk groups, the small sample size limits the interpretation. Additional results and comments are in the Appendix.
Conclusion. Our work demonstrates a new tool to make risk predictions in low-risk populations. We provide a formulation that exhibits how to attend equally across risk, and provide an algorithm and guidance to trade off fair attention with variance from reweighting. Importantly, our method detects individuals an order of magnitude lower than predictions made by optimization with the log likelihood and deep network–the combination of two popular approaches. We further illustrate implications of attending to low-risk individuals in the variable importances reported under each optimization. This difference may have important applications in suggesting risk factors that are stratum-specific, which can provide guidance in personalized decision making. Future work will include explicit characterization of the proportionate attention-variance tradeoff which could provide alternative approximations to the oracle rate with desirable properties.
- Avati et al. (2018) Avati, A., Duan, T., Jung, K., Shah, N. H., and Ng, A. (2018). Countdown regression: Sharp and calibrated survival predictions. arXiv preprint arXiv:1806.08324.
- Brown et al. (2002) Brown, E. N., Barbieri, R., Ventura, V., Kass, R. E., and Frank, L. M. (2002). The time-rescaling theorem and its application to neural spike train data analysis. Neural computation, 14(2):325–346.
- Chapfuwa et al. (2018) Chapfuwa, P., Tao, C., Li, C., Page, C., Goldstein, B., Duke, L. C., and Henao, R. (2018). Adversarial time-to-event modeling. In International Conference on Machine Learning, pages 734–743.
- Dong et al. (2018) Dong, H.-W., Hsiao, W.-Y., Yang, L.-C., and Yang, Y.-H. (2018). Musegan: Multi-track sequential generative adversarial networks for symbolic music generation and accompaniment. In Thirty-Second AAAI Conference on Artificial Intelligence.
- Gerhard and Gerstner (2010) Gerhard, F. and Gerstner, W. (2010). Rescaling, thinning or complementing? on goodness-of-fit procedures for point process models and generalized linear models. In Advances in neural information processing systems, pages 703–711.
- Jing and Smola (2017) Jing, H. and Smola, A. J. (2017). Neural survival recommender. In Proceedings of the Tenth ACM International Conference on Web Search and Data Mining, pages 515–524. ACM.
- Johnson et al. (2016) Johnson, A. E., Pollard, T. J., Shen, L., Li-wei, H. L., Feng, M., Ghassemi, M., Moody, B., Szolovits, P., Celi, L. A., and Mark, R. G. (2016). Mimic-iii, a freely accessible critical care database. Scientific data, 3:160035.
- Lee et al. (2018) Lee, C., Zame, W. R., Yoon, J., and van der Schaar, M. (2018). Deephit: A deep learning approach to survival analysis with competing risks. In Thirty-Second AAAI Conference on Artificial Intelligence.
- Meyer (1971) Meyer, P.-A. (1971). Demonstration simplifiee d’un theoreme de knight. In Séminaire de probabilités v université de strasbourg, pages 191–195. Springer.
- Miscouridou et al. (2018) Miscouridou, X., Perotte, A., Elhadad, N., and Ranganath, R. (2018). Deep survival analysis: Nonparametrics and missingness. In Machine Learning for Healthcare Conference, pages 244–256.
- Ogata (1981) Ogata, Y. (1981). On Lewis’ simulation method for point processes. IEEE Transactions on Information Theory, 27(1):23–31.
- Pillow (2009) Pillow, J. W. (2009). Time-rescaling methods for the estimation and assessment of non-poisson neural encoding models. In Advances in neural information processing systems, pages 1473–1481.
- Weiss (2018) Weiss, J. C. (2018). Clinical risk: wavelet reconstruction networks for marked point processes.
- Weiss (2019) Weiss, J. C. (2019). On microvascular complications of diabetes risk: development of a machine learning and electronic health records risk score.
- Yu et al. (2011) Yu, C.-N., Greiner, R., Lin, H.-C., and Baracos, V. (2011). Learning patient-specific cancer survival distributions as a sequence of dependent regressors. In Advances in Neural Information Processing Systems, pages 1845–1853.
Appendix A Proportional rate error minimization in single failure-time prediction
There are methods for single failure-time prediction that optimize for the proportional predictive error. For example, accelerated failure time (AFT) models accomplish this by log transforming time so that minimizing mean squared error in log space corresponds to minimizing multiplicative errors in the original space. While this approach works for single failure-time analyses, it fails to extend to repeated event models. In particular, AFT models require a specification of time so that the log transformation is well-defined. Time specification is problematic in time-varying and nowcasting analysis, as noted in (Miscouridou et al., 2018) and (Weiss, 2018), because training time must be specified while the modeler may want to vary test time or model repeated events. Attempts to duplicate training samples longitudinally by varying train time lead to samples that violate the parametric and or semi-parametric model assumptions and result in poor parameter estimates and poor predictive performance. Because of this problem, it is unclear how to naturally extend such models to recurrent event and multitask settings. In these settings, larger ranges of rates are often modeled, and this magnifies the problem of misplaced attention.
Another approach is to fuse together binary classification predictions across time. Our method may help illustrate the effectiveness of this approach, in that this multi-task prediction formulation attends implicitly to low-risk examples by ignoring those examples where events have already occurred in the large time-to-event classifiers (Yu et al., 2011). However, unlike these methods, our method provides explicit attention to low-risk persons and regions in the repeated event setting.
Appendix B Sensitivity to hyperparameters and
Figure B.1 demonstrates the reweighting achieved with different choices of attention coefficient and stability factor . To achieve equal rescaled-time weighting, must be set to 10, corresponding to 10-fold increased weight per 10-fold decrease in risk: the blue horizontal line. However, the number of effective samples may become very small, shown by the number of effective samples (per 1) for several common distributions. To avoid this, and or can be chosen to flatten the reweighting distribution. In practice the predicted distribution is implicit and potentially unstable, and using domain knowledge to set near to the lowest rate expected to be found will mitigate the instability while still attending to the low risk individuals.
Appendix C Architecture
The architecture is given in Figure C.2. The idea is to use a LSTM (piano roll) embedding architecture, where any number of events with or without values can first be embedded as line and point embeddings respectively, and then the embedded signals are captured in a group embedding and passed into LSTM time steps. This architecture facilitates flexible parsing of long format data typical of marked point processes, such as that of digital orchestral music, or that of medical event streams. Categorical events are treated as multiple point events, and point events are embedded as points. Real-valued events are embedded based on their value, and so the event’s value domain corresponds to a line embedded as a 1-dimensional manifold. These embedded vectors are then further embedded as a group based on their timestamps, and fed into an LSTM that outputs non-negative rate predictions. We use times steps of unit length with 10 steps in total.
Appendix D Training hyperparameters
We conduct training for 50 epochs using the Adam optimizer with a learning rate of and a batch size of 8. The reweighting at each training step also makes across-step ALL comparisons not meaningful. Therefore, when choosing early stopping points, we use the tune set log likelihood. We also used the tune set performance in our search over the following hyperparameters: (ICH only), (ICH only), L1 regularization (LASSO) coefficient in , and L2 regularization (ridge) coefficient in . Our implementation is in PyTorch v1.0.
Appendix E Additional methods and results
Outcome definition. GCS scores were recorded in the chart table in two versions depending on the vendor, one by the component scores Eyes, Verbal, and Motor, and the other as an aggregate score. We defined a decrease in GCS to be a decrease of any score not a result of intubation, or a first GCS below 8 (an obtunded state indicative of poor outcomes). GCS readings inside the critical care unit were considered only. Per individual, events within the first ICH encounter only were used.
Calibration subgroups and variable importance. We construct calibration plots using the ordering given by the algorithm predictions, which illustrates discriminative ability in the algorithms ability to stratify groups across the spectrum of risk (an alternative is to order by ground truth rates where the specific expressed goal of assessment is calibration). In applications to real data, we cannot access ground truth, so we order by and use equal quantiles with respect to the predictions. We also expect important features to vary by objective function and inspect their similarity with variable importance plots.
Additional results. For the doubly-stochastic model where the formulation includes frailty, the performance of HMPP and MLPP diverges further. In particular, Figure E.3 shows that in the face of random effects that vary the rate ranges by 100-fold, MLPP focuses on the high end of the random effect distribution and HMPP the low end. HMPP identifies groups of individuals with empirical rates an order of magnitude smaller. It also identifies groups at larger rates, but appears to underestimate the rates for these individuals. This could be due to overfitting of the training data leading to erroneously low predictions on the test set. Nonetheless, HMPP detects low risk individuals in this setting whereas MLPP does not acknowledge their low rates, instead limiting all predictions to greater than 0.1.
|HMPP||MLPP||MLPP (HMPP par.)|
|IV solution||Nasogastric fluid||Foley catheter|
|Orogastric fluid||Functional fibrinogen||Osmolality:blood|
|Osmolality: blood||Urine white blood cell count||Temperature|
|Vancomycin level: blood||Osmolality: blood||Functional fibrinogen|
|IV dextrose in water: D5W||IV drip labetalol||IV phenytoin|
|Total cholesterol||Vancomycin level: blood||Phenytoin: blood|
|Lymphocyte count: csf||Urine red blood cells||IV normal saline|
|Urine sodium||Basophil count:blood||IV sterile water added|
|IV normal saline||Urine ketones||PO metoprolol|
Even in the small data setting of ICH in MIMIC, where the low-risk group is not newly identified by the method, we can look at the variable importance plots to demonstrate a marked difference in result. Figure E.4 shows HMPP and MLPP variable importances for the top 10 features. Two features overlap, osmolality and vancomycin levels. For the rest, the HMPP model is concerned with lab tests and intravenous solution choices and quantities while the MLPP model is concerned about urine and clotting lab tests. It could be that the need for stability in the form of increased regularization and early stopping could be necessary for the HMPP model, so we additionally computed the variable importances for the MLPP model with the same regularization settings, the top ten variables are also shown, which share no increased overlap with the fair model. This illustrates that the factors that influence proportional risk across the risk spectrum may differ substantially from those obtained from simple likelihood optimization which attends to high risk.