Interpretations are useful: penalizing explanations to align neural networks with prior knowledge
For an explanation of a deep learning model to be effective, it must provide both insight into a model and suggest a corresponding action in order to achieve some objective. Too often, the litany of proposed explainable deep learning methods stop at the first step, providing practitioners with insight into a model, but no way to act on it. In this paper, we propose contextual decomposition explanation penalization (CDEP), a method which enables practitioners to leverage existing explanation methods in order to increase the predictive accuracy of deep learning models. In particular, when shown that a model has incorrectly assigned importance to some features, CDEP enables practitioners to correct these errors by directly regularizing the provided explanations. Using explanations provided by contextual decomposition (CD) (murdoch2018beyond), we demonstrate the ability of our method to increase performance on an array of toy and real datasets.
mygraygray0.5 \definecolorcblueRGB8, 85, 153 \definecolordarkblueRGB1, 43, 112 \definecolorcgreenRGB8, 153, 83
In recent years, neural networks have demonstrated strong predictive performance across a wide variety of settings. However, in order to achieve that accuracy, they sometimes latch onto spurious correlations, leading to undesirable behavior as a result of dataset bias (winkler2019association), racial and ethnic stereotypes (GargEmbeddings), or simply overfitting. While recent work into explaining neural network predictions (murdoch2019interpretable; doshi2017towards) has demonstrated an ability to uncover the relationships learned by a model, it is still unclear how to actually alter the model in order to remove incorrect, or undesirable, relationships.
We introduce contextual decomposition explanation penalization (CDEP), a method which leverages existing explanation techniques for neural networks in order to prevent a model from learning unwanted relationships and ultimately improve predictive accuracy. Given particular importance scores, CDEP works by allowing the user to directly penalize importances of certain features, or interactions. This forces the neural network to not only produce the correct prediction, but also the correct explanation for that prediction. While we focus on contextual decomposition (CD) (murdoch2018beyond; singh2018hierarchical), which allows the penalization of both feature importances and interactions, CDEP can be readily adapted for existing interpretation techniques, as long as they are differentiable. Moreover, CDEP is a general technique, which can be applied to arbitrary neural network architectures, and is orders of magnitude faster and more memory efficient than recent gradient-based methods, allowing its use on meaningful datasets.
In order to demonstrate the effectiveness of CDEP, we conducted experiments across a wide variety of tasks. In the prediction of skin cancer from images, CDEP improves the prediction of a classifier by teaching it to ignore spurious confounding variables present in the training data. In a colored MNIST task, CDEP allows the network to focus on a digit’s shape rather than its color (with no extra human annotation needed). Finally, a toy example using text classification shows how the penalization can help a network avoid a bias towards particular words, such as those involving gender.
Many methods have been developed to help explain the learned relationships contained in a DNN. For local, or prediction-level, explanation, most prior work has focused on assigning importance to individual features, such as pixels in an image or words in a document. There are several methods that give feature-level importance for different architectures. They can be categorized as gradient-based (springenberg2014striving; sundararajan2016gradients; selvaraju2016grad; baehrens2010explain; rieger2019aggregating), decomposition-based (murdoch2017automatic; shrikumar2016not; bach2015pixel) and others (dabkowski2017real; fong2017interpretable; ribeiro2016should; zintgraf2017visualizing), with many similarities among the methods (ancona2018towards; lundberg2017unified). However, many of these methods have thus far been poorly evaluated (adebayo2018sanity; nie2018theoretical), casting doubt on their usefulness. Another line of work, which we build upon, has focused on uncovering interactions between features, in addition to feature importances, (murdoch2018beyond), and using those interactions to create a hierarchy of features displaying the model’s prediction process (singh2018hierarchical).
Uses of explanation methods
While much work has been put into developing methods for explaining DNNs, relatively little work has explored the potential to use these explanations to help build a better model. Some recent work proposes forcing models to attend to regions of the input which are known to be important (burns2018women; mitsuhara2019embedding), although it is important to note that attention is often not the same as explanation (jain2019attention). An alternative line of work proposes penalizing the gradients of a neural network to match human-provided binary annotations and shows the possibility to improve performance (ross2017right) and adversarial robustness (ross2018improving). Finally, two recent papers extend these ideas by penalizing attributions for natural language models (liu2019incorporating) and penalizing a modified gradient-based score to produce smooth attributions (erion2019learning).
Other ways to constrain DNNs
While we focus on the use of explanations to constrain the relationships learned by neural networks, other approaches for constraining neural networks have also been proposed. A computationally intensive alternative is to augment the dataset in order to prevent the model from learning undesirable relationships, through domain knowledge (bolukbasi2016man), projecting out superficial statistics (wang2019learning) or dramatically altering training images (geirhos2018imagenet). However, these processes are often not feasible, either due to their computational cost or the difficulty of constructing such an augmented data set. Adversarial training has also been explored (zhang2019interpreting). These techniques are generally limited, as they are often tied to particular datasets, and do not provide a clear link between learning about a model’s learned relationships through explanations, and subsequently correcting them.
We now introduce CDEP, which penalizes the explanations of a neural network in order to align with prior knowledge about why a model should make a prediction. To do so, for each data point it penalizes the CD scores of features, or groups of features, which a user does not want the model to learn to be important. While we focus on CD scores, which allow the penalization of interactions between features in addition to features themselves, this approach readily generalizes to other interpretation techniques, so long as they are differentiable.
3.1 Augmenting the loss function
Given a particular classification task, we want to teach a model to not only produce the correct prediction, but also to arrive at the prediction for the correct reasons. That is, we want the model to be right for the right reasons, where the right reasons are provided by the user and are dataset-dependent.
To accomplish this, CDEP modifies the objective function used to train a neural network, as displayed in Eq 1. In addition to the standard prediction loss , which teaches the model to produce the correct predictions, CDEP adds an explanation error , which teaches the model to produce the correct explanations for its predictions. In place of the prediction and labels , used in the prediction error , the explanation error uses the explanations produced by an interpretation method , along with targets provided by the user . As is common with penalization, the two losses are weighted by a hyperparameter :
The precise meanings of depend on the context. For example, in the skin cancer image classification task described in Section 4, many of the benign skin images contain band-aids, but none of the malignant images. To force the model to ignore the band-aids in making their prediction, in each image denotes the importance score of the band-aid and would be zero. These and more examples are further explored in Section 4.
3.2 Contextual decomposition (CD)
In this work, we use the CD score as the explanation function. In contrast to other interpretation methods, which focus on feature importances, CD also captures interactions between features. CD was originally designed for LSTMs (murdoch2018beyond) and subsequently extended to convolutional neural networks and arbitrary DNNs (singh2018hierarchical). For a given DNN , one can represent its output as a SoftMax operation applied to logits . These logits, in turn, are the composition of layers , such as convolutional operations or ReLU non-linearities.
Given a group of features , the CD algorithm, , decomposes the logits into a sum of two terms, and . is the importance score of the feature group , and captures contributions to not included in . The decomposition is computed by iteratively applying decompositions for each of the layers .
3.3 CDEP objective function
We now substitute the above CD scores into the generic equation in Eq 1 to arrive at the method used in this paper. While we use CD for the explanation method , other explanation methods could be readily substituted at this stage. In order to convert CD scores to probabilities, we apply a SoftMax operation to , allowing for easier comparison with the user-provided labels . We collect from the user, for each input , a collection of feature groups , , along with explanation target values , and use the loss for .
In the above, indexes each individual example in the dataset, indexes a subset of the features for which we penalize their explanations, and sums over each class. Updating the model parameters in accordance with this formulation ensures that the model not only predicts the right output but also does so for the right (aligned with prior knowledge) reasons.
3.4 Computational considerations
A similar idea to Eq 1 has been proposed in previous/concurrent work, where the choice of explanation method uses a gradient-based attribution method (ross2017right; erion2019learning). However, using such methods leads to three main complications which are solved by our approach. The first complication is the optimization process. When optimizing over attributions from a gradient-based attribution method via gradient descent, the optimizer requires the gradient of the gradient, thus requiring that all network components be twice differentiable. This process is computationally expensive and indeed optimizing it exactly involves optimizing over a differential equation. In contrast, CD attributions are calculated along with the forward pass of the network, and as a result can be optimized plainly with back-propagation using the standard single forward-pass and backward-pass per batch.
A second complication solved by the use of CD in Eq 5 is the ability to quickly finetune a pre-trained network. In many applications, particularly in transfer learning, it is common to finetune only the last few layers of a pre-trained neural network. Using CD, one can freeze early layers of the network and then finetune the last few layers of the network quickly as the activations and gradients of the frozen layers are not necessary.
Third, penalizing gradient-based methods incurs a very large memory usage. Using gradient-based methods, training requires the storage of activations and gradients for all layers of the network as well as the gradient of input (which can be omitted in normal training). Even for the simplest version, based on saliency, this more than doubles the required memory for a given batch and network size. More advanced methods proved to be completely infeasible to apply to a real-life dataset used, since the memory requirements were too high. By contrast, penalizing CD only requires a small constant amount of memory more than standard training.
The results here demonstrate the efficacy of CDEP on a variety of datasets using diverse explanation types. Sec 4.1 shows results on ignoring spurious patches in the ISIC skin cancer dataset (codella2019skin), Sec 4.2 details experiments on converting a DNN’s preference for color to a preference for shape on a variant of the MNIST dataset (lecun1998mnist), and Sec 4.3 shows experiments on text data from the Stanford Sentiment Treebank (SST) (socher2013recursive).111All models were trained in PyTorch.
4.1 Ignoring spurious signals in skin cancer diagnosis
In recent years, deep learning has achieved impressive results in diagnosing skin cancer, with predictive accuracy sometimes comparable to human doctors (esteva2017dermatologist). However, the datasets used to train these models often include spurious features which make it possible to attain high test accuracy without learning the underlying phenomena (winkler2019association). In particular, a popular dataset from ISIC (International Skin Imaging Collaboration) has colorful patches present in approximately 50% of the non-cancerous images but not in the cancerous images (codella2019skin). An unpenalized DNN learns to look for these patches as an indicator for predicting that an image is benign. We use CDEP to remedy this problem by penalizing the DNN placing importance on the patches during training.
The task in this section is to classify whether an image of a skin lesion contains (1) benign melanoma or (2) malignant melanoma. The ISIC dataset consists of 21,654 images (19,372 benign), each diagnosed by histopathology or a consensus of experts. For classification, we use a VGG16 architecture (simonyan2014very) pre-trained on the ImageNet Classification task 222Pre-trained model retrieved from \textcolorcbluetorchvision. and freeze the weights of early layers so that only the fully connected layers are trained. In order to use CDEP, the spurious patches are identified via a s imple image segmentation algorithm using a color threshold (see Sec S4).
Table 1 shows results comparing the performance of a model trained with and without CDEP. We report results on two variants of the test set. The first, which we refer to as “no patches” only contains images of the test set that do not include patches. The second also includes images with those patches. Training with CDEP improves the AUC and F1-score for both test sets.
||AUC (no patches)||F1 (no patches)||AUC (all)||F1 (all)|
|Vanilla (excluded data)||0.86||0.59||0.92||0.59|
In the first row of Table 1, the model is trained using only the data without the spurious patches, and the second row shows the model trained on the full dataset. The network trained using CDEP achieves the best AUC, surpassing both unpenalized versions. Applying our method increases the ROC AUC as well as the best F1 score. We also compared our method against the method introduced in 2017 by ross2017right (RRR). For this, we restricted the batch size to 16 (and consequently use a learning rate of ) due to memory constraints. Using RRR did not improve on the base AUC, implying that penalizing gradients is not helpful in penalizing higher-order features.333We were not able to compare against the method recently proposed in erion2019learning due to the prohibitively slow training and large memory requirements.
Fig. 3 visualize GradCAM heatmaps (uozbulak_pytorch_vis_2019; selvaraju2017grad) for an unpenalized DNN and a DNN trained with CDEP to ignore spurious patches. As expected, after penalizing with CDEP, the DNN attributes less importance to the spurious patches, regardless of their position in the image. More examples, also for cancerous images, are shown in Sec S5.
4.2 Combating inductive bias on variants of the MNIST dataset
In this section, we investigate whether we can alter which features a DNN uses to perform digit classification, using variants of the MNIST dataset (lecun1998mnist) and a standard CNN architecture for this dataset retrieved from PyTorch 444Model and training code from \textcolorcbluehttps://github.com/pytorch/examples/blob/master/mnist/main.py..
Similar to a previous study (li2019repair), we transform the MNIST dataset to include three color channels and assign each class a distinct color, as shown in Fig. 4. An unpenalized DNN trained on this biased data will completely misclassify a test set with inverted colors, dropping to 0% accuracy (see Table 2), suggesting that it learns to classify using the colors of the digits rather than their shape.
Here, we want to see if we can alter the DNN to focus on the shape of the digits rather than their color. Interestingly, this can be enforced by minimizing the contribution of pixels in isolation while maximizing the importance of groups of pixels (which can represent shapes). To do this, we add penalize the CD contribution of sampled single pixel values, following Eq 5. By minimizing the contribution of single pixels we effectively encourage the network to focus more on groups of pixels, which can represent shape.
Table 2 shows that CDEP can partially change the network’s focus on solely color to also focus on digit shape. We compare CDEP to two previously introduced explanation penalization techniques: penalization of the squared gradients (RRR) (ross2017right) and Expected Gradients (EG) (erion2019learning). None of the baselines can improve the test accuracy of the model on this task above 0.6%, whereas CDEP is able to significantly improve this accuracy to 25.5%. We show the increase of predictive accuracy with increasing penalization in Fig. S5.
|Test Accuracy||0.01 0.2||25.5 0.4||0.4 0.2||0.4 0.8|
For further comparison with previous work, we evaluate CDEP on an existing task: DecoyMNIST (erion2019learning). DecoyMNIST adds a class-indicative gray patch to a random corner of the image. This task is relatively simple, as the spurious features are not entangled with any other feature and are always at the same location (the corners). Table 3 shows that all methods perform roughly equally, recovering the base accuracy.
Table 3 also shows that CDEP runs with relatively reasonable time and memory usage. It is similar to RRR, and considerable more feasible than Expected Gradients. However, when freezing early layers of a network, CDEP very quickly becomes more efficient than other methods. Results are reported using the best penalization parameter , chosen via cross-validation on the test accuracy. More details given in Sec S3.
|Test accuracy||60.1 5.1||97.2 0.8||99.0 1.0||97.8 0.2|
|Run time/epoch (seconds)||4.7||17.1||11.2||821.0|
|Maximum GPU RAM usage (GB)||0.027||0.068||0.046||3.15|
4.3 Fixing bias in text data
To demonstrate CDEP’s effectiveness on text, we use the Stanford Sentiment Treebank (SST) dataset (socher2013recursive), an NLP benchmark dataset consisting of movie reviews with a binary sentiment (positive/negative). We inject spurious signals into the training set and train a standard LSTM 555Model and training code from \textcolorcbluehttps://github.com/clairett/pytorch-sentiment-classification. to classify sentiment from the review.
We create three variants of the SST dataset, each with different spurious signals which we aim to ignore (examples in Sec S1). In the first variant, we add indicator words for each class (positive: ’text’, negative: ’video’) at a random location in each sentence. An unpenalized DNN will focus only on those words, dropping to nearly random performance on the unbiased test set. In the second variant, we use two semantically similar words (‘the’, ‘a’) to indicate the class by using one word only in the positive and one only in the negative class. In the third case, we use ‘he’ and ‘she’ to indicate class (example in Fig 5). Since these gendered words are only present in a small proportion of the training dataset (), for this variant, we report accuracy only on the sentences in the test set that do include the pronouns (performance on the test dataset not including the pronouns remains unchanged).
Table 4 shows the test accuracy for all datasets with and without CDEP. In all scenarios, CDEP is succesfully able to improve the test accuracy by ignoring the injected spurious signals.
|Random words||56.6 5.8||75.4 0.9|
|Biased (articles)||57.8 0.8||68.2 0.8|
|Biased (gender)||64.2 3.1||78.0 3.0|
In this work we introduce a novel method to penalize neural networks to align with prior knowledge. Compared to previous work, CDEP is the first of its kind that can penalize complex features and feature interactions. Furthermore, CDEP is more computationally efficient than previous work and does not rely on backpropagation, enabling its use with more complex neural networks. We show that CDEP can be used to remove bias and improve predictive accuracy on a variety of toy and real data. The experiments here demonstrate a variety of ways to use CDEP to improve models both on real and toy datasets. CDEP is quite versatile and can be used in many more areas to incorporate the structure of domain knowledge (e.g. biology or physics). Of course, the effectiveness of CDEP depends upon the quality of the prior knowledge used to determine the explanation targets. Future work includes extending CDEP to more complex penalties, incorporating more fine-grained explanations and interactions. We hope the work here will help push the field towards a more rigorous way to use interpretability methods, a point which will become increasingly important as interpretable machine learning develops as a field (doshi2017towards; murdoch2019interpretable).
S1 Additional details about SST task
Section 4.3 shows the results for CDEP on biased variants of the SST dataset. Here we show examples of the biased sentences (for task 2 and 3 we only show sentences where the bias was present) in Figs. S3, S2 and S1. For the first task, we insert two randomly chosen words in 100% of the sentences in the positive and negative class respectively. We choose two words (“text” for the positive class and “video” for the negative class) that were not otherwise present in the data set but had a representation in Word2Vec.
For the second task, we choose to replace two common words (”the” and ”a”) in sentences where they appear (27% of the dataset). We replace the words such that one word only appears in the positive class and the other world only in the negative class. By choosing words that are semantically almost replaceable, we ensured that the normal sentence structure would not be broken such as with the first task.
For the third task we repeat the same procedure with two words (“he” and “she”) that appeared in only 2% of the dataset. This helps evaluate whether CDEP works even if the spurious signal appears only in a small section of the data set.
S2 Network architectures and training
s2.1 Network architectures
For the ISIC skin cancer task we used a pretrained VGG16 network retrieved from the PyTorch model zoo. We use SGD as the optimizer with a learning rate of 0.01 and momentum of 0.9. Preliminary experiments with Adam as the optimizer yielded poorer predictive performance.
or both MNIST tasks, we use a standard convolutional network with two convolutional channels followed by max pooling respectively and two fully connected layers:
Conv(20,5,5) - MaxPool() - Conv(50,5,5) - MaxPool - FC(256) - FC(10). The models were trained with Adam, using a weight decay of 0.001.
Penalizing explanations adds an additional hyperparameter, to the training. can either be set in proportion to the normal training loss or at a fixed rate. In this paper we did the latter. We expect that exploring the former could lead to a more stable training process. For all tasks was tested across a wide range between .
The LSTM for the SST experiments consisted of two LSTM layers with 128 hidden units followed by a fully connected layer.
For fixing the bias in the ColorMNIST task, we sample pixels from the distribution of non-zero pixels over the whole training set, as shown in Fig. S4
S3 Runtime and memory requirements of different algorithms
This section provides further details on runtime and memory requirements reported in Table 3. We compared the runtime and memory requirements of the available regularization schemes when implemented in Pytorch. EGradients runs substantially slower since it requires resampling (the authors recommended 200 samples per example) with repeated forward and backward passes.
Memory usage and runtime were tested on the DecoyMNIST task with a batch size of 64. It is expected that the exact ratios will change depending on the complexity of the used network and batch size (since constant memory usage becomes disproportionally smaller with increasing batch size).
The memory usage was read by recording the memory allocated by PyTorch. Since Expected Gradients and RRR require two forward and backward passes, we only record the maximum memory usage. We ran experiments on a single Titan X.
S4 Image segmentation for ISIC skin cancer
To obtain the binary maps of the patches for the skin cancer task, we first segment the images using SLIC¡ a common image-segmentation algorithm (achanta2012slic). Since the patches look quite distinct from the rest of the image, the patches are usually their own segment.
Subsequently we take the mean RGB and HSV values for all segments and filtered for segments which the mean was subsantially different from the typical caucasian skin tone. Since different images were different from the typical skin color in different attributes, we filtered for those images recursively. As an example, in the image shown in Fig. S6, the patch has a much higher saturation than the rest of the image. For each image we exported a map as seen in Fig. S6.
S5 Additional heatmap examples for ISIC
We show additional examples from the test set of the skin cancer task in Figs. S8 and S7. We see that the importance maps for the unregularized and regularized network are very similar for cancerous images and non-cancerous images without patch. The patches are ignored by the network regularized with CDEP.
A different spurious correlation that we noticed was that proportionally more images showing skin cancer will have a ruler next to the lesion. This is the case because doctors often want to show a reference for size if they diagnosed that the lesion is cancerous. Even though the spurious correlation is less pronounced (in a very rough cursory count, 13% of the cancerous and 5% of the benign images contain some sort of measure), the networks learnt to recognize and exploit this spurious correlation. This further highlights the need for CDEP, especially in medical settings.