Uncertainty-Aware Attention forReliable Interpretation and Prediction

Uncertainty-Aware Attention for
Reliable Interpretation and Prediction

Jay Heo , Hae Beom Lee11footnotemark: 1 , Saehoon Kim, Juho Lee,
Kwang Joon Kim, Eunho Yang, Sung Ju Hwang
UNIST, Ulsan, South Korea, KAIST, Daejeon, South Korea,
AITrics, Yonsei University College of Medicine, Seoul, South Korea,
University of Oxford, Oxford, England
jayheo7, hblee@unist.ac.kr, sjhwang82, eunhoy@kaist.ac.kr
shkim@aitrics.com, preppie@yuhs.ac, juho.lee@stats.ox.ac.uk
Equal contribution

Supplementary File for Uncertainty-Aware Attention
for Reliable Interpretation and Prediction

Jay Heo , Hae Beom Lee11footnotemark: 1 , Saehoon Kim, Juho Lee,
Kwang Joon Kim, Eunho Yang, Sung Ju Hwang
UNIST, Ulsan, South Korea, KAIST, Daejeon, South Korea,
AITrics, Yonsei University College of Medicine, Seoul, South Korea,
University of Oxford, Oxford, England
jayheo7, hblee@unist.ac.kr, sjhwang82, eunhoy@kaist.ac.kr
shkim@aitrics.com, preppie@yuhs.ac, juho.lee@stats.ox.ac.uk
Equal contribution
Abstract

Attention mechanism is effective in both focusing the deep learning models on relevant features and interpreting them. However, attentions may be unreliable since the networks that generate them are often trained in a weakly-supervised manner. To overcome this limitation, we introduce the notion of input-dependent uncertainty to the attention mechanism, such that it generates attention for each feature with varying degrees of noise based on the given input, to learn larger variance on instances it is uncertain about. We learn this Uncertainty-aware Attention (UA) mechanism using variational inference, and validate it on various risk prediction tasks from electronic health records on which our model significantly outperforms existing attention models. The analysis of the learned attentions shows that our model generates attentions that comply with clinicians’ interpretation, and provide richer interpretation via learned variance. Further evaluation of both the accuracy of the uncertainty calibration and the prediction performance with “I don’t know” decision show that UA yields networks with high reliability as well.

\PassOptionsToPackage

square,sort,comma,numbersnatbib

 

Uncertainty-Aware Attention for
Reliable Interpretation and Prediction


  Jay Heothanks: Equal contribution , Hae Beom Lee11footnotemark: 1 , Saehoon Kim, Juho Lee, Kwang Joon Kim, Eunho Yang, Sung Ju Hwang UNIST, Ulsan, South Korea, KAIST, Daejeon, South Korea, AITrics, Yonsei University College of Medicine, Seoul, South Korea, University of Oxford, Oxford, England jayheo7, hblee@unist.ac.kr, sjhwang82, eunhoy@kaist.ac.kr shkim@aitrics.com, preppie@yuhs.ac, juho.lee@stats.ox.ac.uk

\@float

noticebox[b]Preprint. Work in progress.\end@float

1 Introduction

For many real-world safety-critical tasks, achieving high reliablity may be the most important objective when learning predictive models for them, since incorrect predictions could potentially lead to severe consequences. For instance, failure to correctly predict the sepsis risk of a patient in ICU may cost his/her life. Deep learning models, while having achieved impressive performances on multitudes of real-world tasks such as visual recognition alexnet (); resnet (), machine translation Bahdanau15 () and risk prediction for healthcare retain (); futoma17 (), may be still susceptible to such critical mistakes since most do not have any notion of predictive uncertainty, often leading to overconfident models guo17 (); deep_ensembles () that are prone to making mistakes. Even worse, they are very difficult to analyze, due to multiple layers of non-linear transformations that involves large number of parameters.

               
(a) Deterministic Attention retain () (b) Stochastic Attention show_attend_tell () (c) Uncertainty-aware Attention (Ours)
Figure 1: Reliability diagrams guo17 () which shows the accuracy as a function of model confidence, generated from RNNs trained for mortality risk analysis from ICU records (PhysioNet-Mortality). ECE ece () in (8) denotes Expected Calibration Error, which is the weighted-average gap between model confidence and actual accuracy. (Gap is shown in green bars.) Conventional attention models result in poorly calibrated networks while our UA yields a well-calibrated one. Such accurately calibrated networks allow us to perform reliable prediction by leveraging prediction confidence to decide whether to predict or defer prediction.

Attention mechanism Bahdanau15 () is an effective means of guiding the model to focus on a partial set of most relevant features for each input instance. It works by generating (often sparse) coefficients for the given features in an input-adaptive manner, to allocate more weights to the features that are found to be relevant for the given input. Attention mechanism has been shown to significantly improve the model performance for machine translation Bahdanau15 () and image annotation show_attend_tell () tasks. Another important feature of the attention mechanism is that it allows easy interpretation of the model via the generated attention allocations, and one recent work on healthcare domain retain () is focusing on this aspect.

Although interpretable, attention mechanisms are still limited as means of implementing safe deep learning models for safety-critical tasks, as they are not necessarily reliable. The attention strengths are commonly generated from a model that is trained in a weakly-supervised manner, and could be incorrectly allocated; thus they may not be safe to base final prediction on. To build a reliable model that can prevent itself from making critical mistakes, we need a model that knows its own limitation - when it is safe to make predictions and when it is not. However, existing attention model cannot handle this issue as they do not have any notion of predictive uncertainty. This problem is less of an issue in the conventional use of attention mechanisms, such as machine translation or image annotation, where we can often find clear link between the attended parts and the generated output. However, when working with variables that are often noisy and may not be one-to-one matched with the prediction, such as in case of risk predictions with electronic health records, the overconfident and inaccurate attentions can lead to incorrect predictions (See Figure 1).

To tackle this limitation of conventional attention mechanisms, we propose to allow the attention model to output uncertainty on each feature (or input) and further leverage them when making final predictions. Specifically, we model the attention weights as Gaussian distribution with input-dependent noise, such that the model generates attentions with small variance when it is confident about the contribution of the given features, and allocates noisy attentions with large variance to uncertain features, for each input. This input-adaptive noise can model heteroscedastic uncertainty what_uncertainty () that varies based on the instance, which in turn results in uncertainty-based attenuation of attention strength. We formulate this novel uncertainty-aware attention (UA) model under the Bayesian framework and solve it with variational inference.

We validate UA on tasks such as sepsis prediction in ICU and disease risk prediction from electronic health records (EHR) that have large degree of uncertainties in the input, on which our model outperforms the baseline attention models by large margins. Further quantitative and qualitative analysis of the learned attentions and their uncertainties show that our model can also provide richer interpretations that align well with the clinician’s interpretations. For further validation on prediction reliability, we evaluate it for the uncertainty calibration performance, and prediction under the scenario where the model can defer the decision by saying “I don’t know”, whose results show that UA yields significantly better calibrated networks that can better avoid making incorrect predictions on instances that it is uncertain, compared to baseline attention models.

Our contribution in this paper is threefold:

  • We propose a novel variational attention model with instance-dependent modeling of variance, that captures input-level uncertainty and use it to attenuate attention strengths.

  • We show that our uncertainty-aware attention yields accurate calibration of model uncertainty as well as attentions that aligns well with human interpretations.

  • We validate our model on six real-world risk prediction problems in healthcare domains, for both the original binary classification task and classification with “I don’t know" decision, and show that our model obtains significant improvements over existings attention models.

2 Related Work

Prediction reliability

There has been work on building a reliable deep learning modeltimeseries_uncertainty_uber (); bayesian_segnet (); what_uncertainty (); that is, a deep network that can avoid making incorrect predictions when it is not sufficiently certain about its prediction. To achieve this goal, a model should know the limitation in the data, and in itself. One way to quantify such limitations is by measuring the predictive uncertainty using Bayesian models. Recently, dropout_as_bayesian (); rnn_dropout (); cnn_dropout () showed that deep networks with dropout sampling dropout () can be understood as Bayesian neural networks. To obtain better calibrated dropout uncertainties, variational_dropout (); concrete_dropout () proposed to automatically learn the dropout rates with proper reparameterization tricks concrete_distribution (); vae (). While the aformentioned work mostly focus on accurate calibration of uncertainty itself, what_uncertainty () utilized dropout sampling to model predictive uncertainty in computer vision bayesian_segnet (), and also modeled label noise with learned variances, to implicitly attenuate loss for the highly uncertain instances. Our work has similar motivation, but we model the uncertainty in the input data rather than in labels. By doing so, we can accurately calibrate deep networks for improved reliability. input_uncertainty_augmentation () has a similar motivation to ours, but with different applications and approaches. There exists quite a few work about uncertainty calibration and its quantification. guo17 () showed that the modern deep networks are poorly calibrated despite their accuracies, and proposed to tune factors such as depth, width, weight decay for better calibration of the model, and deep_ensembles () proposed ensemble and adversarial training for the same objective.

Attention mechanism

The literature on the attention mechanism is vast, which includes its application to machine translation Bahdanau15 (), memory-augmented networks e2ememnet (), and for image annotation show_attend_tell (). Attention mechanisms are also used for interpretability, as in Choi et al. retain () which proposed a RNN-based attention generator for EHR that can provide attention on both the hospital visits and variables for further analysis by clincians. Attentions can be either deterministic or probabilistic, and soft (non-sparse) or hard (sparse). Some probabilistic attention models show_attend_tell () use variational inference as used in our model. However, while their direct learning of multinoulli distribution only considers whether to attend or not without consideration of variance, our attention mechanism models varying degree of uncertainty for each input by input-dependent learning of attention noise (variance).

Risk analysis from electronic health records

Our work is mainly motivated by the needs of performing reliable risk prediction with electronic health records. There exists plentiful prior work on this topic, but to mention a few, Choi et al. retain () proposed to predict heart failure risk with attention generating RNNs and Futoma et al. futoma17 () proposed to predict sepsis using RNN, preprocssing the input data using multivariate GP to resolve uneven spacing and missing entry problems.

3 Approach

We now describe our uncertainty-aware attention model. Let be a dataset containing a set of input data points and the corresponding labels, . For notational simplicity, we suppress the data index when it is clear from the context.

We first present a general framework of a stochastic attention mechanism. Let be the concatenation of intermediate features, each column of which is a length vector, from an arbitrary neural network. From , a set of random variables is conditionally generated from some distribution where the dimension of depends on the model architecture. Then, the context vector is computed as follows:

where the operator is properly defined according to the dimensionality of ; if is a scalar, it is simply the multiplication while for , it is the element-wise product. The function here produces the prediction given the context vector .

The attention could be generated either deterministically, or stochastically. The stochastic attention mechanism is proposed in show_attend_tell (), where they generate from Bernoulli distribution. This variable is learned by maximizing the evidence lower bound (ELBO) with additional regularizations for reducing variance of gradients. In show_attend_tell (), the stochastic attention is shown to perform better than the deterministic counterpart, on image annotation task.

3.1 Stochastic attention with input-adaptive Gaussian noise

Despite the performance improvement in show_attend_tell (), there are two limitations in modeling stochastic attention directly with Bernoulli (or Multinoulli) distribution as show_attend_tell () does, in our purposes:

1) The variance of Bernoulli is completely dependent on the allocation probability .

Since the variance for Bernoulli distribution is decided as , the model thus cannot generate with low variance if is around , and vice versa. To overcome such limitation, we disentangle the attention strength from the attention uncertainty so that the uncertainty could vary even with the same attention strength.

2) The vanilla stochastic attention models the noise independently of the input.

This makes it infeasible to model the amount of uncertainty for each input, which is a crucial factor for reliable machine learning. Even for the same prediction tasks and for the same set of features, the amount of uncertainty for each feature may largely vary across different instances.

To overcome these two limitations, we model the standard deviation , which is indicative of the uncertainty, as an input-adaptive function , enabling to reflect different amount of confidence the model has for each feature, for a given instance. As for distribution, we use Gaussian distribution, which is probably the most simple and efficient solution for our purpose, and also easy to implement.

We first assume that a subset of the neural network parameters , associated with generating attentions, has zero-mean isotropic Gaussian prior with precision . Then the attention scores before squashing, denoted as , are generated from conditional distribution , which is also Gaussian:

(1)

where and are mean and s.d., parameterized by . Note that and are generated from the same layer, but with different set of parameters, although we denote those parameters as in general. The actual attention is then obtained by applying some squashing function to (e.g. sigmoid or hyperbolic tangent): . For comparison, one can think of the vanilla stochastic attention of which variance is independent of inputs.

(2)

However, as we mentioned, this model cannot express different amount of uncertainties over features.

One important aspect of our model is that, in terms of graphical representation, the distribution is independent of , while the distribution is conditional on . That is, tends to capture uncertainty of model parameters (epistemic uncertainty), while reacts sensitively to uncertainty in data, varying across different input points (heteroscedastic uncertainty) what_uncertainty (). When modeled together, it has been empirically shown that the quality of uncertainty improves what_uncertainty (). Such modeling both input-agnostic and input-dependent uncertainty is especially important in risk analysis tasks in healthcare, to capture both the uncertainty from insufficient amount of clinical data (e.g. rare diseases), and the uncertainty that varies from patients to patients (e.g. sepsis).

3.2 Variational inference

We now model what we have discussed so far. Let be the set of latent variables that stands for attention weight before squashing. In neural network, the posterior distribution is usually computationally intractable since is so due to nonlinear dependency between variables. Thus, we utilize variational inference, which is an approximation method that has been shown to be successful in many applications of neural networks vae (); cvae (), along with reprameterization tricks for pathwise backpropagation variational_dropout (); concrete_dropout ().

Toward this, we first define our variational distribution as

(3)

We set to dropout approximation dropout_as_bayesian () with variational parameter . dropout_as_bayesian () showed that a neural network with Gaussian prior on its weight matrices can be approximated with variational inference, in the form of dropout sampling of deterministic weight matrices and weight decay. For the second term, we drop the dependency on (since it is not available in test time) and simply set to be equivalent to , which works well in practice (cvae, ; show_attend_tell, ).

Under the SGVB framework vae (), we maximize the evidence lower bound (ELBO):

(4)
(5)

where we approximate the expectation in (4) via Monte-Carlo sampling. The first KL term nicely reduces to regularization for with dropout approximation dropout_as_bayesian (). The second KL term vanishes as the two distributions are equivalent. Consequently, our final maximization objective is:

(6)

where we first sample random weights with dropout masks and sample such that , with a pathwise derivative function for reparameterization trick. is a tunable hyperparameter; however in practice it can be simply set to common decay shared throughout the network, including other deterministic weights.

When testing with a novel input instance , we can compute the probability of having the correct label by our model, with Monte-Carlo sampling:

(7)

where we first sample dropout masks and then sample .

Uncertainty Calibration

The quality of uncertainty from (7) can be evaluated with reliability diagram shown in Figure 1. Better calibrated uncertainties produce smaller gaps beween model confidences and actual accuracies, shown in green bars. Thus, the perfect calibration occurs when the confidences exactly matches the actual accuracies: guo17 (). Also, ece (); guo17 () proposed a summary statistic for calibration, called the Expected Calibration Error (ECE). It is the expected gap w.r.t. the distribution of model confidence (or frequency of bins):

(8)

4 Application to classification from time-series data

Our variational attention model is generic and can be applied to any generic deep neural network that leverages attention mechanism. However, in this section, we describe its application to prediction from time-series data, since our target application is risk analysis from electronic health records.

Review of the RETAIN model

As a base deep network for learning from time-series data, we consider RETAIN retain (), which is an attentional RNN model with two types of attentions–across timesteps and across features. RETAIN obtains state-of-the-art performance on risk prediction tasks from electronic health records, and is able to provide useful interpretations via learned attentions.

We now briefly review the overall structure of RETAIN. We match the notation with those in the original paper for clear reference. Suppose we are interested in a timestep . With the input embeddings , we generate two different attentions: across timesteps () and features ().

(9)
(10)
(11)

The parameters of two RNNs are collected as . From the RNN outputs and , the attention logits and are generated, followed by squashing functions and respectively. Then the generated two attentions and are multiplied back to the input embedding , followed by a convex sum up to timestep : . A final linear predictor is learned based on it: .

The most important feature of RETAIN is that it allows us to interpret what the model has learned as follows. What we are interested in is contribution, which shows ’s aggregate effect to the final prediction at time . Since RETAIN has attentions on both timesteps () and features (), the computation of aggregate contribution takes both of them into consideration when computing the final contribution of an input data point at a specific timestep: . In other words, it is a certain portion of logit for which is responsible.

Interpretation as a probabilistic model

The interpretation of RETAIN as a probabilistic model is quite straightforwrad. First, the RNN parameters (9) as gaussian latent variables (1) are approximated with MC dropout with fixed probabilities dropout_as_bayesian (); rnn_dropout (); bayesian_lstm (). The input dependent latent variables (1) simply correspond to the collection of and (10), the attention logits. The log variances of and are generated in the same way as their mean, from the output of RNNs and but with different set of parameters. Also the reparameterization trick for diagonal gaussian is simple vae (). We now maximize the ELBO (6), equipped with all the components ,,, and as in the previous section.

5 Experiments

We validate the performance of our model on various risk prediction tasks from multiple EHR datasets, for both the prediction accuracy (Section 5.3) and prediction reliability (Section 5.4).

5.1 Tasks and datasets

1) PhysioNet

This dataset physio2012 () contains 4,000 medical records from ICU111We only use the TrainingSetA, for which the labels were available. Each record contains 48 hours of records, with 155 timesteps, each of which contains 36 physiolocial signals including heart rate, repiration rate and temperature. The challenge comes with four binary classification tasks, namely, 1) Mortality prediction, 2) Length-of-stay less than 3 days: whether the patient will stay in ICU for less than three days, 3) Cardiac conditon: whether the patient will have a cardiac condition, and 4) Recovery from surgery: whether the patient was recovering from surgery.

2) Pancreatic Cancer

This dataset is a subset of an EHR database consisting of anonymized medical check-up records from 2002 to 2013, which includes around 1.5 million records. We extract patient records from this database, among which are patients diagnosed of pancreatic cancer. The task here is to predict the onsets of pancreatic cancer in 2013 using the records from 2002 to 2012 ( timesteps), that consists of 34 variables regarding general information (e.g., sex, height, past medical history, family history) as well as vital information (e.g., systolic pressure, hemoglobin level, creatinine level) and risk inducing behaviors (e.g., tobacco and alcohol consumption).

3) MIMIC-Sepsis

This is the subset of the MIMIC III dataset mimic3_ref () for sepsis prediction, which consists of 58,000 hospital admissions for 38,646 adults over 12 years. We use a subset that consists of 22,395 records of patients over age 15 and stayed in ICUs between 2001 and 2012, among which 2,624 patients are diagnosed of sepsis. We use the data from the first 48 hours after admission (24 timesteps). For features at each timestep, we select 14 sepsis-related variables including arterial blood pressure, heart rate, FiO2, and Glass Coma Score (GCS), following the clinicians’ guidelines. We use Sepsis-related Organ Failure Assessment scores (SOFA) to determine the onset of sepsis.

For all datasets, we generates five random splits of training/validation/test with the ratio of . For more detailed description of the datasets, please see supplementary file.

5.2 Baselines

We now describe our uncertainty-calibrated attention models and relevant baselines.

1) RETAIN-DA: The recurrent attention model in retain (), which uses deterministic soft attention.
2) RETAIN-SA: RETAIN model with the stochastic hard attention proposed by show_attend_tell (), that models the attention weights with multinoulli distribution, which is learned by variational inference.
3) UA-independent: The input-independent version of our uncertainty-aware attention model in (2) whose variance is modeled indepently of the input.
4) UA: Our input-dependent uncertainty-aware attention model in (1).
5) UA+: The same as UA, but with additional modeling of input-adaptive noise at the final prediction as done in what_uncertainty (), to account for output uncertainty as well.

For network configuration and hyperparameters, see supplementary file. We will also release the codes for reproduction of the results.

5.3 Evaluation of the binary classification performance

PhysioNet Pancreatic MIMIC
Mortality Stay Cardiac Recovery Cancer Sepsis
RETAIN-DA retain () 0.7652 0.02 0.8515 0.02 0.9485 0.01 0.8830 0.01 0.8528 0.01 0.7965 0.01
RETAIN-SA show_attend_tell () 0.7635 0.02 0.8412 0.02 0.9360 0.01 0.8582 0.02 0.8444 0.01 0.7695 0.02
UA-Independent 0.7764 0.01 0.8572 0.02 0.9516 0.01 0.8895 0.01 0.8533 0.03 0.8019 0.01
UA 0.7827 0.02 0.8628 0.02 0.9563 0.01 0.9049 0.01 0.8604 0.01 0.8017 0.01
UA+ 0.7770 0.02 0.8577 0.01 0.9612 0.01 0.9074 0.01 0.86380.02 0.8114 0.01
Table 1: The multi-class classification performance on the three electronic health records datasets. The reported numbers are mean AUROC and standard errors for 95% confidence interval over five random splits.

We first examine the prediction accuracy of baselines and our models in a standard setting where the model always makes a decision. Table 1 contains the accuracy of baselines and our models measured in area under the ROC curve (AUROC). We observe that UA variants significantly outperforms both RETAIN variants with either deterministic or stochastic attention mechanisms on all datasets. Note that RETAIN-SA, that generates attention from Bernoulli distribution, performs the worst. This may be because the model is primarily concerned with whether to attend or not to each feature, which makes sense when most features are irrelevant, such as with machine translation, but not in the case of clinical prediction where most of the variables are important. UA-independent performs significantly worse than UA or UA+, which demonstrates the importance of input-dependent modeling of the variance. Additional modeling of output uncertainty with UA+ yields performance gain in most cases.

MechVent DiasABP HR Temp SysABP FiO2 MAP Urine GCS
35m 5s 0 81 61 36.2 135 1 71 N/A 15
38m10s 0 75 64 36.7 94 1 74 N/A 15
38m 55s (current) 1 67 57 35.2 105 1 80 35 10
(a) RETAIN (b) RETAIN-SA (c) UA
Figure 2: Visualization of contributions for a selected patient on PhysioNet mortality prediction task. MechVent - Mechanical ventilation, DiasABP - Diastolic arterial blood pressure, HR - Heart rate, Temp - Temperature, SysABP - Systolic arterial blood pressure, FiO2 - Fractional inspired Oxygen, MAP - Mean arterial blood pressure, Urine - Urine output, GCS - Glasgow coma score. The table presents the value of physiological variables at the previous and the current timestep. Dots correspond to sampled attention weights.

Interpretability and accuracy of generated attentions

To obtain more insight, we further analyze the contribution of each feature in PhysioNet mortality task in Figure 2 for a patient at the timestep with the highest attention , with the help of a physician. The table in Figure 2 is the value of the variables at the previous checkpoints and the current timestep.

The difference between the current and the previous tmesteps is significant - the patient is applied mechanical ventilation; the body temperature, diastolic arterial blood pressure, and heart rate dropped, and GCS, which is a measure of consciousness, dropped from 15 to 10. The fact that the patient is applied mechanical ventilation, and that the GCS score is lowered, are both very important markers for assessing patient’s condition. Our model correctly attends to those two variables, with very low uncertainty. SysABP and DiasABP are variables that has cyclic change in value, and are all within normal range; however RETAIN-DA attended to these variables, perhaps due to having a deterministic model which led it to overfit. Heart rate is out of normal range (60-90), which is problematic but is not definitive, and thus UA attended to it with high variance. RETAIN-SA results in overly incorrect and noisy attention except for FiO2 that did not change its value. Attention on Urine by all models may be the artifact that comes from missing entry in the previous timestep. In this case, UA assigned high variance, which shows that it is uncertain about this prediction.

The previous example shows another advantage of our model: it provides a richer interpretations of why the model has made such predictions, compared to ones provided by deterministic or stochastic model without input-dependent modeling of uncertainty. This additional information can be taken account by clinicians when making diagnosis, and thus can help with prediction reliability.

Sensitivity Specificity
DA 75 68
UA 87 82
Table 2: Percentage of features selected from each model that match the features selected by the clinicians.

We further compared UA against RETAIN-DA for accuracy of the attentions, using variables selected meaningful by clinicians as ground truth labels (avg. variables per record), from EHRs for a male and a female patient randomly selected from 10 age groups (40s-80s), on PhysioNet-Mortality. We observe that UA generates accurate interpretations that better comply with clinicians’ intepretations (Table 2).

PhysioNet Pancreatic MIMIC
Mortality Stay Cardiac Recovery Cancer Sepsis
RETAIN-DA retain () 7.23 0.56 2.04 0.56 5.70 1.56 4.89 0.97 5.45 0.79 3.05 0.56
RETAIN-SA show_attend_tell () 7.70 0.60 3.77 0.07 8.82 0.64 5.39 0.80 9.69 3.90 5.75 0.29
UA-Independent 5.03 0.94 2.74 1.44 3.55 0.56 4.87 1.46 4.51 0.72 2.04 0.62
UA 4.22 0.82 1.43 0.53 3.33 0.96 4.46 0.73 3.61 0.55 1.78 0.41
UA+ 4.41 0.52 1.68 0.16 2.66 0.16 3.98 0.59 3.22 0.69 2.04 0.62
Table 3: Mean Expected Calibration Error (ECE) of various attention models over 5 random splits.
(a) PhysioNet (b) PhysioNet (c) PhysioNet (d) PhysioNet (e) Pancreatic (f) MIMIC
- Mortality - Stay 3 - Cardiac - Recovery Cancer - Sepsis
Figure 3: Experiments on prediction reliability. The line charts show the ratio of incorrect predictions as a function of the ratio of correct predictions for all datasets.

5.4 Evaluation of prediction reliability

Another important goal that we aimed to achieve with the modeling of uncertainty in the attention is achieving high reliability in prediction. Prediction reliability is orthogonal to prediction accuracy, and ece () showed that state-of-the-art deep networks are not reliable as they are not well-calibrated to correlate model confidence with model strength. Thus, to demonstrate the reliability of our uncertainty-aware attention, we evaluate it for the uncertainty calibration performance against baseline attention models in Table 3, using Expected Calibration Errors (ECE) ece () (Eq. (8)). UA and UA+ are significantly better calibrated than RETAIN-DA, RETAIN-SA as well as UA-independent, which shows that independent modeling of variance is essential in obtaining well-calibrated uncertanties.

Prediction with “I don’t know" option

We further evaluate the reliability of our predictive model by allowing it to say I don’t know (IDK), where the model can refrain from making a hard decision of yes or no when it is uncertain about its prediction. This ability to defer decision is crucial for predictive tasks in clinical environments, since those deferred patient records could be given a second round examination by human clinicians to ensure safety in its decision. To this end, we measure the uncertainty of each prediction by sampling the variance of the prediction using both MC-dropout and stochastic Gaussian noise over runs, and simply predict the label for the instances with standard deviation larger than some set threshold as IDK.

Note that we use RETAIN-DA with MC-Dropout rnn_dropout () as our baseline for this experiment, since RETAIN-DA is deterministic and cannot output uncertainty 222RETAIN-SA is not compared since it largely underperforms all others and is not a meaningful baseline. We report the performance of RETAIN + DA, UA, and UA+ for all tasks by plotting the ratio of incorrect predictions as a function of the ratio of correct predictions, by varying the threshold on the model confidence (See Figure 3). We observe that both UA and UA+ output much smaller ratio of incorrect predictions at the same ratio of correct predictions compared to RETAIN + DA, by saying IDK on uncertain inputs. This suggests that our models are relatively more reliable and safer to use when making decisions for prediction tasks where incorrect predictions can lead to fatal consequences. Please see supplementary file for more results and discussions on this experiment.

6 Conclusion

We proposed uncertainty-aware attention mechanism, which generates attention weights following Gaussian distribution with learned mean and variance, that are decoupled and trained in input-adaptive manner. This input-adaptive noise modeling allows to capture heteroscedastic uncertainty, or the instance-specific uncertainty, which in turn yields more accurate calibration of prediction uncertainty. We trained it using variational inference and validated on eight different tasks from three electronic health records, on which it significantly outperformed the baselines and provided more accurate and richer interpretations. Further analysis of prediction reliability shows that our model is accurately calibrated and thus can defer predictions when making prediction with “I don’t know” option. As future work, we plan to apply our model to tasks such as image annotation and machine translation.

References

  • (1) M. S. Ayhan and P. Berens. Test-time Data Augmentation for Estimation of Heteroscedastic Aleatoric Uncertainty in Deep Neural Networks. MIDL, Mar. 2018.
  • (2) D. Bahdanau, K. Cho, and Y. Bengio. Neural machine translation by jointly learning to align and translate. ICLR, 2015.
  • (3) E. Choi, M. T. Bahadori, J. Sun, J. Kulas, A. Schuetz, and W. Stewart. Retain: An interpretable predictive model for healthcare using reverse time attention mechanism. In NIPS. 2016.
  • (4) J. Futoma, S. Hariharan, and K. A. Heller. Learning to detect sepsis with a multitask gaussian process RNN classifier. In ICML, 2017.
  • (5) Y. Gal and Z. Ghahramani. A Theoretically Grounded Application of Dropout in Recurrent Neural Networks. ArXiv e-prints, 2015.
  • (6) Y. Gal and Z. Ghahramani. Bayesian Convolutional Neural Networks with Bernoulli Approximate Variational Inference. ArXiv e-prints, June 2015.
  • (7) Y. Gal and Z. Ghahramani. Dropout as a bayesian approximation: Representing model uncertainty in deep learning. In ICML, 2016.
  • (8) Y. Gal, J. Hron, and A. Kendall. Concrete dropout. In I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, editors, NIPS, 2017.
  • (9) C. Guo, G. Pleiss, Y. Sun, and K. Q. Weinberger. On calibration of modern neural networks. In ICML, 2017.
  • (10) K. He, X. Zhang, S. Ren, and J. Sun. Deep Residual Learning for Image Recognition. In CVPR, 2016.
  • (11) D. J. S. L. A. C. Ivanovitch Silva, Galan Moody and R. G. Mark. Predicting in-hospital mortality of icu patients: The physionet/computing in cardiology challenge 2012. In In CinC, 2012.
  • (12) A. E. Johnson, T. J. Pollard, L. Shen, L. wei H. Lehman, M. Feng, M. Ghassemi, B. Moody, P. Szolovits, L. A. Celi, and R. G. Mark. Mimic-iii, a freely accessible critical care database.
  • (13) A. Kendall, V. Badrinarayanan, and R. Cipolla. Bayesian SegNet: Model Uncertainty in Deep Convolutional Encoder-Decoder Architectures for Scene Understanding. ArXiv e-prints, Nov. 2015.
  • (14) A. Kendall and Y. Gal. What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision? In NIPS, 2017.
  • (15) D. P. Kingma and J. Ba. Adam: A method for stochastic optimization. CoRR, abs/1412.6980, 2014.
  • (16) D. P. Kingma, T. Salimans, and M. Welling. Variational Dropout and the Local Reparameterization Trick. ArXiv e-prints, June 2015.
  • (17) D. P. Kingma and M. Welling. Auto encoding variational bayes. In ICLR. 2014.
  • (18) A. Krizhevsky, I. Sutskever, and G. E. Hinton. ImageNet Classification with Deep Convolutional Neural Networks. In NIPS, 2012.
  • (19) B. Lakshminarayanan, A. Pritzel, and C. Blundell. Simple and scalable predictive uncertainty estimation using deep ensembles. In NIPS, pages 6405–6416, 2017.
  • (20) S. M, D. CS, S. C, and et al. The third international consensus definitions for sepsis and septic shock (sepsis-3). In JAMA, 2016.
  • (21) C. J. Maddison, A. Mnih, and Y. Whye Teh. The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables. ArXiv e-prints, Nov. 2016.
  • (22) M. P. Naeini, G. F. Cooper, and M. Hauskrecht. Obtaining well calibrated probabilities using bayesian binning. In AAAI, 2015.
  • (23) S. Purushotham, C. Meng, Z. Che, and Y. Liu. Benchmark of deep learning models on large healthcare MIMIC datasets. In CoRR, 2017.
  • (24) K. Sohn, H. Lee, and X. Yan. Learning structured output representation using deep conditional generative models. In NIPS. 2015.
  • (25) N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, and R. Salakhutdinov. Dropout: A simple way to prevent neural networks from overfitting. Journal of Machine Learning Research, 15:1929–1958, 2014.
  • (26) S. Sukhbaatar, A. Szlam, J. Weston, and R. Fergus. End-to-end memory networks. In NIPS, 2015.
  • (27) J. van der Westhuizen and J. Lasenby. Bayesian LSTMs in medicine. ArXiv e-prints, June 2017.
  • (28) K. Xu, J. L. Ba, R. Kiros, K. Cho, A. Courville, R. Salakhutdinov, R. S. Zemel, and Y. Bengio. Show, attend and tell: Neural image caption generation with visual attention. In ICML, 2015.
  • (29) L. Zhu and N. Laptev. Deep and Confident Prediction for Time Series at Uber. ArXiv e-prints, Sept. 2017.

Appendix A Detailed Description of Datasets and Experimental Setup

a.1 Datasets

MIMIC3-Sepsis

We calculated Sepsis-related Organ Failure Assessment Score(SOFA) [20] for each patient to determine the onset of sepsis: if SOFA score increases by 2 points or more within the time window, we label the patient as positive. We set the time window as 72 hours, since the current guideline of American Medical Association considers the specified period of suspected infection on sepsis as 48 hours before and up to 24 hours after the onset of sepsis [20]. The overal rate of septic patients is 16.07. Table 6 describes feature information in details. We selected features under the guidelines of physicians and, for urine outputs, we adopted the similar approach to the recent work [23]: we sum the variables representing urine.

Pancreatic Cancer

This datasets is a subset of electronic healthcare records-based database from healthcare organization, consisting of around 1.5 million records. The database contains demographic information including medical aid beneficiaries, treatmenet information, disease histories, and drug prescription records. In total, 34 features regarding vital signs, social and behavioral factors, medical history, and general information, were extracted from the database over 12 years. Total cholesterol level and fasting glucose levle were sampled after overnight fasting and systolic blood pressure and diastolic blood pressure were checked through medical examinations. Also, there were several questionnaires that are designed to identify social and behavioral risk factors, such as smoking habit, alcohol consumption, and time spent on excercise. Individual medical history was followed with drug perscription history and clinical codes of the 10th revision of the International Classification of Diseases (ICD-10). We determined patients with pancreatic cancer by identifying ICD code, C25, on examination and treatment records. On the labeling process, we exclude those who had previous pancreatic cancer-related treatment records as well as pre-existing medical history of pancreatic cancer. Table 7 describes feature information in details.

a.2 Configuration and Parameters

We trained all the models using Adam [15] optimizer with dropout regularization. We set the maximum iteration for Adam optimizer as , and for other hyperparameters, we searched for the optimal values by cross-validation, within predefined ranges as follows: Mini batch size: , learning rate: , L-2 regularization: , and dropout rate .

Appendix B Benefits of Input-adaptive Uncertainty Modeling

We conducted experiments to show the benefits of input-adaptive noise on PhysioNet-Mortality dataset. First, we intentionally corrupted the distribution of original dataset with Gaussian noise. The result shows that UA and UA+ outperform RETAIN in classification performance. Especially, when comparing measured attention weights on noisy features, UA captures 86 of noisy features, while RETAIN captures only 59 with a threshold of attention weight, 0.01. For the second experiment, we intentionally increased the original missing rate by 5, from 92 to 97, to simulate low-quality samples. As a result, UA and UA+ models outperform RETAIN in classification performance.

Gaussian Noise 97% Missing Rate
RETAIN-DA 0.7692 0.7129
UA 0.7868 0.7372
UA+ 0.7864 0.7643
Table 4: Classification performance of RETAIN and uncertainty-aware attention models on PhysioNet-Mortality dataset. The reported numbers are AUROC.
(a)RETAIN (b)UA (c)UA+
Figure 4: Uncertainty over prediction strength on PhysioNet Challenge dataset. For all models, we measured the prediction uncertainty by using MC-dropout with samples.

Appendix C Prediction with "I Don’t Know" Decision

We analyzed the predictions for PhysioNet-Mortality to address how many of the IDK predictions would have been false positives, false negatives, or true positives. The result shows that, when correct prediction rate becomes 0.7, UA mainly filters out more false negative cases, while RETAIN filters out more false positive cases. This is a promising result since preventing type II error is critical for healthcare applications.

False Positive False Negative True Positive
RETAIN-DA 14 14 8
UA 7 22 10
UA+ 8 21 9
Table 5: Number of false positives, false negatives, and true positives in IDK holder on PhysioNet-Mortality dataset.

In Figure 5, we observe that both UA and UA+ are more likely to say IDK rather than make incorrect predictions when compared against RETAIN + MC Dropout model, which suggests that they are relatively more reliable, and safer to use for making clinical decisions where incorrect predictions can lead to fatal consequences. For instance, on sepsis prediction task, UA+ made incorrect prediction only on of the instances ( for UA) while avoiding of potentially incorrect predictions based on uncertainty, when correct prediction rate becomes 0.7. On the other hand, RETAIN + MC Dropout predicted incorrectly on of the instances with IDK predictions. Considering the consequences that follow an incorrect prediction of sepsis, this is a significant difference. Furthermore, for pancreatic cancer prediction task, our model made incorrect predictions with IDK decisions, while RETAIN + MC Dropout made incorrect prediction on of instances with IDK decisions. This difference is significant considering the severe consequences an incorrect cancer prediction has on the patient.

     
 (a) PhysioNet-Mortality (b) PhysioNet-Stay (c) PhysioNet-Cardiac
     
  (d) PhysioNet-Recovery (e) Pancreatic Cancer (f) MIMIC-Sepsis
Figure 5: Experiments on prediction reliability. The stacked bar charts show the ratio of IDK and incorrect predictions, when correct prediction becomes 0.7.
Features Item-ID Name of Item
Age N/A
intime
dob
Heart rate
211
22045
Heart Rate
Heart Rate
FiO2
223835
3420
3422
190
Inspired O2 Fraction
FiO2
FiO2 [Meas]
FiO2 set
Temperature
676
678
223761
223762
Temperature C
Temperature F
Temperature Fahrenheit
Temperature Celsius
Systolic Blood Pressure
51
442
455
6701
220179
220050
Arterial BP[Systolic]
Manual BP[Systolic]
NBP[Systolic]
Arterial BP #2 [Systolic]
Non Invasive Blood Pressure[systolic]
Arterial Blood Pressure[systolic]
Diastolic Blood Pressure
8368
8440
8441
8555
220051
220180
Arterial BP[Diastolic]
Manual BP[Diastolic]
NBP[Diastolic]
Arterial BP #2[Diastolic]
Non Invasive Blood Pressure[Diastolic]
Arterial Blood Pressure[Diastolic]
PaO2
50821
50816
PO2
Oxygen
GCS - Verbal Response 223900 Verbal Response
GCS - Motor Response 223901 Motor Response
GCS - Eye Opening 220739 Eye Opening
Serum Urea Nitrogen Level 51006 Urea Nitrogen
Sodium Level 950824 Sodium Whole Blood
White Blood Cells Count
51300
51301
WBC Count
White Blood Cells
Urine Output
40055
43175
40069
40094
40715
40473
40085
40057
40056
40405
40428
40086
40096
40651
226559
226560
226561
226584
226563
226564
226565
226567
226557
226558
227488
227489
Urine Out Foley
Urine
Urine Out Void
Urine Out Condom Cath
Urine Out Suprapubic
Urine Out IleoConduit
Urine Out Incontinent
Urine Out Rt Nephrostomy
Urine Out Lt Nephrostomy
Urine Out Other
Urine Out Straight Cath
Orine Out Incontinent
Urine Out Ureteral Stent 1
Urine Out Ureteral Stent 2
Foley
Void
Condom Cath
Ileoconduit
Suprapubic
R Nephrostomy
L Nephrostomy
Straight Cath
R Ureteral Stent
L Ureteral Stent
GU Irrigant Volumne In
GU Irrigant/Urine Volume Out
Table 6: Feature information of MIMIC-Sepsis dataset.
Category Feature
Demographics
Age
Sex
Socio-Economic Status
Income Level
Type of Disability
Health Screening
Body Mass Index (BMI)
Waist Circumference
Systolic Blood Pressure
Diastolic Blood Pressure
Fasting Glucose
Total Cholesterol
Triglyceride
Hemoglobin
Urine Protein
Creatinine
HDL Cholesterol
LDL Cholesterol
Aspartate Aminotransferase
Alanine Transaminase
Gamma-Glutamyl Transferase
Family History
Liver Disease
Stroke
Heart Disease
Hypertension
Diabetes Mellitus
Cancer
Personal History
Stroke or Cerebral Infarction-related Disease
Heart Disease
Hypertension
Diabetes Mellitus
Hyperlipidemia
Tuberculosis
Social and behavioral Factor
Alcohol Consumption
Smoking Habit
Physical Exercise
Table 7: Feature information of pancreatic cancer dataset.
Comments 0
Request Comment
You are adding the first comment!
How to quickly get a good reply:
  • Give credit where it’s due by listing out the positive aspects of a paper before getting into which changes should be made.
  • Be specific in your critique, and provide supporting evidence with appropriate references to substantiate general statements.
  • Your comment should inspire ideas to flow and help the author improves the paper.

The better we are at sharing our knowledge with each other, the faster we move forward.
""
The feedback must be of minimum 40 characters and the title a minimum of 5 characters
   
Add comment
Cancel
Loading ...
199136
This is a comment super asjknd jkasnjk adsnkj
Upvote
Downvote
""
The feedback must be of minumum 40 characters
The feedback must be of minumum 40 characters
Submit
Cancel

You are asking your first question!
How to quickly get a good answer:
  • Keep your question short and to the point
  • Check for grammar or spelling errors.
  • Phrase it like a question
Test
Test description