Graph Convolutional Transformer: Learning the Graphical Structure of Electronic Health Records

Graph Convolutional Transformer: Learning the Graphical Structure of Electronic Health Records

Edward Choi, Zhen Xu, Yujia Li, Michael W. Dusenberry,
Gerardo Flores, Yuan Xue, Andrew M. Dai
Google, USA   DeepMind, UK
Abstract

Effective modeling of electronic health records (EHR) is rapidly becoming an important topic in both academia and industry. A recent study showed that utilizing the graphical structure underlying EHR data (e.g. relationship between diagnoses and treatments) improves the performance of prediction tasks such as heart failure diagnosis prediction. However, EHR data do not always contain complete structure information. Moreover, when it comes to claims data, structure information is completely unavailable to begin with. Under such circumstances, can we still do better than just treating EHR data as a flat-structured bag-of-features? In this paper, we study the possibility of utilizing the implicit structure of EHR by using the Transformer for prediction tasks on EHR data. Specifically, we argue that the Transformer is a suitable model to learn the hidden EHR structure, and propose the Graph Convolutional Transformer, which uses data statistics to guide the structure learning process. Our model empirically demonstrated superior prediction performance to previous approaches on both synthetic data and publicly available EHR data on encounter-based prediction tasks such as graph reconstruction and readmission prediction, indicating that it can serve as an effective general-purpose representation learning algorithm for EHR data.

 

Graph Convolutional Transformer: Learning the Graphical Structure of Electronic Health Records


  Edward Choi, Zhen Xu, Yujia Li, Michael W. Dusenberry, Gerardo Flores, Yuan Xue, Andrew M. Dai Google, USA   DeepMind, UK

\@float

noticebox[b]Preprint. Under review.\end@float

1 Introduction

Large medical records collected by electronic healthcare records (EHR) systems in healthcare organizations enabled deep learning methods to show impressive performance in diverse tasks such as predicting diagnosis (Lipton et al., 2015; Choi et al., 2016a; Rajkomar et al., 2018), learning medical concept representations (Che et al., 2015; Choi et al., 2016b, c; Miotto et al., 2016), and making interpretable predictions (Choi et al., 2016d; Ma et al., 2017). As diverse as they are, one thing shared by all tasks is the fact that, under the hood, some form of neural network is processing EHR data to learn useful patterns from them. To successfully perform any EHR-related task, it is essential to learn effective representations of various EHR features: diagnosis codes, lab values, encounters, and even patients themselves.

Figure 1: The graphical structure of electronic health records. A single visit consists of multiple types of features, and their connections (red edges) reflect the physician’s decision process.

EHR data are typically stored in a relational database that can be represented as a hierarchical graph depicted in Figure 1. The common approach for processing EHR data with neural networks has been to treat each encounter as an unordered set of features, or in other words, a bag of features. However, the bag of features approach completely disregards the graphical structure that reflects the physician’s decision process. For example, if we treat the encounter in Figure 1 as a bag of features, we will lose the information that Benzonatate was ordered because of Cough, not because of Abdominal pain.

Recently, motivated by this EHR structure, Choi et al. (2018) proposed MiME, a model architecture that reflects EHR’s encounter structure, specifically the relationships between the diagnosis and its treatment. MiME outperformed various bag of features approaches in prediction tasks such as heart failure diagnosis prediction. Their study, however, naturally raises the question: when the EHR data do not contain structure information (the red edges in Figure 1), can we still do better than bag of features in learning the representation of the data for various prediction tasks? This question emerges in many occasions, since EHR data do not always contain the entire structure information. For example, some dataset might describe which treatment lead to measuring certain lab values, but might not describe the reason diagnosis for ordering that treatment. Moreover, when it comes to claims data, such structure information is completely unavailable to begin with.

To address this question, we study the possibility of using the Transformer (Vaswani et al., 2017) to utilize the unknown encounter structure for various prediction tasks when the structure information is unavailable. Specifically, we describe the graphical nature of encounter records, and argue that the Transformer is a reasonable model to discover implicit encounter structure. Then we propose the Graph Convolutional Transformer (GCT) to more effectively utilize the characteristics of EHR data while performing diverse prediction tasks. We test the Transformer and GCT on both synthetic data and real-world EHR records for encounter-based prediction tasks such as graph reconstruction and readmission prediction. In all tasks, GCT consistently outperformed baseline models, showing its potential to serve as an effective general-purpose representation learning algorithm for EHR data.

2 Related Work

Although there are recent works on medical concept embedding, focusing on patients Che et al. (2015); Miotto et al. (2016); Suresh et al. (2017); Nguyen et al. (2018), visits Choi et al. (2016c), or codes Tran et al. (2015); Choi et al. (2017), the graphical nature of EHR has not been fully explored yet. Choi et al. (2018) proposed MiME, which derives the visit representation in a bottom-up fashion according to the encounter structure. For example in Figure 1, MiME first combines the embedding vectors of lab results with the Cardiac EKG embedding, which in turn is combined with both the Abdominal Pain embedding and the Chest Pain embedding. Then all diagnosis embeddings are pooled together to derive the final visit embedding. By outperforming various bag-of-features models in heart failure prediction and general disease prediction, MiME demonstrated the usefulness of the structure information of encounter records.

The Transformer Vaswani et al. (2017) was proposed for natural language processing, specifically machine translation. It uses a novel method to process sequence data using only attention Bahdanau et al. (2014), and is recently showing impressive performance in other tasks such as word representation learning Devlin et al. (2018). Graph (convolutional) networks encompass various neural network methods to handle graphs such as molecular structures, social networks, or physical experiments. Kipf and Welling (2016); Hamilton et al. (2017); Battaglia et al. (2018); Xu et al. (2019). In essence, many graph networks can be described as different ways to aggregate a given node’s neighbor information, combine it with the given node, and derive the node’s latent representation Xu et al. (2019).

Some recent works focused on the connection between the Transformer’s self-attention and graph networks Battaglia et al. (2018). Graph Attention Networks Veličković et al. (2018) applied self-attention on top of the adjacency matrix to learn non-static edge weights, and Wang et al. (2018) used self-attention to capture non-local dependencies in images. Although our work also relies on self-attention, our interest lies in whether the Transformer can be an effective tool to capture the underlying graphical structure of EHR data even when the structure information is missing, thus improving encounter-based prediction tasks. In the next section, we first describe the graphical nature of EHR encounter data, then show that the Transformer is a reasonable algorithm for learning the hidden graphical structure of encounter records.

3 Method

3.1 Electronic Health Records as a Graph

As depicted in Figure 1, the -th visit starts with the visit node at the top. Beneath the visit node are diagnosis nodes , which in turn lead to ordering a set of treatments , where respectively denote the number of diagnosis and treatment codes in . Some treatments produce lab results , which may be associated with continuous values (e.g. blood pressure) or binary values (e.g. positive/negative allergic reaction). Since we focus on a single encounter in this study, we omit the time index throughout the paper.

If we assume all features , , 111If we bucketize the continuous values associated with , we can treat as a discrete feature like , . can be represented in the same latent space, then we can view an encounter as a graph consisting of nodes with an adjacency matrix that describes the connections between the nodes. We use as the collective term to refer to any of , , and for the rest of the paper. Given and , we can use graph networks or MiME222MiME is in fact, a special form of graph networks with residual connections. to derive the visit representation and use it for downstream tasks such as heart failure prediction. However, if we do not have the structural information , which is the case in many EHR data and claims data, we typically use feed-forward networks to derive , which is essentially summing all node representations ’s and projecting it to some latent space.

Figure 2: Learning the underlying structure of an encounter. We use Transformer to start from the left, where all nodes are implicitly fully-connected, and arrive at the right, where meaningful connections are described with thicker edges.

3.2 Transformer and Graph Networks

Even without the structure information , it is unreasonable to treat as a bag of nodes , because obviously physicians must have made some decisions when making diagnosis and ordering treatments. The question is how to utilize the underlying structure without explicit . One way to view this problem is to assume that all nodes in are implicitly fully-connected, and try to figure out which connections are stronger than the other as depicted in Figure 2. In this work, as discussed in section 2, we use Transformer to learn the underlying encounter structure. To elaborate, we draw a comparison between two cases:

  • Case A: We know , hence we can use Graph Convolutional Networks (GCN). In this work, we use multiple hidden layers between each convolution, motivated by Xu et al. (2019).

    (1)

    where , is the diagonal node degree matrix333Xu et al. (2019) does not use the normalizer to improve model expressiveness on multi-set graphs, but we include to make the comparison with Transformer clearer. of , and are the node embeddings and the trainable parameters of the -th convolution respectively. MLP is a multi-layer perceptron of the -th convolution with its own trainable parameters.

  • Case B: We do not know , hence we use Transformer, specifically the encoder with a single-head attention, which can be formulated as

    (2)

    where , , , and is the column size of . , and are trainable parameters of the -th Transformer block444Since we use MLP in both GCN and Transformer, the terms and are unnecessary, but we put them to follow the original formulations.. Note that positional encoding using sine and cosine functions is not required, since features in an encounter are unordered.

Given Eq. 1 and Eq. 2, we can readily see that there is a correspondence between the normalized adjacency matrix and the attention map , and between the node embeddings and the value vectors . In fact, GCN can be seen as a special case of Transformer, where the attention mechanism is replaced with the known, fixed adjacency matrix. Conversely, Transformer can be seen as a graph embedding algorithm that assumes fully-connected nodes and learns the connection strengths during training. Given this connection, it seems natural to use Transformer as an algorithm to learn the underlying structure of visits.

Figure 3: Creating the conditional probability matrix based on an example encounter. The gray cells are masked to zero probability since those connections are not allowed. The green cells are special connections that we know are guaranteed to exist. We assign a pre-defined scalar value (e.g. 1) to the green cells. The white cells are assigned the corresponding conditional probabilities.

3.3 Graph Convolutional Transformer

Although Transformer can potentially learn the hidden encounter structure, without a single piece of hint, it must search the entire attention space to discover meaningful connections between encounter features. Therefore we propose Graph Convolutional Transformer (GCT), which, based on data statistics, restricts the search to the space where it is likely to contain meaningful attention distribution.

Specifically, we use 1) the characteristic of EHR data and 2) the conditional probabilities between features. First, we use the fact that some connections are not allowed in the encounter record. For example, we know that treatment codes can only be connected to diagnosis codes, but not to other treatment codes. Based on this observation, we can create a mask , which will be used during the attention generation step. has negative infinities where connections are not allowed, and zeros where connections are allowed.

Conditional probabilities can be useful for determining potential connections between features. For example, given chest pain, fever and EKG, without any structure information, we do not know which diagnosis is the reason for ordering EKG. However, we can calculate from EHR data that is typically larger than , indicating that the connection between the former pair is more likely than the latter pair. Therefore we propose to use the conditional probabilities calculated from the encounter records as the guidance for deriving the attention. After calculating , and from all encounter records for all diagnosis codes , treatment codes , and lab codes , we can create a guiding matrix when given an encounter record, as depicted by Figure 3. We use to denote the matrix of conditional probabilities of all features, normalized such that each row sums to . Note that GCT’s attention , the mask , and the conditional probabilities are of the same size.

Given and , we want to guide GCT to recover the true graph structure as much as possible. But we also want to allow some room for GCT to learn novel connections that are helpful for solving given prediction tasks. Therefore GCT uses the following formulation:

(3)
Self-attention:
Regularization:
(4)

In preliminary experiments, we noticed that attentions were often uniformly distributed in the first block of Transformer. This seemed due to Transformer not knowing which connections were worth attending. Therefore we replace the attention mechanism in the first GCT block with the conditional probabilities . The following blocks use the masked self-attention mechanism. However, we do not want GCT to drastically deviate from the informative , but rather gradually improve upon . Therefore, based on the fact that attention is itself a probability distribution, and inspired by Trust Region Policy Optimization Schulman et al. (2015), we sequentially penalize attention of -th block if it deviates too much from the attention of -th block, using KL divergence. As shown by Eq. (4), the regularization terms are summed to the prediction loss term (e.g. negative log-likelihood), and the trade-off is controlled by the coefficient . GCT’s code will be made publicly available in the future.

4 Experiments

4.1 Synthetic Encounter Record

Choi et al. (2018) evaluated their model on proprietary EHR data that contained structure information. Unfortunately, to the best of our knowledge, there are no publicly available EHR data that contain structure information (which is the main motivation of this work). In order to evaluate GCT’s ability to learn EHR structure, we instead generated synthetic data that has a similar structure as EHR data.

The synthetic data has the same visit-diagnosis-treatment-lab results hierarchy as EHR data, and was generated in a top-down fashion. Each level was generated conditioned on the previous level, where the probabilities were modeled with the Pareto distribution. Pareto distribution follows the power law which best captures the long-tailed nature of medical codes. Using 1000 diagnosis codes, 1000 treatment codes, and 1000 lab codes, we initialized to follow the Pareto distribution, where , and respectively denote diagnosis, treatment, and lab random variables. is used to draw independent diagnosis codes , and is used to draw that are likely to co-occur with the previously sampled . is used to draw a treatment code , given some . is used to draw a lab code , given some and . Detailed description of generating the synthetic records and the link to download them are provided in Appendix A, and Appendix E, respectively. Code for generating the synthetic records will be open-sourced in the future. Table 1 summarizes the data statistics.

4.2 eICU Collaborative Research Dataset

Synthetic eICU
# of encounters 50,000 41,026
# of diagnosis codes 1,000 3,093
# of treatment codes 1,000 2,132
# of lab codes 1,000 N/A
Avg. # of diagnosis per visit 7.93 7.70
Avg. # of treatment per visit 14.59 5.03
Avg. # of lab per visit 21.31 N/A
Table 1: Statistics of the synthetic dataset and eICU

To test GCT on real-world EHR records, we use Philips eICU Collaborative Research Dataset555https://eicu-crd.mit.edu/about/eicu/ Pollard et al. (2018). eICU consists of Intensive Care Unit (ICU) records filtered for remote caregivers, collected from multiple sites in the United States between 2014 and 2015. From the encounter records, medication orders and procedure orders, we extracted diagnosis codes and treatment codes (i.e. medication, procedure codes). Since the data were collected from an ICU, a single encounter can last several days, where the encounter structure evolves over time, rather than being fixed as Figure 1. Therefore we used encounters where the patient was admitted for less than 24 hours, and removed duplicate codes (i.e. medications administered multiple times). Additionally, we did not use lab results as their values change over time in the ICU setting (i.e. blood pH level). We leave as future work how to handle ICU records that evolve over a longer period of time. Note that eICU does not contain structure information. For example, we know that cough and acetaminophen in Figure 1 occur in the same visit, but do not know if acetaminophen was prescribed due to cough. Table 1 summarizes the data statistics.

4.3 Baseline Models

  • GCN: Given the true adjacency matrix , we follow Eq. (1) to learn the feature representations of each feature in a visit . The visit embedding (i.e. graph-level representation) is obtained from the placeholder visit node . This model will serve as the optimal model during the experiments.

  • GCN: Instead of the true adjacency matrix , we use the conditional probability matrix , and follow Eq. (1).

  • GCN: Instead of the true adjacency matrix , we use a randomly generated normalized adjacency matrix where each element is indepdently sampled from a uniform distribution between 0 and 1. This model will let us evaluate whether true encounter structure is useful at all.

  • Shallow: Each is converted to a latent representation using multi-layer feedforward networks with ReLU activations. The visit representation is obtained by simply summing all ’s. We use layer normalization Ba et al. (2016), drop-out Srivastava et al. (2014) and residual connections He et al. (2016) between layers.

  • Deep: We use multiple feedforward layers with ReLU activations (including layer normalization, drop-out and residual connections) on top of shallow to increase the expressivity. Note that Zaheer et al. (2017) theoretically describes that this model is sufficient to obtain the optimal representation of a set of items (i.e., a visit consisting of multiple features).

4.4 Prediction Tasks

In order to evaluate the model’s capacity to leverage the implicit encounter structure, we use prediction tasks based on a single encounter, rather than a sequence of encounters, which was the experiment setup in Choi et al. (2018). Specifically, we test the models on the following tasks. Parentheses indicate which dataset is used for each task.

  • Graph reconstruction (Synthetic): Given an encounter with features, we train models to learn feature embeddings , and predict whether there is an edge between every pair of features, by performing an inner-product between each feature embedding pairs and (i.e. binary predictions). We do not use Deep baseline for this task, as we need individual embeddings for all features ’s.

  • Diagnosis-Treatment classification (Synthetic): We assign labels to an encounter if there are specific diagnosis ( and ) and treatment code () connections. Specifically, we assign label "1" if the encounter contains - connection, and label "2" if the encounter contains - connection. We intentionally made the task difficult so that the models cannot achieve a perfect score by just basing their prediction on whether , and exist in an encounter. The prevalence for both labels are approximately . Further details on the labels are provided in Appendix B. This is a multi-label prediction task using the visit representation .

  • Masked diagnosis code prediction (Synthetic, eICU): Given an encounter record, we mask a random diagnosis code . We train models to learn the embedding of the masked code to predict its identity, i.e. a multi-class prediction. For Shallow and Deep, we use the visit embedding as a proxy for the masked code representation. The row and the column of the conditional probability matrix that correspond to the masked diagnosis were also masked to zeroes.

  • Readmission prediction (eICU): Given an encounter record, we train models to learn the visit embedding to predict whether the patient will be admitted to the ICU again during the same hospital stay, i.e., a binary prediction. The prevalence is approximately .

  • Mortality prediction (eICU): Given an encounter record, we train models to learn the visit embedding to predict patient death during the ICU admission, i.e., a binary prediction. The prevalence is approximately .

Note that the conditional probability matrix was calculated only with the training set. Further training details and hyperparameter settings are described in Appendix C.

4.5 Prediction Performance

Graph reconstruction Diagnosis-Treatment classification
Model
Validation
AUCPR
Validation
AUROC
Test
AUCPR
Test
AUROC
Validation
AUCPR
Validation
AUROC
Test
AUCPR
Test
AUROC
GCN 0.9999 1.0 0.9999 1.0 1.0 1.0 1.0 1.0
GCN 0.5863 0.8894 0.5826 0.8887 0.9164 0.9934 0.9175 0.9931
GCN 0.5657 0.8807 0.5614 0.8803 0.8973 0.9916 0.8900 0.9903
Shallow 0.5482 0.8589 0.5444 0.8582 0.9235 0.9935 0.9103 0.9926
Deep - - - - 0.9173 0.9936 0.9070 0.9923
Transformer 0.5944 0.8925 0.5885 0.8910 0.9293 0.9942 0.9146 0.9930
GCT 0.6063 0.8966 0.6003 0.8956 0.9291 0.9942 0.9219 0.9933
Table 2: Graph reconstruction and diagnosis-treatment classification performance.
Synthetic eICU
Model
Validation
Accuracy
Test
Accuracy
Validation
Accuracy
Test
Accuracy
GCN 0.2828 0.3024 - -
GCN 0.1908 0.2096 0.7510 0.7389
GCN 0.1780 0.1980 0.7184 0.7137
Shallow 0.2006 0.2156 0.7387 0.7241
Deep 0.1900 0.2048 0.7406 0.7343
Transformer 0.1962 0.2026 0.7295 0.7158
GCT 0.2182 0.2328 0.7796 0.7644
Table 3: Masked diagnosis code prediction performance on the two datasets.
Readmission prediction Mortality prediction
Model
Validation
AUCPR
Validation
AUROC
Test
AUCPR
Test
AUROC
Validation
AUCPR
Validation
AUROC
Test
AUCPR
Test
AUROC
GCN 0.4952 0.7327 0.4706 0.7422 0.6664 0.9153 0.6082 0.9066
GCN 0.4740 0.6948 0.4421 0.6953 0.6592 0.9196 0.5996 0.9098
Shallow 0.3962 0.6852 0.3517 0.6740 0.6687 0.9206 0.6120 0.9120
Deep 0.4527 0.7170 0.4429 0.7403 0.6724 0.9213 0.6023 0.9110
Transformer 0.5057 0.7341 0.4571 0.7161 0.6594 0.9120 0.6039 0.9102
GCT 0.5279 0.7345 0.4837 0.7421 0.6757 0.9209 0.6178 0.9122
Table 4: Readmission prediction and mortality prediction performance on eICU.

Table 2 shows the graph reconstruction performance and the diagnosis-treatment classification performance of all models. Naturally, GCN shows the best performance since it uses the true adajcency matrix . Given that GCN’s performance is only marginally inferior to Transformer, we can infer that the conditional probability is indeed indicative of the true structure. GCT, which combines the strength of both GCN and Transformer shows the best performance, besides GCN. It is noteworthy that GCN outperforms Shallow. This seems to indicate that for graph reconstruction, attending to other features, regardless of how accurately the process follows the true structure, is better than individually embedding each feature. Diagnosis-treatment classification, on the other hand, clearly penalizes randomly attending to the features, since GCN shows the worst performance. GCT again shows the best performance.

Table 3 shows the model performance for masked diagnosis prediction for both datasets. GCN could not be evaluated on eICU, since eICU does not have the true structure. However, it naturally shows the best performance on the synthetic dataset. GCN again demonstrates the worst performance for both datasets, indicating that inferring a diagnosis based on other features requires accurate structure knowledge. Interestingly, the task performance is significantly higher for eICU than for the synthetic dataset. This is mainly due to eICU having a very skewed diagnosis code distribution. In eICU, more than 80% of encounters have diagnosis codes related to whether the patient has been in an operating room prior to the ICU admission. Therefore randomly masking one of them does not make the prediction task as difficult as for the synthetic dataset.

Table 4 shows the readmission prediction and mortality prediction performance of all models on eICU. As shown by GCN and GCT’s superior performance, it is evident that readmission prediction heavily benefits from using the latent encounter structure. Mortality prediction, on the other hand, seems to rely little on the encounter structure, as can be seen from the marginally superior performance of GCT. Rather, given the comparable performances of Shallow, it seems that mortality prediction can be successfully performed by simply adding all features. Even when the conditional probability is unnecessary, GCT still slightly outperforms other models, demonstrating its potential to be used as a general-purpose EHR modeling algorithm. These two experiments indicate that not all prediction tasks require the true encounter structure, and it is our future work to apply GCT to various prediction tasks to evaluate its effectiveness.

4.6 Evaluating the Learned Encounter Structure

In this section, we analyze the learned structure of both Transformer and GCT. As we know the true structure of synthetic records, we can evaluate how well both models learned via self-attention . Since we can view the normalized true adjacency matrix as a probability distribution, we can measure how well the attention map in Eq. (3) approximates using KL divergence .

Graph Reconstruction Diagnosis-Treatment Classification Masked Diagnosis Code Prediction
Model KL Divergence Entropy KL Divergence Entropy KL Divergence Entropy
GCN 8.2250 1.8761 8.2250 1.8761 8.2250 1.8761
Transformer 28.8123 1.6120 13.994 2.0953 15.1882 2.0230
GCT 7.6487 1.8788 8.5952 2.3143 9.0101 1.2618
Table 5: KL divergence between the normalized true adjacency matrix and the attention map. We also show the entropy of the attention map to indicate the sparseness of the attention distribution.

Table 5 shows the KL divergence between the normalized true adjacency and the learned attention on the test set of the synthetic data while performing three different tasks. For GCN, the adjacency matrix is fixed to the conditional probability matrix , so KL divergence can be readily calculated. For Transformer and GCT, we calculated KL divergence between and the attention maps in each self-attention block, and averaged the results. Note that KL divergence can be lowered by evenly distributing the attention across all features, which is the opposite of learning the encounter structure. Therefore we also show the entropy of alongside the KL divergence.

As shown by Table 5, the conditional probabilities are closer to the true structure than what Transformer has learned, in all three tasks. GCT shows similar performance to GCN in all tasks, and was even able to improve upon in the graph reconstruction task. Transformer showing strong performance in graph reconstruction, even with attentions significantly different from the true structure, again indicates the importance of just attending to other features in graph reconstruction, which was discussed in Section 4.5 regarding the performance of GCN. For the other two tasks, regularizing the models to stay close to seems to help the models generalize better to the test set, especially in masked diagnosis code prediction. We show visual examples of attention behavior of both Transformer and GCT in Appendix D.

5 Conclusion

Learning effective patterns from raw EHR data is an essential step for improving the performance of many downstream prediction tasks. In this paper, we addressed the issue where the previous state-of-the-art method required the complete encounter structure information, and proposed GCT to capture the underlying encounter structure when the structure information is unknown. Experiments demonstrated that GCT outperformed various baseline models on encounter-based tasks on both synthetic data and a publicly available EHR dataset, demonstrating its potential to serve as a general-purpose EHR modeling algorithm. In the future, we plan to apply GCT on patient-level tasks such as heart failure diagnosis prediction or unplanned emergency admission prediction, while working on improving the attention mechanism to learn more medically meaningful patterns.

References

  • Lipton et al. [2015] Zachary C Lipton, David C Kale, Charles Elkan, and Randall Wetzel. Learning to diagnose with lstm recurrent neural networks. arXiv preprint arXiv:1511.03677, 2015.
  • Choi et al. [2016a] Edward Choi, Mohammad Taha Bahadori, Andy Schuetz, Walter F Stewart, and Jimeng Sun. Doctor ai: Predicting clinical events via recurrent neural networks. In Machine Learning for Healthcare Conference, pages 301–318, 2016a.
  • Rajkomar et al. [2018] Alvin Rajkomar, Eyal Oren, Kai Chen, Andrew M Dai, Nissan Hajaj, Michaela Hardt, Peter J Liu, Xiaobing Liu, Jake Marcus, Mimi Sun, et al. Scalable and accurate deep learning with electronic health records. NPJ Digital Medicine, 1(1):18, 2018.
  • Che et al. [2015] Zhengping Che, David Kale, Wenzhe Li, Mohammad Taha Bahadori, and Yan Liu. Deep computational phenotyping. In Proceedings of the 21th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pages 507–516. ACM, 2015.
  • Choi et al. [2016b] Youngduck Choi, Chill Yi-I Chiu, and David Sontag. Learning low-dimensional representations of medical concepts. AMIA Summits on Translational Science Proceedings, 2016:41, 2016b.
  • Choi et al. [2016c] Edward Choi, Mohammad Taha Bahadori, Elizabeth Searles, Catherine Coffey, Michael Thompson, James Bost, Javier Tejedor-Sojo, and Jimeng Sun. Multi-layer representation learning for medical concepts. In Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pages 1495–1504. ACM, 2016c.
  • Miotto et al. [2016] Riccardo Miotto, Li Li, Brian A Kidd, and Joel T Dudley. Deep patient: an unsupervised representation to predict the future of patients from the electronic health records. Scientific reports, 6:26094, 2016.
  • Choi et al. [2016d] Edward Choi, Mohammad Taha Bahadori, Jimeng Sun, Joshua Kulas, Andy Schuetz, and Walter Stewart. Retain: An interpretable predictive model for healthcare using reverse time attention mechanism. In Advances in Neural Information Processing Systems, pages 3504–3512, 2016d.
  • Ma et al. [2017] Fenglong Ma, Radha Chitta, Jing Zhou, Quanzeng You, Tong Sun, and Jing Gao. Dipole: Diagnosis prediction in healthcare via attention-based bidirectional recurrent neural networks. In Proceedings of the 23rd ACM SIGKDD international conference on knowledge discovery and data mining, pages 1903–1911. ACM, 2017.
  • Choi et al. [2018] Edward Choi, Cao Xiao, Walter Stewart, and Jimeng Sun. Mime: Multilevel medical embedding of electronic health records for predictive healthcare. In NeurIPS, pages 4552–4562, 2018.
  • Vaswani et al. [2017] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NIPS, pages 5998–6008, 2017.
  • Suresh et al. [2017] Harini Suresh, Nathan Hunt, Alistair Johnson, Leo Anthony Celi, Peter Szolovits, and Marzyeh Ghassemi. Clinical intervention prediction and understanding using deep networks. In MLHC, 2017.
  • Nguyen et al. [2018] Phuoc Nguyen, Truyen Tran, and Svetha Venkatesh. Resset: A recurrent model for sequence of sets with applications to electronic medical records. In 2018 International Joint Conference on Neural Networks (IJCNN), pages 1–9. IEEE, 2018.
  • Tran et al. [2015] Truyen Tran, Tu Dinh Nguyen, Dinh Phung, and Svetha Venkatesh. Learning vector representation of medical objects via emr-driven nonnegative restricted boltzmann machines (enrbm). Journal of Biomedical Informatics, 2015.
  • Choi et al. [2017] Edward Choi, Mohammad Taha Bahadori, Le Song, Walter F Stewart, and Jimeng Sun. Gram: Graph-based attention model for healthcare representation learning. In SIGKDD, 2017.
  • Bahdanau et al. [2014] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473, 2014.
  • Devlin et al. [2018] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.
  • Kipf and Welling [2016] Thomas N Kipf and Max Welling. Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907, 2016.
  • Hamilton et al. [2017] Will Hamilton, Zhitao Ying, and Jure Leskovec. Inductive representation learning on large graphs. In Advances in Neural Information Processing Systems, pages 1024–1034, 2017.
  • Battaglia et al. [2018] Peter W Battaglia, Jessica B Hamrick, Victor Bapst, Alvaro Sanchez-Gonzalez, Vinicius Zambaldi, Mateusz Malinowski, Andrea Tacchetti, David Raposo, Adam Santoro, Ryan Faulkner, et al. Relational inductive biases, deep learning, and graph networks. arXiv:1806.01261, 2018.
  • Xu et al. [2019] Keyulu Xu, Weihua Hu, Jure Leskovec, and Stefanie Jegelka. How powerful are graph neural networks? In ICLR, 2019.
  • Veličković et al. [2018] Petar Veličković, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Lio, and Yoshua Bengio. Graph attention networks. In ICLR, 2018.
  • Wang et al. [2018] Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. Non-local neural networks. In CVPR, pages 7794–7803, 2018.
  • Schulman et al. [2015] John Schulman, Sergey Levine, Pieter Abbeel, Michael Jordan, and Philipp Moritz. Trust region policy optimization. In International Conference on Machine Learning, pages 1889–1897, 2015.
  • Pollard et al. [2018] Tom J Pollard, Alistair EW Johnson, Jesse D Raffa, Leo A Celi, Roger G Mark, and Omar Badawi. The eicu collaborative research database, a freely available multi-center database for critical care research. Scientific data, 5, 2018.
  • Ba et al. [2016] Jimmy Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization. arXiv preprint arXiv:1607.06450, 2016.
  • Srivastava et al. [2014] Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. Dropout: a simple way to prevent neural networks from overfitting. The Journal of Machine Learning Research, 15(1):1929–1958, 2014.
  • He et al. [2016] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778, 2016.
  • Zaheer et al. [2017] Manzil Zaheer, Satwik Kottur, Siamak Ravanbakhsh, Barnabas Poczos, Ruslan R Salakhutdinov, and Alexander J Smola. Deep sets. In Advances in neural information processing systems, pages 3391–3401, 2017.
  • Kingma and Ba [2014] Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
  • Abadi et al. [2016] Martín Abadi, Paul Barham, Jianmin Chen, Zhifeng Chen, Andy Davis, Jeffrey Dean, Matthieu Devin, Sanjay Ghemawat, Geoffrey Irving, Michael Isard, et al. Tensorflow: A system for large-scale machine learning. In 12th USENIX Symposium on Operating Systems Design and Implementation (OSDI 16), pages 265–283, 2016.
   // Dx vocab size
   // Treatment vocab size
   // Lab vocab size
  
  // Independent diagnosis occurrence probability
   = permute(normalize(numpy.random.pareto(, size=|D|)))
  
  // Conditional probability of a diagnosis co-occurring with another diagnosis
   = permute(normalize(numpy.random.pareto(, size=|D|)))
  …
   = permute(normalize(numpy.random.pareto(, size=|D|)))
  
  // Conditional probability of a treatment being ordered for a specific diagnosis
   = permute(normalize(numpy.random.pareto(, size=|M|)))
  …
   = permute(normalize(numpy.random.pareto(, size=|M|)))
  
  // Conditional probability of a lab being ordered for a specific treatment and diagnosis
   = permute(normalize(numpy.random.pareto(, size=|R|)))
  …
   = permute(normalize(numpy.random.pareto(, size=|R|)))
  
  // Bernoulli probability to determine whether to sample a diagnosis code given a previous diagnosis. Values are clipped to [0., 1.).
   = numpy.random.normal(, , size=|D|))
  
  // Bernoulli probability to determine whether to sample a treatment code given a diagnosis. Values are clipped to [0., 1.).
   = numpy.random.normal(, , size=|M|))
  
  // Bernoulli probability to determine whether to sample a lab code given a treatment and a diagnosis. Values are clipped to [0., 1.).
   = numpy.random.normal(, , size=(|M|, |D|))
  
  // Start creating synthetic records
  repeat
     Sample a diagnosis from
     repeat
        Sample a diagnosis from
     until 
  until 
  for  in the sampled diagnosis codes do
     repeat
        Sample a treatment from .
        repeat
           Sample a lab from .
        until 
     until 
  end for
Algorithm 1 Synthetic Encounter Records Generation

Appendix A Generating Synthetic Encounter Records

We describe the synthetic data creation process in this section. As described in Section 4.1, we use the Pareto distribution to capture the long-tailed nature of medical codes. We also define , and to determine when to stop sampling the codes. The overall generation process starts by sampling a diagnosis code. Then we sample a diagnosis code that is likely to co-occur with the previous sampled diagnosis code. After the diagnosis codes are sampled, we iterate through the sampled diagnosis code to sample a treatment code that is likely to be ordered for each diagnosis code. At the same time as sampling the treatment code, we sample lab codes that are likely to be produced by each treatment code. The overall algorithm is described in Algorithm 1.

Note that we use to model the treatment being ordered due to a diagnosis code, instead of , which might be more accurate since a treatment may depend on the already ordered treatments as well. However, we assume that given a diagnosis code, treatments that follow are conditionally independent, therefore each treatment can be factorized by . The same assumption went into using , instead of .

Finally, among the generated synthetic encounters, we removed the ones that had less than 5 diagnosis or treatment codes, in order to make the encounter structure sufficiently complex. Additionally, we removed encounters which contained more than 50 diagnosis or treatment or lab codes in order to make the encounter structure realistic (i.e. it is unlikely that a patient receives more than 50 diagnosis codes in one hospital encounter). For the eICU dataset, we also removed the encounters with more than 50 diagnosis or treatment codes. But we did not remove any encounters for having less than 5 diagnosis or treatment codes, as that would leave us only approximately 7,000 encounter records, which are rather small for training neural networks.

Appendix B Diagnosis-Treatment Classification Task

This task is used to test the model’s ability to derive a visit representation (i.e. graph-level representation) that correctly reserves the encounter structure. As described in Section 4.4, this is a multi-label classification problem, where an encounter is assigned the label “1” if it contains a connected pair of a diagnosis code and a treatment code (i.e. was ordered because of ). An encounter is assigned the label “2” if it contains a connected pair of and . Therefore it is possible that an encounter is assigned both labels “1” and “2”, or not assigned any label at all.

Since we want to test the model’s ability to correctly learn the encounter structure, we do not want the model to achieve a perfect score, for example, by just predicting label “1” based on whether both and simply exist in an encounter. Therefore we adjusted the sampling probabilities to make this task difficult. Specifically, we set . Therefore the probability of an encounter containing a - connection is . The probability of an encounter contaning a - connection is . Therefore The overall probability of the two connection pairs occurring in an encounter are more or less the same, and the model cannot achieve a perfect score unless the model correctly identifies the encounter structure.

Appendix C Training Details

We divided both the synthetic dataset and eICU into a training set, a validation set, and a test set in 8:1:1 ratio. All models were trained with Adam Kingma and Ba [2014] on the training set, and performance was evaluated against the validation set to select the final model. Final performance was evaluated against the test set. We used the minibatch of size 32, and trained all models for 200,000 iterations (i.e. minibatch updates), which was sufficient for convergence for all tasks. After an initial round of preliminary experiments, the embedding size of the encounter features was set to 128. For GCN, GCN, GCN, Transformer, and GCT, we used undirected adjacency/attention matrix to enhance the message passing efficiency. All models were implemented in TensorFlow 1.13 Abadi et al. [2016], and trained with a system equipped Nvidia P100’s.

Tunable hyperparameters for models Shallow, Deep, GCN, GCN, GCN, and Transformer are as follows:

  • Adam learning rate ()

  • Drop-out rate between layers ()

Shallow used 15 feedforward layers and Deep used 8 feedforward layers before, and 7 feedforward layers after summing the embeddings. The number of layers were chosen to match the number of trainable parameters of Transformer and GCT. GCN, GCN, GCN and Transformer three self-attention blocks, which was sufficient to cover the entire depth of EHR encounters. Transformer used one attention head to match its representative power to GCN, GCN, and GCN, and so that we can accurately evaluate the effect of learning the correct encounter structure.

Tunable hyperparameters for GCT are as follows:

  • Adam learning rate ()

  • Drop-out rate between layers ()

  • Regularization coefficient ()

GCT also used three self-attention blocks and one attention head. All Hyperparameters were searched via bayesian optimization with Gaussian Process for 48-hour wall clock time.

Appendix D Attention Behavior

Figure 4: Attentions from each self-attention block of Transformer trained for graph reconstruction. Code starting with ‘D’ are diagnosis codes, ‘T’ treatment codes, ‘L’ lab codes. The diagnosis code with the red background D_199 is attending to the other features. The red bars indicate the codes that are actually connected to D_199, and the blue bars indicate the attention given to all codes.
Figure 5: Attentions from each self-attention block of GCT trained for graph reconstruction. Code starting with ‘D’ are diagnosis codes, ‘T’ treatment codes, ‘L’ lab codes. The diagnosis code with the red background D_199 is attending to the other features. The red bars indicate the codes that are actually connected to D_199, and the blue bars indicate the attention given to all codes.
Figure 6: Attentions from each self-attention block of Transformer trained for masked diagnosis code prediction. Code starting with ‘D’ are diagnosis codes, ‘T’ treatment codes, ‘L’ lab codes. The diagnosis code with the red background D_199 is attending to the other features. The diagnosis code with the gray background D_294 is the masked diagnosis code. The red bars indicate the codes that are actually connected to D_199, and the blue bars indicate the attention given to all codes.
Figure 7: Attentions from each self-attention block of GCT trained for masked diagnosis code prediction. Code starting with ‘D’ are diagnosis codes, ‘T’ treatment codes, ‘L’ lab codes. The diagnosis code with the red background D_199 is attending to the other features. The diagnosis code with the gray background D_294 is the masked diagnosis code. The red bars indicate the codes that are actually connected to D_199, and the blue bars indicate the attention given to all codes.

In this section, we compare the attention behavior of Transformer and GCT in two different context; graph reconstruction and masked diagnosis code prediction. We randomly chose an encounter record from the test set of the synthetic dataset, which had less than 30 codes in order to enhance readability. To show the attention distribution of a specific code, we chose the first diagnosis code connected to at least one treatment. Figure 4 shows Transformer’s attentions in each self-attention block when performing graph reconstruction. Specifically we show the attention given by the diagnosis code D_199 to other codes. The red bars indicate the true connections, and the blue bars indicate the attention given to all codes. It can be seen that Transformer evenly attends to all codes in the first block, then develops its own attention. In the second block, it successfully recovers two of the true connections, but attends to incorrect codes in the third block.

Figure 5 shows GCT’s attention in each self-attention blcok when performing graph reconstruction. Contrary to Transformer, GCT starts with a very specific attention distribution. The first two attentions given to the placeholder Visit node, and to itself are determined by the scalar value from Figure 3. However, the attentions given to the treatment codes, especially T_939 are derived from the conditional probability matrix . Then in the following self-attention blocks, GCT starts to deviate from , and the attention distribution becomes more similar to the true adjacency matrix. This nicely shows the benefit of using as a guide to learning the encounter structure.

Since the goal of the graph reconstruction task is to predict the edges between nodes, it may be an obvious result that both Transformer and GCT’s attentions mimic the true adjacency matrix. Therefore, we show another set of attentions from Transformer and GCT trained for the masked diagnosis code prediction task. Figure 6 shows Transformer’s attention while performing the masked diagnosis code prediction. Note that the diagnosis code D_294 is maksed, and therefore the model does not know its identity. Similar to graph reconstruction, Transformer starts with an evenly distributed attentions, and develops its own structure. Interestingly, it learns to attend to the right treatment in the third block, but mostly tries to predict the masked node’s identity by attending to other diagnosis codes, while mostly ignoring the lab codes.

Figure 7 shows GCT’s attention while performing the masked diagnosis code prediction task. Again, GCT starts with the conditional probability matrix , then develops its own attention. But this time, understandably, the attention maps are not as similar to the true structure as in the graph reconstruction task. An interesting finding is that GCT attends heavily to the placeholder Visit node in this task. This is inevitable, given that we only allow diagnosis codes to attend to treatment codes (see the white cells in Figure 3), and therefore, if GCT wants to look at other diagnosis codes, it can only be done by indirectly receiving information via the Visit node. And as Figure 6 suggests, predicting the identity of the masked code seems to require knowing the co-occurring diagnosis codes as well as the treatment codes. Therefore, unlike in the graph reconstruction task, GCT puts heavy attention to the Visit node in this task, in order to learn the co-occurring diagnosis codes.

Appendix E Sharing the Synthetic Records

The synthetic records used for the experiments can be downloaded via this link (https://www.dropbox.com/s/ojx9jr4yyvmfdum/synthetic.tar.gz). It is a compressed file, which you can decompress to obtain the following files.

  • visits_50k.p: This is a Python cPickle file. It is a List of encounter records, where each record is a List of a diagnosis code and the associated treatment-lab Lists. For example, [[1, []], [2, [[3, []], [4, [5, 6]]]]] describes a single encounter record. The first diagnosis code is “1”, and no treatment or lab codes follow. The second diagnosis code is “2”, and the treatment “3”, and treatment “4” are ordered because of the diagnosis “2”. Additionally, treatment “4” is followed by two lab codes, “5” and “6”. visits_50k.p consists of 500,000 encounter records that follow this format.

  • dx_probs.npy: This Python Numpy file corresponds to in Algorithm 1. It is a 1000-dimensional vector, where the -th element represents .

  • dx_dx_cond_probs.npy: This Python Numpy file corresponds to in Algorithm 1. It is a 1000-by-1000 matrix, where the -th element represents .

  • dx_proc_cond_probs.npy: This Python Numpy file corresponds to in Algorithm 1. It is a 1000-by-1000 matrix, where the -th element represents .

  • dx_proc_lab_cond_probs.npy: This Python Numpy file corresponds to in Algorithm 1. It is a 1000-by-1000-by-1000 3D tensor, where the -th element represents .

  • dx_dx_probs.npy: This Python Numpy file corresponds to in Algorithm 1. It is a 1000-dimensional vector, where the -th element represents .

  • multi_proc_probs.npy: This Python Numpy file corresponds to in Algorithm 1. It is a 1000-dimensional vector, where the -th element represents .

  • multi_lab_probs.npy: This Python Numpy file corresponds to in Algorithm 1. It is a 1000-by-1000 matrix, where the -th element represents .

Note that and discussed in Appendix B correspond to the actual codes and , respectively. Therefore, for example, the -th element in dx_proc_cond_probs.npy equals approximately 0.2, and the -the element equals approximately 0.8.

Comments 0
Request Comment
You are adding the first comment!
How to quickly get a good reply:
  • Give credit where it’s due by listing out the positive aspects of a paper before getting into which changes should be made.
  • Be specific in your critique, and provide supporting evidence with appropriate references to substantiate general statements.
  • Your comment should inspire ideas to flow and help the author improves the paper.

The better we are at sharing our knowledge with each other, the faster we move forward.
""
The feedback must be of minimum 40 characters and the title a minimum of 5 characters
   
Add comment
Cancel
Loading ...
374727
This is a comment super asjknd jkasnjk adsnkj
Upvote
Downvote
""
The feedback must be of minumum 40 characters
The feedback must be of minumum 40 characters
Submit
Cancel

You are asking your first question!
How to quickly get a good answer:
  • Keep your question short and to the point
  • Check for grammar or spelling errors.
  • Phrase it like a question
Test
Test description