SAVEHR: Self Attention Vector Representations for EHR based Personalized Chronic Disease Onset Prediction and Interpretability
Chronic disease progression is emerging as an important area of investment for healthcare providers. As the quantity and richness of available clinical data continue to increase along with advances in machine learning, there is great potential to advance our approaches to caring for patient. An ideal approach to this problem should generate good performance on at least three axes namely, a) perform across many clinical conditions without requiring deep clinical expertise or extensive data scientist effort, b) generalization across populations, and c) be explainable (model interpretability). We present SAVEHR, a self-attention based architecture on heterogeneous structured EHR data that achieves 0.51 AUC-PR and 0.87 AUC-ROC gains on predicting the onset of four clinical conditions (CHF, Kidney Failure, Diabetes and COPD) 15-months in advance, and transfers with high performance onto a new population. We demonstrate that SAVEHR model performs superior to ten baselines on all three axes stated formerly.
Clinicians record structured data such as diagnosis codes, vitals from lab tests and unstructured data such as clinical notes in electronic health records (EHR) system. Accurately predicting the progression of diseases using aforementioned data from EHR could allow clinicians and patients to make more informed choices, reduce costs, and decrease mortality and morbidity. But challenges in EHR data include, heterogeneity, temporal dependencies, sparseness and incompleteness while being high dimensional. [jensen2012mining, weiskopf2013defining, tran2014framework]. Data-driven approaches for feature selection from EHR have been proposed to address these challenges. [huang2014toward, lyalina2013identifying, wang2014unsupervised]. An initial step in modeling a disease trajectory is to predict its onset. A variety of deep learning approaches to predicting disease onset have been explored including predictions of congestive heart failure mallya2019effectiveness, Kidney Failure perotte2015risk, Dementia de2018unsupervised and Delirium wong2018development. High performance while a necessity, validation on a new population and interpretability are key aspects for adoption in a healthcare system. We answer these aspects by proposing SAVEHR which uses self-attention lin2017structured on structured EHR data to learn pairwise comorbidities to effectively predict disease onset, and generate personalized feature importance visualizations.
Related Work: In recent years, Attention mechanisms have made substantial gains in conjunction with RNNs. Attention allows the network to focus on certain regions of data, while perceiving other regions with “low resolution”. As a consequence of that, it facilitates the interpretation of learned representations. We now see attention being applied to healthcare data (clinical notes and structured EHR) as well, To represent the behavior of physicians during an encounter a two-level neural attention model is used by choi2016retain focusing on reverse time order of events. EHR events as a temporal matrix by cheng2016risk and use CNN based architecture to predict onset of CHF and COPD, to obtain feature importance they aggregate weights of the neurons. To predict outcomes on ICU events, attention is used by kaji2019attention and attention is used on clinical notes to detect adverse medical events by chu2018using. Multi-stage attention is used by Patient2Vec, where self-attention is applied within a sub-sequence of homogeneous features like medical codes followed by creation of aggregated deep representation to predict outcomes and generate personalized heatmap for a patient that can explain model predictions.
We create one development and one external test cohort for each of the four chronic diseases (CHF, Kidney Failure, Diabetes Type II and COPD) from de-identified, anonymized, structured EHR data that are from two distinct patient populations referred to as P1 (the development cohort) and P2 (the test cohort). We use a 12-month observation window that’s between 27 and 15-months from the index date and a prediction window of 15-months. Since, the observation window was fixed across a relatively long time window (12-month), we aggregate frequency counts of diagnosis codes assigned for encounters across a time window (3-month time slices or quarters) similar to that in Choi2016UsingRN to facilitate temporal learning. The case-control design, index date selection along with time windows is explained in the Appendix B. We represent each patient’s data with static demographic features (gender, race, age) and sequence of diagnoses and procedures codes, termed as medical concepts. Feature sequence for a patient is denoted as , where denotes the quarter of interest, so the higher the subscript value, the closer the quarter is to the index date. The number of case and control patients for each disease is presented in Appendix 2.
3 SAVEHR Model Architecture
In this section, we present details and the architecture(Figure 1) of the proposed self-attention based SAVEHR neural network. There are three main components of this architecture: 1) self-attention layer for heterogeneous features, followed by 2) a Bi-GRU layer and 3) an MLP attention mechanism. In the following, we describe each component and their contribution to the classification task in detail.
Self-Attention with Heterogeneous features: Self-Attention relates elements at different positions from a single sequence by computing the attention between each pair of inputs, and . Non-categorical information such as age is converted into categorical feature by binning, while race and gender are integer encoded. The final feature representation for any given time slice is obtained by concatenating one-hot representations of all the above features into single vector form where represents homogeneous feature representation for patient in time slice , we denote its length as . The input is passed through an embedding layer, where an embedding is learnt for each of feature, represented by of , where is embedding dimension and is then fed into the self-attention layer. We compute attention for every feature with respect to other features in the via . The self-attention layer produces a vector of weights : where is a weight matrix with a shape of and is a vector of parameters with size . To capture nuanced interactions especially for ’s with long sequences of ’s, we perform multiple hops of attention. As an example, say we want different parts to be extracted from the , we extend into a matrix, note it as , and the resulting annotation vector becomes annotation matrix A. It’s formally represented in equation 1 where the softmax function is applied along the second dimension of its input. We compute the weighted sums by multiplying the annotation matrix and embedding output matrix resulting in .
To capture the longitudinal dependencies and understand the importance of each time slice for a given patient, we feed the sequence of encoded quarterly representations from the self-attention layer into a bidirectional GRU-based RNN with aggregated MLP-Attention, refer to Appendix E for details.
Baselines: We use ten baselines categorized in to common baselines (Logistic Regression, Random Forest, Multi-Layer perceptron), Deep Learning Baselines(1D-CNN and Bi-directional GRU based) and attention based models. A wide variety of baselines were evaluated in order to understand the performance vs model complexity trade off. Baselines are described in depth in Appendix F.
We perform a robust evaluation with 11 onset prediction models on four clinical conditions (CHF, Kidney Failure, Diabetes Type II and COPD) over three axes (Performance, Generalization and Interpretability). Given the imbalance in data, we consider the Area under the Precision-Recall Curve (AUC-PR) as the primary metric for performance [saito2015precision, davis2006relationship] and is reported in Table 1. Standard deviations from three-fold cross validation is reported in Appendix I.
Experiment i) Across clinical conditions: For each of four conditions mentioned above, we created a training set, validation, internal test (P1) and external test (P2) and use AUC-PR as the primary metric for evaluating performance.
Experiment ii) Generalize across populations: Hospital systems can have variations in how diagnosis codes are assigned for each clinical visits as shown in many studies [burns2011systematic, quach2010administrative, jolley2015validity, vlasschaert2011validity], hence its essential for the model to be evaluated on different populations justice1999assessing. Several studies show that characterizing performance on a single population can be insufficient [collins2014external, bleeker2003external, konig2007practical, ivanescu2016importance]. Hence to evaluate, we pick the same trained model evaluated on test set (P1) and evaluate it on corresponding condition’s cohort in external cohort (P2) and report AUC-PR in Table 1 under section P2.
Experiment iii) Interpretablity: Non-linearity in deep learning based models help achieve better performance over linear methods, but may make model opaque to humans. In order to trust the model’s prediction, we believe alignment they should provide insights into why the model produced the result it did. We evaluate the interpretability of the models by generating both population level (Appendix K) and per patient feature importance visualizations for SAVEHR (Figure 2).
5 Results and Discussion
The SAVEHR model outperformed all baselines models on AUC-PR metric across all four conditions on the internal test set P1 (except Diabetes in P1 and the external test set P2 as well. In the external test set P2, SAVEHR gains ranged from 7-46% over the next best performing model as shown in 1.
External Test Cohort: A major strength of our work is that we used a formal external test cohort which is as large as most studies’ development cohorts to validate the model’s performance. Importantly, performance, as measured by AUC-PR, was higher (except Diabetes) in the external test cohort providing evidence that our architecture may generate models that generalize across cohorts. Although testing model performance seems to be an important criteria, a vast majority of published studies do not evaluate how their model transfers to a new population.
Interpretability: Predictive models are not, in general, intended to be explanatory, yet clinicians certainly desire an explanation of the model’s prediction particularly when that prediction is inconsistent with the clinician’s intuition. A powerful characteristic of the SAVEHR architecture is that it allows us to assign importance (or risk) scores to features and combination of features (Figure 2). While not a full explanation, we believe based on the findings in this study that it may be possible to provide the clinician with a summary and visualization that provides an indication of the underlying reasoning for model’s prediction for an individual patient. In addition, by exploring the importance scores across populations of patients such as those in a certain age category or with specific risk predictions the clinician may gain insight into which features contribute to risk in that category of patients.
We examine the feature importance for two patients one with elevated Congestive Heart Failure (CHF) risk, Patient A (57%.) and Control A (13%) who correspondingly have similar characteristics (demographics and clinical encounters). We graphically illustrate the importance of pairwise feature interactions with color from deep blue to deep red indicating increasing importance. The features listed on the x-axis and the y-axis are the same for each panel, x-axis represents the ICD-9/10 code, while y-axis has descriptive labels for the codes. The mutual interactions are averaged, given there is no precedence for a feature over other. We observe that the patient identified as high risk has more interactions with high importance than the patients identified as low-risk, the interactions with high importance are multiple and diffuse: There are not one or two interactions but many that have high importance in patients with elevated risk, and many of the high importance interactions, but certainly not all, make clinical sense. A similar visualization is provided for one of the best performing attention based baselines in Appendix (Figure 10).
We provide a new self-attention based deep neural network architecture to extract interpretable and actionable information from heterogeneous, sparse time-series data from electronic health records. We provide a multitude of performance metrics on the models for a comprehensive comparison of the current state-of-the-art and our models. Our model yields SoTA results across four different clinical conditions on an external cohort of thousands of patients monitored for a year or longer. Finally, we provide samples from anonymized patients to identify the interpretability of prediction scores to demonstrate how clinicians can incorporate our risk scores into the clinical workflows. We believe the relative importance of these features and feature interactions with a appropriate visualization can improve clinician’s confidence in model predictions. Clinicians could utilize these predictions to target and modulate clinical interventions with greater precision.
Appendix A Population Statistics in P1 and P2
We describe the population statistics such as gender ratio and average age for all the four clinical conditions and present them in Figure 3.
Appendix B Disease Cohorts
We create cohorts (Training, Validation and Test) on Population P1, and use P2 entirely as an external test set for four chronic diseases - Congestive Heart Failure (CHF), Kidney Failure, Diabetes Type II and Chronic Obstructive Pulmonary Disease (COPD).
|Disease||Training (P1)||Validation (P1)||Test (P1)||External Test (P2)|
|case : control|
|CHF||14343 : 159567||793 : 8361||3916 : 41851||1259 : 5890|
|Kidney||8085 : 66045||447 : 3455||2216 : 17292||757 : 9351|
|Diabetes Type II||7674 : 53308||429 : 2781||2088 : 13961||3422 : 6997|
|COPD||11301 : 104719||641 : 5466||3107 : 27425||1767 : 5000|
Appendix C Index Date
The case-control design within cohorts for each disease was created for patients who received care between 2015 and 2018. Incident cases for each condition was defined as patients between ages of 30 and 80 years of age for whom an ICD-9 or ICD-10 code representing the condition was recorded as an encounter diagnosis at least three times in a six-month period but never had any prior diagnosis for the condition. We defined the index date as the date of the first of the three qualifying encounters. We did not consider any of the data from the 3 months prior (buffer period) to the index date in order to avoid incorporating diagnostic data that had been obtained but not yet resulted in a diagnosis being recorded. Control patients were selected as those who had at least 5 encounters in a 2-year period but never had a diagnostic code for the condition being modelled recorded. The last encounter recorded in the system was chosen as index date for control patients. Figure 4 illustrates our use of 12-month observation window that’s between 27 and 15-months from the index date and a prediction window of 15-months. Since, the observation window was fixed across a relatively long time window (12-month), we aggregate frequency counts of diagnosis codes assigned for encounters across a time window similar to that in Choi2016UsingRN to facilitate temporal learning. Codes that had fewer than 50 occurrences in cohort were filtered out.
Appendix D End-to-End Data and Modeling Pipeline
HealtheDataLab is a big data processing platform built on Amazon EMR. The data, population health data with longitudinal patient records are ingested from Amazon S3. The end-to-end flow is as follows, an AWS Data Pipeline job orchestrates the transformation of data, the launch of an Amazon EMR cluster, and creates a data catalog along with a Hive metastore in Amazon RDS and AWS Glue. HealtheDataLab provides a Jupyter notebook running on an EC2 instance that connects to a spark pipeline on Amazon EMR. HealtheDataLab has custom packages like ontologies, FHIR support, and concepts mapping to empower data scientists to create patient cohorts in a very simplified manner. Once cohorts are created, they are stored in S3 as compressed numpy arrays. Then, the Amazon SageMaker machine learning job is kicked off with specified cohort location in S3, along with hyper-parameters to be optimized for. After the completion of the job, the best hyper-parameters are recorded and the job id is noted to run evaluation on secondary populations.
Appendix E Aggregated Deep Representation with MLP Attention across quarters
To capture the longitudinal dependencies and understand the importance of each quarter for a given patient, we feed the sequence of encoded quarterly representations from the self-attention layer into a bidirectional GRU-based RNN, presented in Equation 2
where represents the output by the GRU for quarter . We use MLP Attention (or multi-layer perceptron attention)  on top of the BiGRU layer to obtain weighted representation of each quarter.The weight of each attention vector at GRU output , , is calculated as a normalized weighted sum,
where are hidden state vectors from the BiGRU cell, the attention network and an attention model. Once we obtain the attention weights, the vector representation aggregated for the patient across quarters is computed by:
Once we obtain the aggregated representation for a patient across quarters, we add a softmax layer for the final outcome prediction given by
Logistic regression and Random Forest: We trained three commonly used baselines - logistic regression (LR), random forest (RF) and a multi-layer perceptron (MLP) with dropout. For the logistic regression and random forest, we use the implementation provided by scikit-learnscikit-learnwith no regularization.
Deep Learning Baselines (1D-CNN + BiGRU and BiGRU): Inspired by the success of 1D-CNN and RNN based architectures for clinical notes liu2018deep, we extend that to structured EHR data. The diagnosis codes for conditions and procedures that we collectively term as medical concepts, can be considered analogous to words in sentences. We use an embedding layer to encode the features into a continuous space. We use frequency of each code assigned in a time slice, and concatenate the frequencies for each code into the medical concept embedding as shown to be effective in mallya2019effectiveness. We feed the embedding data into a 1D-CNN first. Since the medical concepts are inherently not ordered, we use the equivalent of 1-gram, i.e a kernel size of 1 across feature embeddings for the 1D-CNN on each of the time slices. To exploit the longitudinal nature of EHR data, we feed the time slice aggregated representation from 1D-CNN into a bidirectional GRU (Bi-GRU) layer, we name this model CNN-1G. To measure the incremental value of 1D-CNN filters we create another baseline (BG) that uses only a Bi-GRU layer on top of the embedding layer. Demographic information may be static in nature, but is very critical to clinical decisions, hence we incorporate these features by concatenating to the Bi-GRU layer output from both the models described earlier.We also experiment with BiGRU instead of 1D-CNN and report performance on that.
Attention based Deep Learning Baselines (CNN-LargeKernel + BiGRU + Attention): To understand if attention could help in EHR based modeling, we add MLP attention chorowski2015attention to the BiGRU layer for the baselines (CNN-1G & BG) described in the earlier section to enable them to focus on the most important time window. The approach above with 1D-CNN wouldn’t capture the interactions among features very effectively due to our kernel size of 1.Hence, ideally we’d like to compute for any 2,3 or n-grams of features their collective and relative importance. Given that the ordering of medical concepts within a given time window doesn’t matter, anything beyond a 1-gram kernel would require an ordering. To incorporate this and avoid the need for massive n-gram computation, we propose a novel baseline named CNN Large Kernel (CNN-LK), where the 1D-CNN kernel size is set equal to the number of input features, essentially giving us a weighted combination of all the input features. To understand if the 1D-CNN adds values, we also use another baseline where we replace the large kernel layer with a Dense layer of the same size. We note that, for the aforementioned architectures, we are unable to determine pairwise importance between any two features.
Appendix G AUC-ROC results for all conditions across populations
Appendix H Distribution of MLP attention weights across quarters
In order to understand the importance across time slices, we compute the average attention per time slice across entire test set P1 and report it below in 4. T4, the closest to the index date is the most prominent quarter.
|Average Attention per timeslice||t1||t2||t3||t4|
h.1 MLP attention weights vs number of diagnosis counts
To assess the importance of attention with respect to the number of diagnosis in a given time-slice, we plot the average and standard deviation for the diagnosis counts. We observe that the model very low to zero attention to quarters without any diagnosis code. As the count increases, attention increases but the large error bars in both Figure 6(a) and Figure 6(b) suggest that its not always paying attention to time-slice with the most counts.
Appendix I Performance metric graphs with error bars
Appendix J Example patient heatmaps for CNN-LK-A model
In section 5 of the paper, we provide an example visualization for the SAVEHR model. To contrast that, below we provide heatmap from the CNN-LK-A model on the same set of case and control patients.
Appendix K SAVEHR Case population heatmaps in P1 and P2 for CHF
To understand the features that induce risk across the population as a whole, we generate averaged heat maps across all the case patients in P1 (Figure 11) and P2 (Figure 12). Noticeably, the top features in both of the populations differ, suggesting that the model is able to learn different characteristics and adapt.
Appendix L Feature importance tables for baselines
We report the feature importance as determined by averaging the importance scores predicted by the model across the predicted case patients.
|Logistic Regression||LR coefficient||Diagnosis Code||Description|
|1.207||G0378||Hospital Observation Service|
|1.096||735||Acquired hammer toe|
|0.926||v54||Aftercare fracture arm|
|0.901||816||Closed fracture of middle phalanx of second finger of right hand|
|0.794||191||Malignant Neoplasm of Brain|
|0.788||041||Mycoplasma infection in conditions classified elsewhere|
|0.783||432||Chronic spont intraparenchymal hemorrhage|
|Random Forest||coefficient||Diagnosis Code||Description|
|0.006||v58||Encounter for other and unspecified procedures and aftercare|
|0.004||v57||Care involving use of rehabilitation procedures|
|0.004||v70||General psychiatric examination|
|0.004||786||Chest wall pain|
|CNN 1 gram||Importance score||Diagnosis Code||Description|
|0.041183||v45||Status post lumbar surgery|
|0.03769||793||Abnormal Findings X-Ray Breast|
|0.02934||v57||Care involving use of rehabilitation procedures|
|0.020804||585||chronic renal failure|
|0.018625||562||Small bowel diverticular disease|
|0.016736||v10||Personal History of Malignant Neoplasm of Eye|
|0.016391||455||External hemorrhoids with complication|
|CNN LargeKernel||Importance score||Diagnosis Code||Description|
|0.05082||569||Colostomy and enterostomy complications|
|0.48861||M06||Rheumatoid arthritis with negative rheumatoid factor (HCC)|
|0.38591||333||degenerative diseases of the basal ganglia|
|0.31523||250||Diabetes mellitus TypeII|
|0.31516||I62||Nontraumatic subdural hemorrhage|
|0.31463||182||Malignant Neoplasm of body of uterus|
|0.29420||H25||Senile cataract of right eye|