Transfer Learning for Clinical Time Series Analysis using Recurrent Neural Networks
Deep neural networks have shown promising results for various clinical prediction tasks such as diagnosis, mortality prediction, predicting duration of stay in hospital, etc. However, training deep networks – such as those based on Recurrent Neural Networks (RNNs) – requires large labeled data, high computational resources, and significant hyperparameter tuning effort. In this work, we investigate as to what extent can transfer learning address these issues when using deep RNNs to model multivariate clinical time series. We consider transferring the knowledge captured in an RNN trained on several source tasks simultaneously using a large labeled dataset to build the model for a target task with limited labeled data. An RNN pre-trained on several tasks provides generic features, which are then used to build simpler linear models for new target tasks without training task-specific RNNs. For evaluation, we train a deep RNN to identify several patient phenotypes on time series from MIMIC-III database, and then use the features extracted using that RNN to build classifiers for identifying previously unseen phenotypes, and also for a seemingly unrelated task of in-hospital mortality. We demonstrate that (i) models trained on features extracted using pre-trained RNN outperform or, in the worst case, perform as well as task-specific RNNs; (ii) the models using features from pre-trained models are more robust to the size of labeled data than task-specific RNNs; and (iii) features extracted using pre-trained RNN are generic enough and perform better than typical statistical hand-crafted features.
Electronic health records (EHR) consisting of a patient’s medical history can be leveraged for various clinical applications such as diagnosis, recommending medicine, etc. Traditional machine learning techniques often require careful domain-specific feature engineering before building the prediction models. On the other hand, deep learning approaches enable end-to-end learning without the need of hand-crafted and domain-specific features, and have recently produced promising results for various clinical prediction tasks (Miotto et al., 2017; Ravì et al., 2017). Applications of such approaches include medical diagnosis (Choi et al., 2016; Lipton et al., 2015), predicting future clinical events (Miotto et al., 2016), etc.
Deep Recurrent Neural Networks (RNNs) have been successfully explored for various time series and sequential modeling applications of EHR data such as diagnoses (Lipton et al., 2015; Che et al., 2016; Choi et al., 2016), mortality prediction and estimating length of stay (Harutyunyan et al., 2017; Purushotham et al., 2017; Rajkomar et al., 2018). However, training RNNs is compute-intensive due to sequential nature of computations, and requires large amount of labeled data. Transfer learning (Pan and Yang, 2010; Bengio, 2012) has been used to overcome these challenges. It enables knowledge transfer from neural networks trained on a source task with sufficient training instances to a related target task with few training instances. Moreover, fine-tuning a pre-trained network for target task is often faster and easier than constructing and training a new network from scratch (Bengio, 2012; Malhotra et al., 2017). Another advantage of learning in such a manner is that the pre-trained network has already learned to extract a rich set of generic features that can then be applied to a wide range of other similar tasks (Malhotra et al., 2017; Gupta et al., 2018).
Transfer learning via fine-tuning parameters of pre-trained models for end tasks has been recently considered for medical applications, e.g. (Choi et al., 2016; Lee et al., 2017). However, fine-tuning a large number of parameters with a small labeled dataset may cause overfitting. If the parameters to be tuned for target task can be reduced to a small number, then the pre-trained deep models can be leveraged in a better way (Keshari et al., 2018). In this work, we evaluate an approach to transfer the learning from a set of tasks to another related task for clinical time series by means of an RNN. Considering phenotype detection from time series of physiological parameters as a binary classification task, we train an RNN classifier on a diverse set of such binary classification tasks (one task per phenotype) simultaneously using a large labeled dataset; so that the RNN thus obtained provides general-purpose features for time series. The features extracted using this RNN are then transferred to train a simple logistic regression model (Hosmer Jr et al., 2013) for target tasks, i.e. identifying a new phenotype and predicting in-hospital mortality, with few labeled instances (detailed in Section 3).
Through empirical evaluation on MIMIC-III dataset (Johnson et al., 2016) (as detailed in Section 4), we demonstrate that: 1) it is possible to leverage deep RNNs for clinical time series classification tasks in scarcely-labeled scenarios via transfer learning; 2) a deep model trained on multiple diverse tasks on a large labeled dataset provides features that are generic enough to build models for new tasks from clinical time series data. Further, our approach provides a computationally-efficient way to use deep models for new phenotypes once an RNN has been trained to classify a diverse-enough set of phenotypes.
2. Related Work
Unsupervised pre-training has been shown to be effective in capturing the generic patterns and distribution from EHR data (Miotto et al., 2016). Further, RNNs for time series classification from EHR data have been successfully explored, e.g. in (Lipton et al., 2015; Che et al., 2016). However, these approaches do not address the challenge posed by limited labeled data, which is the focus of this work. Transfer learning using deep neural networks has been recently explored for medical applications: A model learned from one hospital could be adapted to another hospital for same task via recurrent neural networks (Choi et al., 2016). A deep neural network was used to transfer knowledge from one dataset to another while the source and target tasks (named-entity recognition from medical records) are the same in (Lee et al., 2017). However, in both these transfer learning approaches, the source and target tasks are the same while only the dataset changes.
Features extracted from a pre-trained off-the-shelf RNN-based feature extractor (TimeNet (Malhotra et al., 2017)) have been shown to be useful for patient phenotyping and mortality prediction tasks (Gupta et al., 2018). In this work, we provide an approach to transfer the model trained on several healthcare-specific tasks to a different (although related) classification task using RNNs for clinical time series. Training a deep RNN for multiple related tasks simultaneously on clinical time series has been shown to improve the performance for all tasks (Harutyunyan et al., 2017). In this work, we additionally demonstrate that a model trained in this manner serves as a reasonable starting point for building models for new related tasks.
3. Proposed Approach
Consider sets and of labeled time series instances corresponding to source () and target () tasks, respectively. , where is the number of time series instances corresponding to patients (in our experiments, we consider each episode of hospital stay for a patient as a separate data instance). Denoting time series by and the corresponding target label by for simplicity of notation, we have denote a time series of length , where is an -dimensional vector corresponding to parameters such as glucose level, heart rate, etc. Further, , where is the number of binary classification tasks. For example, for binary classification tasks corresponding to presence or absence of 5 phenotypes, indicates that phenotypes 1, 3, and 4 are present while phenotypes 2 and 5 are absent. such that , and such that the target task is a binary classification task. We assume that the time series in belongs to same parameters as in . We first train the deep RNN on source tasks using , and then train the simpler logistic regression (LR) classifier for target task using and the features obtained via the deep RNN, as shown in Figure 1. We next provide details of training RNN and LR models.
3.1. Supervised Pre-training of RNN
Training an RNN on binary classification tasks simultaneously can be considered as a multi-label classification problem. We train a multi-layered RNN with recurrent layers having Gated Recurrent Units (GRUs) (Cho et al., 2014) to map to . Let denote the output of recurrent units in -th hidden layer at time , and denote the hidden state at time obtained as concatenation of hidden states of all layers, where is the number of GRU units in a hidden layer and . The parameters of the network are obtained by minimizing the cross-entropy loss given by via stochastic gradient descent:
Here = is the sigmoid activation function, is the estimate for target , are parameters of recurrent layers, and and are parameters of the classification layer.
3.2. Using features from pre-trained RNN
For input , the hidden state at last time step is used as input feature vector for training the LR model. We obtain probability of the positive class for the binary classification task as = , where , are parameters of LR. The parameters are obtained by minimizing the negative log-likelihood loss :
where is the L regularizer with controlling the extent of sparsity – with higher implying more sparsity, i.e. fewer features from the representation vector are selected for the final classifier. It is to be noted that this way of training the LR model on pre-trained RNN features is equivalent to freezing the parameters of all the hidden layers of the pre-trained RNN while tuning the parameters of a new final classification layer. The sparsity constraint ensures that only a small number of parameters are to be tuned which is useful to avoid overfitting when labeled data is small.
4. Experimental Evaluation
We evaluate the proposed approach on binary classification tasks as defined in (Harutyunyan et al., 2017): i) estimating the presence (class 1) or absence (class 0) of a phenotype (e.g. cardiac dysrhythmia, chronic kidney disease, etc.) from time series of parameters such as heart rate and respiratory rate, and ii) in-hospital mortality prediction where the goal is to predict whether the patient will survive or not given time series observations after ICU admission (class 1: patient dies, class 0: patient survives).
We use MIMIC-III (v1.4) clinical database (Johnson et al., 2016) which consists of over 60,000 ICU stays across 40,000 critical care patients. We use benchmark data from (Harutyunyan et al., 2017) with same data-splits for train, validation and test datasets111Refer (Harutyunyan et al., 2017) and https://github.com/yerevann/mimic3-benchmarks for dataset sizes and other details.. Train, validation and test sets for various scenarios considered in our experiments are subsets of the respective original datasets (as described later). The data contains multivariate time series for multiple physiological parameters with 12 real-valued (e.g. blood glucose level, systolic blood pressure, etc.) and 5 categorical parameters (e.g. Glascow coma scale motor response, Glascow coma scale verbal, etc.), sampled at 1 hour interval. The categorical variables are converted to one-hot vectors such that final multivariate time series has dimension . We use time series from only up to first hours of ICU stay for all predictions (such that ) to imitate the practical scenario where early predictions are important.
The benchmark dataset contains label information for presence/absence of 25 phenotypes common in adult ICUs (e.g. acute cerebrovascular disease, diabetes mellitus with complications, gastrointestinal hemorrhage, etc.). We consider phenotypes to obtain the pre-trained RNN which we refer to as MIMIC-Net (MN), and test the transferability of the features from MN to remaining 5 phenotype (binary) classification tasks with varying labeled data sizes. Since more than one phenotypes may be present in a patient at a time, we remove all patients with any of the 5 test phenotypes from the original train and validate sets (despite of them having one of the 20 train phenotypes also) to avoid any information leakage. We report average results in terms of weighted AUROC (as in (Harutyunyan et al., 2017)) on two random splits of 20 train phenotypes and 5 test phenotypes, such that we have 10 test phenotypes (tested one-at-a-time). We also test transferability of MN features to in-hospital mortality prediction task.
We consider number of hidden layers , batch size of 128, regularization using dropout factor (Pham et al., 2014) of 0.3, and Adam optimizer (Kingma and Ba, 2014) with initial learning rate for training RNNs. The number of hidden units with minimum (eq. 1) on the validation set is chosen from . Best MN model was obtained for such that total number of features is . The L parameter is tuned on ,, (on a logarithmic scale) to minimize (eq. 2) on the validation set.
|Phenotyping222The average and standard deviation over 10 phenotypes is reported.||0.902 0.023||0.955 0.020||0.974 0.011|
Results and Observations
We refer to the LR model learned using MN features as MN-LR, and consider two baselines for comparison: 1) Logistic Regression (LR) using statistical features (including mean, standard deviation, etc.) from raw time series as used in (Harutyunyan et al., 2017), 2) RNN classifier (RNN-C) learned using training data for the target task. To test the robustness of the models for small labeled training sets, we consider subsets of training and validation datasets, while the test set remains the same. Further, we also evaluate the relevance of layer-wise features from the hidden layers. MN-LR-1 and MN-LR-2 refer to models trained using (the topmost hidden layer only) and (from both hidden layers), respectively.
Robustness to training data size: Phenotyping results in Figure 2(a) suggest that: (i) MN-LR and RNN-C perform equally well when using 100% training data, and are better than LR. This implies that the transfer learning based models are as effective as models trained specifically for the target task on large labeled datasets. (ii) MN-LR consistently outperforms RNN-C and LR models as training dataset is reduced. As the size of labeled training set reduces, the performance of RNN-C as well as MN-LR degrades. However, importantly, we observe that MN-LR degrades more gracefully and performs better than RNN-C. The performance gains from transfer learning are greater when the training set of the target task is small. Therefore, with transfer learning, fewer labeled instances are needed to achieve the same level of performance as model trained on target data alone. (iii) As labeled training set is reduced, LR performs better than RNN-C confirming that deep networks are prone to overfitting on small datasets.
From Figure 2(b), we interestingly observe that MN-LR results are at least as good as RNN-C and LR on the seemingly unrelated task of mortality prediction, suggesting that the features learned are generic enough and transfer well.
Importance of features from different hidden layers: We observe that MN-LR-1 and MN-LR-2 perform equally well for phenotyping task (Figure 2(a)), suggesting that adding features from lower hidden layer do not improve the performance given higher layer features . For the mortality prediction task, we observe slight improvement in MN-LR-2 over MN-LR-1, i.e. adding lower layer features helps. A possible explanation for this behavior is as follows: since training was done on phenotyping tasks, features from top-most layer suffice for new phenotypes as well; on the other hand, the more generic features from the lower layer are useful for the unrelated task of mortality prediction.
Number of relevant features for a task: We observe that only a small number of features are actually relevant for a target classification task out of large number of input features to LR models (714 for LR, 300 for MN-LR-1, and 600 for MN-LR-2), As shown in Table 1, 95% of features have weight 0 (absolute value 0.001) for MN-LR models corresponding to phenotyping tasks due to sparsity constraint (eq. 2), i.e. most features do not contribute to the classification decision. The weights of features that are non-zero for at least one of target tasks for MN-LR-1 are shown in Figure 3. We observe that, for example, for MN-LR-1 model only 130 features (out of 300) are relevant across the 10 phenotype classification tasks and the mortality prediction task. This suggests that MN provides several generic features while LR learns to select the most relevant ones given a small labeled dataset. Table 1 and Figure 3 also suggest that MN-LR models use larger number of features for mortality prediction task, possibly because concise features for mortality prediction are not available in the learned set of features as MN was pre-trained for phenotype identification tasks.
We have proposed an approach to leverage deep RNNs for small labeled datasets via transfer learning. We trained an RNN model to identify several phenotypes via multi-label classification. This model is found to be generalize well for new tasks including identification of new phenotypes, and interestingly, for mortality prediction. We found that transfer learning performs better than the models trained specifically for the end task. Such transfer learning approaches can be a good starting point when building models with limited labeled datasets. Transferability and generalization capability of RNNs trained simultaneously on diverse tasks (such as length of stay, mortality prediction, phenotyping, etc. (Harutyunyan et al., 2017; Song et al., 2017)) to new tasks is an interesting future direction.
- Bengio (2012) Yoshua Bengio. 2012. Deep learning of representations for unsupervised and transfer learning. In Proceedings of ICML Workshop on Unsupervised and Transfer Learning. 17–36.
- Che et al. (2016) Zhengping Che, Sanjay Purushotham, Kyunghyun Cho, David Sontag, and Yan Liu. 2016. Recurrent neural networks for multivariate time series with missing values. arXiv preprint arXiv:1606.01865 (2016).
- Cho et al. (2014) Kyunghyun Cho, Bart Van Merriënboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, and Yoshua Bengio. 2014. Learning phrase representations using RNN encoder-decoder for statistical machine translation. arXiv preprint arXiv:1406.1078 (2014).
- Choi et al. (2016) Edward Choi, Mohammad Taha Bahadori, Andy Schuetz, Walter F Stewart, and Jimeng Sun. 2016. Doctor ai: Predicting clinical events via recurrent neural networks. In Machine Learning for Healthcare Conference. 301–318.
- Gupta et al. (2018) Priyanka Gupta, Pankaj Malhotra, Lovekesh Vig, and Gautam Shroff. 2018. Using Features from Pre-trained TimeNet for Clinical Predictions. The 3rd International Workshop on Knowledge Discovery in Healthcare Data at IJCAI.
- Harutyunyan et al. (2017) Hrayr Harutyunyan, Hrant Khachatrian, David C Kale, and Aram Galstyan. 2017. Multitask Learning and Benchmarking with Clinical Time Series Data. arXiv preprint arXiv:1703.07771 (2017).
- Hosmer Jr et al. (2013) David W Hosmer Jr, Stanley Lemeshow, and Rodney X Sturdivant. 2013. Applied logistic regression. Vol. 398. John Wiley & Sons.
- Johnson et al. (2016) Alistair EW Johnson, Tom J Pollard, et al. 2016. MIMIC-III, a freely accessible critical care database. Scientific data 3 (2016), 160035.
- Keshari et al. (2018) Rohit Keshari, Mayank Vatsa, Richa Singh, and Afzel Noore. 2018. Learning Structure and Strength of CNN Filters for Small Sample Size Training. arXiv preprint arXiv:1803.11405 (2018).
- Kingma and Ba (2014) Diederik P Kingma and Jimmy Ba. 2014. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980 (2014).
- Lee et al. (2017) Ji Young Lee, Franck Dernoncourt, and Peter Szolovits. 2017. Transfer Learning for Named-Entity Recognition with Neural Networks. arXiv preprint arXiv:1705.06273 (2017).
- Lipton et al. (2015) Zachary C Lipton, David C Kale, Charles Elkan, and Randall Wetzel. 2015. Learning to diagnose with LSTM recurrent neural networks. arXiv preprint arXiv:1511.03677 (2015).
- Malhotra et al. (2017) Pankaj Malhotra, Vishnu TV, Lovekesh Vig, Puneet Agarwal, and Gautam Shroff. 2017. TimeNet: Pre-trained deep recurrent neural network for time series classification. In 25th European Symposium on Artificial Neural Networks, Computational Intelligence and Machine Learning.
- Miotto et al. (2016) Riccardo Miotto, Li Li, Brian A Kidd, and Joel T Dudley. 2016. Deep patient: an unsupervised representation to predict the future of patients from the electronic health records. Scientific reports 6 (2016), 26094.
- Miotto et al. (2017) Riccardo Miotto, Fei Wang, Shuang Wang, Xiaoqian Jiang, and Joel T Dudley. 2017. Deep learning for healthcare: review, opportunities and challenges. Briefings in bioinformatics (2017).
- Pan and Yang (2010) Sinno Jialin Pan and Qiang Yang. 2010. A survey on transfer learning. IEEE Transactions on knowledge and data engineering 22, 10 (2010), 1345–1359.
- Pham et al. (2014) Vu Pham, Théodore Bluche, Christopher Kermorvant, and Jérôme Louradour. 2014. Dropout improves recurrent neural networks for handwriting recognition. In Frontiers in Handwriting Recognition (ICFHR). IEEE, 285–290.
- Purushotham et al. (2017) Sanjay Purushotham, Chuizheng Meng, Zhengping Che, and Yan Liu. 2017. Benchmark of Deep Learning Models on Large Healthcare MIMIC Datasets. arXiv preprint arXiv:1710.08531 (2017).
- Rajkomar et al. (2018) Alvin Rajkomar, Eyal Oren, Kai Chen, Andrew M Dai, Nissan Hajaj, Peter J Liu, Xiaobing Liu, Mimi Sun, Patrik Sundberg, Hector Yee, et al. 2018. Scalable and accurate deep learning for electronic health records. arXiv preprint arXiv:1801.07860 (2018).
- Ravì et al. (2017) Daniele Ravì, Charence Wong, Fani Deligianni, Melissa Berthelot, Javier Andreu-Perez, Benny Lo, and Guang-Zhong Yang. 2017. Deep learning for health informatics. IEEE journal of biomedical and health informatics 21, 1 (2017), 4–21.
- Song et al. (2017) Huan Song, Deepta Rajan, Jayaraman J Thiagarajan, and Andreas Spanias. 2017. Attend and Diagnose: Clinical Time Series Analysis using Attention Models. arXiv preprint arXiv:1711.03905 (2017).