Adaptive Prediction Timing for Electronic Health Records

Adaptive Prediction Timing for Electronic Health Records

Abstract

In realistic scenarios, multivariate timeseries evolve over case-by-case time-scales. This is particularly clear in medicine, where the rate of clinical events varies by ward, patient, and application. Increasingly complex models have been shown to effectively predict patient outcomes, but have failed to adapt granularity to these inherent temporal resolutions. As such, we introduce a novel, more realistic, approach to generating patient outcome predictions at an adaptive rate based on uncertainty accumulation in Bayesian recurrent models. We use a Recurrent Neural Network (RNN) and a Bayesian embedding layer with a new aggregation method to demonstrate adaptive prediction timing. Our model predicts more frequently when events are dense or the model is certain of event latent representations, and less frequently when readings are sparse or the model is uncertain. At 48 hours after patient admission, our model achieves equal performance compared to its static-windowed counterparts, while generating patient- and event-specific prediction timings that lead to improved predictive performance over the crucial first 12 hours of the patient stay.

\iclrfinalcopy

1 Introduction

Over the past decade, machine learning, and deep learning in particular, have repeatedly demonstrated strong performance on a range of benchmark datasets (LeCun et al. (2015); Goodfellow et al. (2016)). These successes have since been transferred to medical data (Esteva et al. (2019)), and specifically the domain of Electronic Health Records (EHRs, Shickel et al. (2017)), where it is hoped research will lead to patient specification prognostication and diagnosis.

Recent state-of-the-art deep neural network (DNN) models for EHR data use recurrent models for sequence analysis which rely upon fixed prediction scheduling, carrying out extensive model analysis while overlooking the underlying choice of time window length (Choi et al. (2016); Rajkomar et al. (2018); Tomašev et al. (2019)). This early-stage modelling decision—necessitated by traditional Recurrent Neural Network (RNN, Elman (1990)) structure—loses patient information and timeseries granularity, and ignores the underlying timescales present in EHR data which have been shown to boost model performance (Meiring et al. (2018); Che et al. (2018)).

In this paper, we introduce and analyse a novel method for adaptive prediction timing in the context of medical timeseries. Progress in variational Bayesian neural networks, facilitated by the reparameterisation trick (Blundell et al. (2015); Kucukelbir et al. (2017)), has led to increased use of Bayesian embeddings of medical concepts (Dusenberry et al. (2019)). We posit that embedding distribution uncertainty can be used to induce adaptive prediction timing. To this end, we explore the following:

  1. How can model uncertainty be related to prediction timing? (Section 2, paragraph 2)

  2. How does adaptive prediction timing affect model performance? (Table 1)

  3. How do prediction timings differ and adapt from fixed windows during training? (Figure 1)

  4. Does our model generalise to different clinical objectives and cohorts? (Table 1 and A.2)

Contributions

We draw several significant conclusions from sequence modelling on the MIMIC-III (Johnson et al. (2016)) and the e-ICU (Pollard et al. (2018), Section A.2) datasets. We find that certainty rather than uncertainty, quantified by the precision of embedding distributions in a variational embedding layer, generates a natural measure of when to predict. In particular, we find that using the cumulative precision of embedding distributions encourages models to predict frequently when event sampling is dense and/or familiar to the model, and delays prediction when the model is uncertain about recent events—a more realistic approach to prediction timing. We demonstrate that the benefits of this model formulation do not impact negatively on final model performance. Finally, we highlight how adaptive prediction timing evolves over training to better utilise time periods of frequent events and produce correct predictions earlier in the patient stay.

2 Background

Adaptive prediction timing

Recent, highly-cited, outcome prediction models for EHR timeseries (Rajkomar et al. (2018); Tomašev et al. (2019)) have paid little attention to prediction timing, relegating window choice to the supplementary material, and omitting the question ’When is a good time to update model predictions?’ from the line of enquiry. This approach contradicts evidence that modelling patient outcomes dynamically over time is beneficial (Meiring et al. (2018); Deasy et al. (2019)), and overlooks literature on adaptive computation for RNNs (Graves (2016)). A few efforts to overcome irregular sampling have been made by learning interpolants and decay rates for individual variables across fixed time periods (Che et al. (2018); Shukla and Marlin (2019)), but these models still do not account for varying amounts of information lost at the point of aggregation. The authors of Liu et al. (2019) recently demonstrated that dynamic prediction, with a patient-specific temporal resolution found by classical min-max optimisation, outperforms the previous one-size-fits-all approach despite maintaining fixed windows. In this paper, we go further by arguing that the patient timeseries is, in fact, an information series and it is more appropriate to evenly spread events based on model certainty. Our approach not only generates individualised prediction timing, but also event-specific prediction timing—crucial in the highly heterogeneous and patient-specific environment of the hospital and the real world.

Embedding precision

In a Bayesian RNN, the variational inference approach to learning the weights of the approximate model , dictates the use of factorised weight posteriors . When the follow a multivariate Gaussian distribution, with learnable mean vector and diagonal covariance matrix , we have

(1)

where we drop the index for ease of notation and note that , the inverse of the covariance matrix , is the precision matrix of the multivariate normal. The precision of each embedding distribution is, therefore, defined by

(2)

and is as a measure of model certainty (DeGroot (2005)).

3 Adaptive prediction timing

Clinical objectives

We analyse the performance of a novel recurrent model on in-hospital mortality and long length of stay (defined as greater than 7 days, Rajkomar et al. (2018)) prediction, both at 48 hours after admission. Dynamic mortality risk estimation helps summarise the patient state, predict patient trajectory, and is the subject of multiple clinical severity scores (Rapsang and Shyam (2014)). Equally, long length of stay estimation enables ward management planning and resource allocation across the hospital. We employ the MIMIC-III database (Johnson et al. (2016)), an EHR dataset collected from 46,520 patients admitted to intensive care units (ICUs) at Beth Israel Deaconess Medical Center. We embed all chart, lab, and output events as described in Deasy et al. (2019), utilise our adaptive prediction timing aggregation step, feed the output to a layer-normalised LSTM (Ba et al. (2016); Hochreiter and Schmidhuber (1997)), and perform an affine transformation before a sigmoid output activation.

Method

For a given patient sequence, comprised of time points and events , instead of aggregating events into fixed time intervals, we first sample the corresponding sequence of event embedding distributions to obtain the sequence . To separate these samples into intervals, we then use the embedding distributions to generate a corresponding cumulative precision sequence

(3)

and separate this sequence into equi-precise aggregation windows which evolve as embedding distributions are refined by training. At no point does our model make use of event timestamps. We implement this in a vectorised manner to handle batches of size greater than one.

Our models were trained by minimising the Kullback-Leibler (KL) divergence between the approximate posterior and the actual, intractable, posterior via the reparameterisation trick. Equivalent to minimising an expectation over the negative log-likelihood term plus a KL regularisation term

(4)
(5)

4 Experiments

Clinical tasks

For our clinical tasks, to assess predictive performance, we measure area under the precision-recall curve (AUPRC), area under the receiver operating characteristic curve (AUROC), and, as there is a strong class imbalance (see Table A.1), Matthews Correlation Coefficient (MCC). Table 1 shows the mean and standard deviation of the metrics at 48 hours after admission. Throughout, we re-sample the embedding layer of the variational models 100 times and bootstrapped ensembles of 10 deterministic models with 1000 re-samples to generate error measurements. Despite the underlying change in aggregation mechanism, the final performance of our model is very strong, inline with the literature (Dusenberry et al. (2019)), and the predictions are well-calibrated (see Figure A.1), demonstrating strong performance on both clinical tasks. We also verify model generalisation for both tasks on the eICU dataset (Pollard et al. (2018)) in Table A.2.

TASK METRIC VALIDATION TEST
Mortality AUPRC 0.576 (0.013) 0.556 (0.011)
AUROC 0.897 (0.003) 0.879 (0.004)
MCC 0.510 (0.011) 0.496 (0.012)
Long length of stay AUPRC 0.614 (0.008) 0.566 (0.009)
AUROC 0.834 (0.004) 0.830 (0.004)
MCC 0.494 (0.009) 0.465 (0.009)
Table 1: Mean (and standard deviation) of metrics for the adaptive prediction timing model on the binary mortality and long length of stay tasks—max MCC over 100 thresholds is reported. Our model displays strong predictive performance, with robust generalisation to the held-out test set.

Model variants

We assess our model against a range of differing models in Table 2. We name these models based on whether their embedding layer is deterministic or Bayesian, and whether their prediction timings are based on fixed timing (no prefix), fixed event count (#), or fixed cumulative precision (). We include a model which aggregates by event count, to isolate the effect of our model’s adaptation, and conclude it has a marginally negative effect on generalisation, which must be weighed against the advantages of improved temporal resolution during eventful periods.

MODEL VAL. AUPRC VAL AUROC TEST AUPRC TEST AUROC
Deterministic LSTM 0.592 (0.010) 0.887 (0.008) 0.595 (0.012) 0.889 (0.006)
Deterministic #-LSTM 0.602 (0.012) 0.883 (0.006) 0.578 (0.011) 0.883 (0.005)
Bayesian LSTM 0.574 (0.013) 0.886 (0.004) 0.571 (0.012) 0.883 (0.004)
Bayesian #-LSTM 0.582 (0.011) 0.889 (0.004) 0.573 (0.013) 0.881 (0.004)
Bayesian -LSTM 0.576 (0.013) 0.897 (0.003) 0.556 (0.011) 0.879 (0.004)
Table 2: Performance comparison between different model variants for the binary mortality prediction tasks on the MIMIC-III dataset.

Early predictive power

In Figure 1, we compare mortality risk prediction of the static model with our adaptive model for a patient who went on to die in hospital. The more fine-grained prediction timings learnt by the variational model, displayed in Figure 1-left, led to earlier prediction of mortality compared to the static model in Figure 1-right due to patient-specific segmentation of the timeseries. As most patients in the ICU have many additional readings taken in the first hours of their stay (e.g. admission information and medical history), which clinicians use to more frequently update their opinion of the patient state, this is a more realistic approach to outcome prediction.

Figure 1: Left–Mortality probability evolving through training for a patient that went on to die. Right–The adaptive timing model predicts mortality earlier than the static window model.

Prediction timing evolution

In Figure 2, we demonstrate the evolution of prediction timing for a different patient. In Figure 2-right, the prediction timing distribution can be seen to focus on a particularly event-dense period for this patient as it learns to be more certain about the embedding distributions of particular clinical events. This suggests our model would adapt well to more extreme shifts in granularity such as stays which include either surgical or emergency interventions.

Figure 2: Example of how the timing of model predictions coalesce through training for a patient with frequent event recordings towards the end of their stay.

Appendix A Appendix

a.1 Additional dataset information

After applying the pipeline described in Deasy et al. (2019), our subset of the MIMIC-III dataset contained 21,143 patient stays including the following demographic and outcome ratios.

TRAIN (%) VALIDATION (%) TEST (%)
Male (55.2%) (53.9%) (54.9%)
In-hospital mortality (13.2%) (13.2%) (13.2%)
Long length of stay (21.9%) (24.2%) (21.8%)
Table A.1: Dataset information.

a.2 eICU Collaborative Research Database

As a proof of concept that our findings generalise, we also experiment with a small subset of the aperiodic vital sign readings in the eICU Collaborative Research Database (eICU) dataset (Pollard et al. (2018)), another publicly available EHR dataset. Results are displayed in Table A.2.

TASK METRIC VALIDATION TEST
Mortality AUPRC 0.253 ( 0.012) 0.244 ( 0.010)
AUROC 0.705 ( 0.005) 0.701 ( 0.006)
Long length of stay AUPRC 0.217 ( 0.010) 0.205 ( 0.012)
AUROC 0.610 ( 0.005) 0.600 ( 0.007)
Table A.2: Performance of the adaptive prediction timing model for the binary mortality and long length of stay tasks on the eICU dataset.

a.3 Model calibration

(a) Deterministic model.
(b) Variational model.
Figure A.1: Model calibration curves.

a.4 Additional training details

As in Deasy et al. (2019), our model uses embeddings of raw, unprocessed medical concepts in order to avoid any loss of extraneous information. Continuous variables were binned into 10 discrete categories assigned by quantile, and discrete variables were left untouched. Missing events were also embedded as discrete categories so our model makes use of informative missingness (Che et al. (2018)). Following on from the embedding layer, we noted during experimentation that a layer-normalised LSTM variant led to a considerable increase in performance. We initialise our LSTM with glorot initialisation (Glorot and Bengio (2010)) for the input-to-hidden matrices, orthogonal initialisation for the hidden-to-hidden matrices, and set the forget gate bias to 1 before training (Jozefowicz et al. (2015)).

For each model, we search over the space of hyperparameters defined in Table A.3 using the Python package wandb from Weights and Biases. The search was performed using Bayesian optimisation to minimise validation set loss over both discrete and continuous variables. We use HyperBand early stopping (Li et al. (2017)), with , , and to expedite optimisation. We fix batch size at 64 for all models.

Models were implemented in PyTorch 1.4, and trained on a Nvidia Titan X using Adam optimisation (Kingma and Ba (2014)). Both e-ICU and MIMIC-III datasets were split into train, validation, and test sets in a ratio of 8:1:1.

HYPERPARAMETER RANGE
Learning rate [0.00001, 0.1]
regularisation coefficient [0.0, 0.01]
Prior standard deviation (Bayesian only) [0.1, 1.0]
Embedding dimension [16, 32, 48, 64]
LSTM hidden dimension [16, 32, 64, 128, 256, 512]
Table A.3: Hyperparameter ranges and options.

References

  1. Layer normalization. arXiv preprint arXiv:1607.06450. Cited by: §3.
  2. Weight uncertainty in neural networks. arXiv preprint arXiv:1505.05424. Cited by: §1.
  3. Recurrent neural networks for multivariate time series with missing values. Scientific reports 8 (1), pp. 6085. Cited by: §A.4, §1, §2.
  4. Retain: an interpretable predictive model for healthcare using reverse time attention mechanism. In Advances in Neural Information Processing Systems, pp. 3504–3512. Cited by: §1.
  5. Dynamic survival prediction in intensive care units from heterogeneous time series without the need for variable selection or pre-processing. arXiv preprint arXiv:1909.07214. Cited by: §A.1, §A.4, §2, §3.
  6. Optimal statistical decisions. Vol. 82, John Wiley & Sons. Cited by: §2.
  7. Analyzing the role of model uncertainty for electronic health records. arXiv preprint arXiv:1906.03842. Cited by: §1, §4.
  8. Finding structure in time. Cognitive science 14 (2), pp. 179–211. Cited by: §1.
  9. A guide to deep learning in healthcare. Nature medicine 25 (1), pp. 24. Cited by: §1.
  10. Understanding the difficulty of training deep feedforward neural networks. In Proceedings of the thirteenth international conference on artificial intelligence and statistics, pp. 249–256. Cited by: §A.4.
  11. Deep learning. Vol. 1, MIT Press. Cited by: §1.
  12. Adaptive computation time for recurrent neural networks. arXiv preprint arXiv:1603.08983. Cited by: §2.
  13. Long short-term memory. Neural computation 9 (8), pp. 1735–1780. Cited by: §3.
  14. MIMIC-iii, a freely accessible critical care database. Scientific data 3, pp. 160035. Cited by: §1, §3.
  15. An empirical exploration of recurrent network architectures. In International conference on machine learning, pp. 2342–2350. Cited by: §A.4.
  16. Adam: a method for stochastic optimization. arXiv preprint arXiv:1412.6980. Cited by: §A.4.
  17. Automatic differentiation variational inference. The Journal of Machine Learning Research 18 (1), pp. 430–474. Cited by: §1.
  18. Deep learning. nature 521 (7553), pp. 436. Cited by: §1.
  19. Hyperband: a novel bandit-based approach to hyperparameter optimization. The Journal of Machine Learning Research 18 (1), pp. 6765–6816. Cited by: §A.4.
  20. Learning hierarchical representations of electronic health records for clinical outcome prediction. arXiv preprint arXiv:1903.08652. Cited by: §2.
  21. Optimal intensive care outcome prediction over time using machine learning. PloS one 13 (11), pp. e0206862. Cited by: §1, §2.
  22. The eicu collaborative research database, a freely available multi-center database for critical care research. Scientific data 5. Cited by: §A.2, §1, §4.
  23. Scalable and accurate deep learning with electronic health records. NPJ Digital Medicine 1 (1), pp. 18. Cited by: §1, §2, §3.
  24. Scoring systems in the intensive care unit: a compendium. Indian journal of critical care medicine: peer-reviewed, official publication of Indian Society of Critical Care Medicine 18 (4), pp. 220. Cited by: §3.
  25. Deep ehr: a survey of recent advances in deep learning techniques for electronic health record (ehr) analysis. IEEE journal of biomedical and health informatics 22 (5), pp. 1589–1604. Cited by: §1.
  26. Interpolation-prediction networks for irregularly sampled time series. arXiv preprint arXiv:1909.07782. Cited by: §2.
  27. A clinically applicable approach to continuous prediction of future acute kidney injury. Nature 572 (7767), pp. 116–119. Cited by: §1, §2.
Comments 0
Request Comment
You are adding the first comment!
How to quickly get a good reply:
  • Give credit where it’s due by listing out the positive aspects of a paper before getting into which changes should be made.
  • Be specific in your critique, and provide supporting evidence with appropriate references to substantiate general statements.
  • Your comment should inspire ideas to flow and help the author improves the paper.

The better we are at sharing our knowledge with each other, the faster we move forward.
""
The feedback must be of minimum 40 characters and the title a minimum of 5 characters
   
Add comment
Cancel
Loading ...
410857
This is a comment super asjknd jkasnjk adsnkj
Upvote
Downvote
""
The feedback must be of minumum 40 characters
The feedback must be of minumum 40 characters
Submit
Cancel

You are asking your first question!
How to quickly get a good answer:
  • Keep your question short and to the point
  • Check for grammar or spelling errors.
  • Phrase it like a question
Test
Test description