GNNExplainer: Generating Explanations
for Graph Neural Networks
Graph Neural Networks (GNNs) are a powerful tool for machine learning on graphs. GNNs combine node feature information with the graph structure by recursively passing neural messages along edges of the input graph. However, incorporating both graph structure and feature information leads to complex models and explaining predictions made by GNNs remains unsolved. Here we propose GnnExplainer, the first general, model-agnostic approach for providing interpretable explanations for predictions of any GNN-based model on any graph-based machine learning task. Given an instance, GnnExplainer identifies a compact subgraph structure and a small subset of node features that have a crucial role in GNN’s prediction. Further, GnnExplainer can generate consistent and concise explanations for an entire class of instances. We formulate GnnExplainer as an optimization task that maximizes the mutual information between a GNN’s prediction and distribution of possible subgraph structures. Experiments on synthetic and real-world graphs show that our approach can identify important graph structures as well as node features, and outperforms alternative baseline approaches by up to 43.0% in explanation accuracy. GnnExplainer provides a variety of benefits, from the ability to visualize semantically relevant structures to interpretability, to giving insights into errors of faulty GNNs.
In many real-world applications, including social, information, chemical, and biological domains, data can be naturally modeled as graphs Cho et al. (2011); You et al. (2018); Zitnik et al. (2018). Graphs are powerful data representations but are challenging to work with because they require modeling of rich relational information as well as node feature information Zhang et al. (2018); Zhou et al. (2018). To address this challenge, Graph Neural Networks (GNNs) have emerged as state-of-the-art for machine learning on graphs, due to their ability to recursively incorporate information from neighboring nodes in the graph, naturally capturing both graph structure and node features Hamilton et al. (2017); Kipf and Welling (2016); Ying et al. (2018b); Zhang and Chen (2018).
Despite their strengths, GNNs lack transparency as they do not easily allow for a human-intelligible explanation of their predictions. Yet, the ability to understand GNN’s predictions is important and useful for several reasons: (i) it can increase trust in the GNN model, (ii) it improves model’s transparency in a growing number of decision-critical applications pertaining to fairness, privacy and other safety challenges Doshi-Velez and Kim (2017), and (iii) it allows practitioners to get an understanding of the network characteristics, identify and correct systematic patterns of mistakes made by models before deploying them in the real world.
While currently there are no methods for explaining GNNs, recent approaches for explaining other types of neural networks have taken one of two main routes. One line of work locally approximates models with simpler surrogate models, which are then probed for explanations Lakkaraju et al. (2017); Ribeiro et al. (2016); Schmitz et al. (1999). Other methods carefully examine models for relevant features and find good qualitative interpretations of high level features Chen et al. (2018b); Erhan et al. (2009); Lundberg and Lee (2017); Sundararajan et al. (2017) or identify influential input instances Koh and Liang (2017); Yeh et al. (2018). However, these approaches fall short in their ability to incorporate relational information, the essence of graphs. Since this aspect is crucial for the success of machine learning on graphs, any explanation of GNN’s predictions should leverage rich relational information provided by the graph as well as node features.
Here we propose GnnExplainer, an approach for explaining predictions made by GNNs. GnnExplainer takes a trained GNN and its prediction(s), and it returns an explanation in the form of a small subgraph of the input graph together with a small subset of node features that are most influential for the prediction(s) (Figure 1). The approach is model-agnostic and can explain predictions of any GNN on any machine learning task for graphs, including node classification, link prediction, and graph classification. It handles single- as well as multi-instance explanations. In the case of single-instance explanations, GnnExplainer explains a GNN’s prediction for one particular instance (i.e., a node label, a new link, a graph-level label). In the case of multi-instance explanations, GnnExplainer provides an explanation that consistently explains a set of instances (e.g., nodes of a given class).
GnnExplainer specifies an explanation as a rich subgraph of the entire graph the GNN was trained on, such that the subgraph maximizes the mutual information with GNN’s prediction(s). This is achieved by formulating a mean field variational approximation and learning a real-valued graph mask which selects the important subgraph of the GNN’s computation graph. Simultaneously, GnnExplainer also learns a feature mask that masks out unimportant node features (Figure 1).
We evaluate GnnExplainer on synthetic as well as real-world graphs. Experiments show that GnnExplainer provides consistent and concise explanations of GNN’s predictions. On synthetic graphs with planted network motifs, which play a role in determining node labels, we show that GnnExplainer accurately identifies the subgraphs/motifs as well as node features that determine node labels outperforming alternative baseline approaches by up to 43.0% in explanation accuracy. Further, using two real-world datasets we show how GnnExplainer can provide important domain insights by robustly identifying important graph structures and node features that influence a GNN’s predictions. Specifically, using molecular graphs and social interaction networks, we show that GnnExplainer can identify important domain-specific graph structures, such as chemical groups or ring structures in molecules, and star structures in Reddit threads. Overall, experiments demonstrate that GnnExplainer provides consistent and concise explanations for GNN-based models for different machine learning tasks on graphs.
2 Related work
Although the problem of explaining GNNs is not well-studied, the related problems of interpretability and neural debugging received substantial attention in machine learning. At a high level, we can group those interpretability methods for non-graph neural networks into two main families.
Methods in the first family formulate simple proxy models of full neural networks. This can be done in a model-agnostic way, usually by learning a locally faithful approximation around the prediction, for example through linear models Ribeiro et al. (2016) or sets of rules, representing sufficient conditions on the prediction Augasta and Kathirvalavakumar (2012); Lakkaraju et al. (2017); Zilke et al. (2016). Methods in the second family identify important aspects of the computation, for example, through feature gradients Erhan et al. (2009); Zeiler and Fergus (2014), backpropagation of neurons’ contributions to the input features Chen et al. (2018b); Shrikumar et al. (2017); Sundararajan et al. (2017), and counterfactual reasoning Kang et al. (2019). However, the saliency maps Zeiler and Fergus (2014) produced by these methods have been shown to be misleading in some instances Adebayo et al. (2018) and prone to issues like gradient saturation Shrikumar et al. (2017); Sundararajan et al. (2017). These issues are exacerbated on discrete inputs such as graph adjacency matrices since the gradient values can be very large but only on very small intervals. Because of that, such approaches are not suitable for explaining predictions made by neural networks on graphs.
Instead of creating new, inherently interpretable models, post-hoc interpretability methods Adadi and Berrada (2018); Fisher et al. (2018); Guidotti and others (2018); Hooker (2004); Koh and Liang (2017); Yeh et al. (2018) consider models as black boxes and then probe them for relevant information. However, no work has been done to leverage relational structures like graphs. The lack of methods for explaining predictions on graph-structured data is problematic, as in many cases, predictions on graphs are induced by a complex combination of nodes and paths of edges between them. For example, in some tasks, an edge is important only when another alternative path exists in the graph to form a cycle, and those two features, only when considered together, can accurately predict node labels Debnath and others (1991); Duvenaud and others (2015). Their joint contribution thus cannot be modeled as a simple linear combinations of individual contributions.
Finally, recent GNN models augment interpretability via attention mechanisms Neil and others (2018); Velickovic et al. (2018); Xie and Grossman (2018). However, although the learned edge attention values can indicate important graph structure, the values are the same for predictions across all nodes. Thus, this contradicts with many applications where an edge is essential for predicting the label of one node but not the label of another node. Furthermore, these approaches are either limited to specific GNN architectures or cannot explain predictions by jointly considering both graph structure and node feature information.
3 Formulating explanations for graph neural networks
Let denote a graph on edges and nodes that are associated with -dimensional node features , . Without loss of generality, we consider the problem of explaining a node classification task (see Section 4.4 for other tasks). Let denote a label function on nodes that maps every node in to one of classes. The GNN model is optimized on all nodes in the training set and is then used for prediction, i.e., to approximate on new nodes.
3.1 Background on graph neural networks
At layer , the update of GNN model involves three key computations Battaglia et al. (2018); Zhang et al. (2018); Zhou et al. (2018). (1) First, the model computes neural messages between every pair of nodes. The message for node pair is a function Msg of ’s and ’s representations and in the previous layer and of the relation between the nodes: (2) Second, for each node , GNN aggregates messages from ’s neighborhood and calculates an aggregated message via an aggregation method Agg Hamilton et al. (2017); Xu et al. (2019): where is neighborhood of node whose definition depends on a particular GNN variant. (3) Finally, GNN takes the aggregated message along with ’s representation from the previous layer, and it non-linearly transforms them to obtain ’s representation at layer : The final embedding for node after layers of computation is . Our GnnExplainer provides explanations for any GNN that can be formulated in terms of Msg, Agg, and Update computations.
3.2 GnnExplainer: Problem formulation
Our key insight is the observation that the computation graph of node , which is defined by the GNN’s neighborhood-based aggregation (Figure 2), fully determines all the information the GNN uses to generate prediction at node . In particular, ’s computation graph tells the GNN how to generate ’s embedding . Let us denote that computation graph by , the associated binary adjacency matrix by , and the associated feature set by . The GNN model learns a conditional distribution , where is a random variable representing labels , indicating the probability of nodes belonging to each of classes.
A GNN’s prediction is given by , meaning that it is fully determined by the model , graph structural information , and node feature information . In effect, this observation implies that we only need to consider graph structure and node features to explain (Figure 2A). Formally, GnnExplainer generates explanation for prediction as , where is a small subgraph of the computation graph. is the associated feature of , and is a small subset of node features (masked out by the mask , i.e., ) that are most important for explaining (Figure 2B).
Next we describe our approach GnnExplainer. Given a trained GNN model and a prediction (i.e., single-instance explanation, Sections 4.1 and 4.2) or a set of predictions (i.e., multi-instance explanations, Section 4.3), the GnnExplainer will generate an explanation by identifying a subgraph of the computation graph and a subset of node features that are most influential for the model ’s prediction. In the case of explaining a set of predictions, GnnExplainer will aggregate individual explanations in the set and automatically summarize it with a prototype. We conclude this section with a discussion on how GnnExplainer can be used for any machine learning task on graphs, including link prediction and graph classification (Section 4.4).
4.1 Single-instance explanations
Given a node , our goal is to identify a subgraph and the associated features that are important for the GNN’s prediction . For now, we assume that is a small subset of -dimensional node features; we will later discuss how to automatically determine which dimensions of node features need to be included in explanations (Section 4.2). We formalize the notion of importance using mutual information and formulate the GnnExplainer as the following optimization framework:
For node , quantifies the change in the probability of prediction when ’s computation graph is limited to explanation subgraph and its node features are limited to .
For example, consider the situation where . Then, if removing from strongly decreases the probability of prediction , the node is a good counterfactual explanation for the prediction at . Similarly, consider the situation where . Then, if removing an edge between and strongly decreases the probability of prediction then the absence of that edge is a good counterfactual explanation for the prediction at .
Examining Eq. (1), we see that the entropy term is constant because is fixed for a trained GNN. As a result, maximizing mutual information between the predicted label distribution and explanation is equivalent to minimizing conditional entropy , which can be expressed as follows:
Explanation for prediction is thus a subgraph that minimizes uncertainty of when the GNN computation is limited to . In effect, maximizes probability of (Figure 2). To obtain a compact explanation, we impose a constraint on ’s size as: so that has at most nodes. In effect, this implies that GnnExplainer aims to denoise by taking edges that give the highest mutual information with the prediction.
GnnExplainer’s optimization framework. Direct optimization of GnnExplainer’s objective is not tractable as has exponentially many subgraphs that are candidate explanations for . We thus consider a fractional adjacency matrix111For typed edges, we define where is the number of edge types. for subgraphs , i.e., , and enforce the subgraph constraint as: for all . This continuous relaxation can be interpreted as a variational approximation of distribution of subgraphs of . In particular, if we treat as a random graph variable, the objective in Eq. (2) becomes:
With convexity assumption, Jensen’s inequality gives the following upper bound:
In practice, due to the complexity of neural networks, the convexity assumption does not hold. However, experimentally, we found that minimizing this objective with regularization often leads to a local minimum corresponding to high-quality explanations.
To tractably estimate , we use mean-field variational approximation and decompose into a multivariate Bernoulli distribution as: . This allows us to estimate the expectation with respect to the mean-field approximation, thereby obtaining in which -th entry represents the expectation on whether edge exists. We observed empirically that this approximation together with a regularizer for promoting discreteness Ying et al. (2018b) converges to good local minima despite the non-convexity of GNNs. The conditional entropy in Equation 4 can be optimized by replacing the to be optimized by a masking of the computation graph of adjacency matrix, , where denotes the mask that we need to learn, denotes element-wise multiplication, and denotes the sigmoid that maps the mask to .
In some applications, instead of finding an explanation in terms of model’s confidence, the users care more about “why does the trained model predict a certain class label”, or “how to make the trained model predict a desired class label”. We can modify the conditional entropy objective in Equation 4 with a cross entropy objective between the label class and the model prediction222The label class is the predicted label class by the GNN model to be explained, when answering “why does the trained model predict a certain class label”. “how to make the trained model predict a desired class label” can be answered by using the ground-truth label class.. To answer these queries, a computationally efficient version of GnnExplainer’s objective, which we optimize using gradient descent, is as follows:
The masking approach is also found in Neural Relational Inference Kipf et al. (2018), albeit with different motivation and objective. Lastly, we compute the element-wise multiplication of and and remove low values in through thresholding to arrive at the explanation for the GNN model’s prediction at node .
4.2 Joint learning of graph structural and node feature information
To identify what node features are most important for prediction , GnnExplainer learns a feature selector for nodes in explanation . Instead of defining to consists of all node features, i.e., , GnnExplainer considers as a subset of features of nodes in , which are defined through a binary feature selector (Figure 2B):
where has node features that are not masked out by . Explanation is then jointly optimized for maximizing the mutual information objective:
which represents a modified objective function from Eq. (1) that considers structural and node feature information to generate an explanation for prediction .
Learning binary feature selector . We specify as , where acts as a feature mask that we need to learn. Intuitively, if a particular feature is not important, the corresponding weights in GNN’s weight matrix take values close to zero. In effect, this implies that masking the feature out does not decrease predicted probability for Conversely, if the feature is important then masking it out would decrease predicted probability. However, in some cases this approach ignores features that are important for prediction but take values close to zero. To address this issue we marginalize over all feature subsets and use a Monte Carlo estimate to sample from empirical marginal distribution for nodes in during training Zintgraf et al. (2017). Further, we use a reparametrization trick Kingma and Welling (2013) to backpropagate gradients in Eq. (7) to the feature mask . In particular, to backpropagate through a -dimensional random variable we reparametrize as: s.t. , where is a -dimensional random variable sampled from the empirical distribution and is a parameter representing the maximum number of features to be kept in the explanation.
Integrating additional constraints into explanations. To impose further properties on the explanation we can extend GnnExplainer’s objective function in Eq. (7) with regularization terms. For example, we use element-wise entropy to encourage structural and node feature masks to be discrete. Further, GnnExplainer can encode domain-specific constraints through techniques like Lagrange multiplier of constraints or additional regularization terms. We include a number of regularization terms to produce explanations with desired properties. We penalize large size of the explanation by adding the sum of all elements of the mask paramters as the regularization term.
Finally, it is important to note that each explanation must be a valid computation graph. In particular, explanation needs to allow GNN’s neural messages to flow towards node such that GNN can make prediction . Importantly, GnnExplainer automatically provides explanations that represent valid computation graphs because it optimizes structural masks across entire computation graphs. Even if a disconnected edge is important for neural message-passing, it will not be selected for explanation as it cannot influence GNN’s prediction. In effect, this implies that the explanation tends to be a small connected subgraph.
4.3 Multi-instance explanations through graph prototypes
The output of a single-instance explanation (Sections 4.1 and 4.2) is a small subgraph of the input graph and a small subset of associated node features that are most influential for a single prediction. To answer questions like “How did a GNN predict that a given set of nodes all have label ?”, we need to obtain a global explanation of class . Our goal here is to provide insight into how the identified subgraph for a particular node relates to a graph structure that explains an entire class. GnnExplainer can provide multi-instance explanations based on graph alignments and prototypes. Our approach has two stages:
First, for a given class (or, any set of predictions that we want to explain), we first choose a reference node , for example, by computing the mean embedding of all nodes assigned to . We then take explanation for reference and align it to explanations of other nodes assigned to class . Finding optimal matching of large graphs is challenging in practice. However, the single-instance GnnExplainer generates small graphs (Section 4.2) and thus near-optimal pairwise graph matchings can be efficiently computed.
Second, we aggregate aligned adjacency matrices into a graph prototype using, for example, a robust median-based approach. Prototype gives insights into graph patterns shared between nodes that belong to the same class. One can then study prediction for a particular node by comparing explanation for that node’s prediction (i.e., returned by single-instance explanation approach) to the prototype (see Appendix for more information).
4.4 GnnExplainer model extensions
Any machine learning task on graphs. In addition to explaining node classification, GnnExplainer provides explanations for link prediction and graph classification with no change to its optimization algorithm. When predicting a link , GnnExplainer learns two masks and for both endpoints of the link. When classifying a graph, the adjacency matrix in Eq. (5) is the union of adjacency matrices for all nodes in the graph whose label we want to explain. However, note that in graph classification, unlike node classification, due to the aggregation of node embeddings, it is no longer true that the explanation is necessarily a connected subgraph. Depending on application, in some scenarios such as chemistry where explanation is a functional group and should be connected, one can extract the largest connected component as the explanation.
Any GNN model. Modern GNNs are based on message passing architectures on the input graph. The message passing computation graphs can be composed in many different ways and GnnExplainer can account for all of them. Thus, GnnExplainer can be applied to: Graph Convolutional Networks Kipf and Welling (2016), Gated Graph Sequence Neural Networks Li et al. (2015), Jumping Knowledge Networks Xu et al. (2018), Attention Networks Velickovic et al. (2018), Graph Networks Battaglia et al. (2018), GNNs with various node aggregation schemes Chen et al. (2018c, a); Huang et al. (2018); Hamilton et al. (2017); Ying et al. (2018b, a); Xu et al. (2019), Line-Graph NNs Chen et al. (2019), position-aware GNN You et al. (2019), and many other GNN architectures.
Computational complexity. The number of parameters in GnnExplainer’s optimization depends on the size of computation graph for node whose prediction we aim to explain. In particular, ’s adjacency matrix is equal to the size of the mask , which needs to be learned by GnnExplainer. However, since computation graphs are typically relatively small, compared to the size of exhaustive -hop neighborhoods (e.g., 2-3 hop neighborhoods Kipf and Welling (2016), sampling-based neighborhoods Ying et al. (2018a), neighborhoods with attention Velickovic et al. (2018)), GnnExplainer can effectively generate explanations even when input graphs are large.
We begin by describing the graphs, alternative baseline approaches, and experimental setup. We then present experiments on explaining GNNs for node classification and graph classification tasks. Our qualitative and quantitative analysis demonstrates that GnnExplainer is accurate and effective in identifying explanations, both in terms of graph structure and node features.
Synthetic datasets. We construct four kinds of node classification datasets (Table 1). (1) In BA-Shapes, we start with a base Barabási-Albert (BA) graph on 300 nodes and a set of 80 five-node “house”-structured network motifs, which are attached to randomly selected nodes of the base graph. The resulting graph is further perturbed by adding random edges. Nodes are assigned to 4 classes based on their structural roles. In a house-structured motif, there are 3 types of roles: the top, middle and bottom node of the house. Therefore there are 4 different classes, corresponding to nodes at the top, middle, bottom of houses, and nodes that do not belong to a house. (2) BA-Community dataset is a union of two BA-Shapes graphs. Nodes have normally distributed feature vectors and are assigned to one of 8 classes based on their structural roles and community memberships. (3) In Tree-Cycles, we start with a base 8-level balanced binary tree and 80 six-node cycle motifs, which are attached to random nodes of the base graph. (4) Tree-Grid is the same as Tree-Cycles except that 3-by-3 grid motifs are attached to the base tree graph in place of cycle motifs.
Real-world datasets. We consider two graph classification datasets: (1) Mutag is a dataset of molecule graphs labeled according to their mutagenic effect on the Gram-negative bacterium S. typhimurium Debnath and others (1991). (2) Reddit-Binary is a dataset of graphs, each representing an online discussion thread on Reddit. In each graph, nodes are users participating in a thread, and edges indicate that one user replied to another user’s comment. Graphs are labeled according to the type of user interactions in the thread: r/IAmA and r/AskReddit contain Question-Answer interactions, while r/TrollXChromosomes and r/atheism contain Online-Discussion interactions Yanardag and Vishwanathan (2015).
Alternative baseline approaches. Many explainability methods cannot be directly applied to graphs (Section 2). Nevertheless, we here consider the following alternative approaches that can provide insights into predictions made by GNNs: (1) Grad is a gradient-based method. We compute gradient of the GNN’s loss function with respect to the adjacency matrix and the associated node features, similar to a saliency map approach. (2) Att is a graph attention GNN (GAT) Velickovic et al. (2018) that learns attention weights for edges in the computation graph, which we use as a proxy measure of edge importance. While Att does consider graph structure, it does not explain using node features and can only explain GAT models. Furthermore, in Att it is not obvious which attention weights need to be used for edge importance, since a 1-hop neighbor of a node can also be a 2-hop neighbor of the same node due to cycles. Each edge’s importance is thus computed as the average attention weight across all layers.
Setup and implementation details. For each dataset, we first train a single GNN for each dataset, and use Grad and GnnExplainer to explain the predictions made by the GNN. Note that the Att baseline requires using a graph attention architecture like GAT Velickovic et al. (2018). We thus train a separate GAT model on the same dataset and use the learned edge attention weights for explanation. Hyperparameters control the size of subgraph and feature explanations respectively, which is informed by prior knowledge about the dataset. For synthetic datasets, we set to be the size of ground truth. On real-world datasets, we set . We set for all datasets. We further fix our weight regularization hyperparameters across all node and graph classification experiments. We refer readers to the Appendix for more training details (Code and datasets are available at https://github.com/RexYing/gnn-model-explainer).
Results. We investigate questions: Does GnnExplainer provide sensible explanations? How do explanations compare to the ground-truth knowledge? How does GnnExplainer perform on various graph-based prediction tasks? Can it explain predictions made by different GNNs?
1) Quantitative analyses. Results on node classification datasets are shown in Table 1. We have ground-truth explanations for synthetic datasets and we use them to calculate explanation accuracy for all explanation methods. Specifically, we formalize the explanation problem as a binary classification task, where edges in the ground-truth explanation are treated as labels and importance weights given by explainability method are viewed as prediction scores. A better explainability method predicts high scores for edges that are in the ground-truth explanation, and thus achieves higher explanation accuracy. Results show that GnnExplainer outperforms alternative approaches by 17.1% on average. Further, GnnExplainer achieves up to 43.0% higher accuracy on the hardest Tree-Grid dataset.
2) Qualitative analyses. Results are shown in Figures 3–5. In a topology-based prediction task with no node features, e.g. BA-Shapes and Tree-Cycles, GnnExplainer correctly identifies network motifs that explain node labels, i.e. structural labels (Figure 3). As illustrated in the figures, house, cycle and tree motifs are identified by GnnExplainer but not by baseline methods. In Figure 4, we investigate explanations for graph classification task. In Mutag example, colors indicate node features, which represent atoms (hydrogen H, carbon C, etc). GnnExplainer correctly identifies carbon ring as well as chemical groups and , which are known to be mutagenic Debnath and others (1991).
Further, in Reddit-Binary example, we see that Question-Answer graphs (2nd row in Figure 4B) have 2-3 high degree nodes that simultaneously connect to many low degree nodes, which makes sense because in QA threads on Reddit we typically have 2-3 experts who all answer many different questions Kumar et al. (2018). Conversely, we observe that discussion patterns commonly exhibit tree-like patterns (2nd row in Figure 4A), since a thread on Reddit is usually a reaction to a single topic Kumar et al. (2018). On the other hand, Grad and Att methods give incorrect or incomplete explanations. For example, both baseline methods miss cycle motifs in Mutag dataset and more complex grid motifs in Tree-Grid dataset. Furthermore, although edge attention weights in Att can be interpreted as importance scores for message passing, the weights are shared across all nodes in input the graph, and as such Att fails to provide high quality single-instance explanations.
An essential criterion for explanations is that they must be interpretable, i.e., provide a qualitative understanding of the relationship between the input nodes and the prediction. Such a requirement implies that explanations should be easy to understand while remaining exhaustive. This means that a GNN explainer should take into account both the structure of the underlying graph as well as the associated features when they are available. Figure 5 shows results of an experiment in which GnnExplainer jointly considers structural information as well as information from a small number of feature dimensions333Feature explanations are shown for the two datasets with node features, i.e., Mutag and BA-Community.. While GnnExplainer indeed highlights a compact feature representation in Figure 5, gradient-based approaches struggle to cope with the added noise, giving high importance scores to irrelevant feature dimensions.
Further experiments on multi-instance explanations using graph prototypes are in Appendix.
We present GnnExplainer, a novel method for explaining predictions of any GNN on any graph-based machine learning task without requiring modification of the underlying GNN architecture or re-training. We show how GnnExplainer can leverage recursive neighborhood-aggregation scheme of graph neural networks to identify important graph pathways as well as highlight relevant node feature information that is passed along edges of the pathways. While the problem of explainability of machine-learning predictions has received substantial attention in recent literature, our work is unique in the sense that it presents an approach that operates on relational structures—graphs with rich node features—and provides a straightforward interface for making sense out of GNN predictions, debugging GNN models, and identifying systematic patterns of mistakes.
Jure Leskovec is a Chan Zuckerberg Biohub investigator. We gratefully acknowledge the support of DARPA under FA865018C7880 (ASED) and MSC; NIH under No. U54EB020405 (Mobilize); ARO under No. 38796-Z8424103 (MURI); IARPA under No. 2017-17071900005 (HFC), NSF under No. OAC-1835598 (CINES) and HDR; Stanford Data Science Initiative, Chan Zuckerberg Biohub, JD.com, Amazon, Boeing, Docomo, Huawei, Hitachi, Observe, Siemens, UST Global. The U.S. Government is authorized to reproduce and distribute reprints for Governmental purposes notwithstanding any copyright notation thereon. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the authors and do not necessarily reflect the views, policies, or endorsements, either expressed or implied, of DARPA, NIH, ONR, or the U.S. Government.
-  (2018) Peeking Inside the Black-Box: A Survey on Explainable Artificial Intelligence (XAI). IEEE Access 6, pp. 52138–52160. External Links: Cited by: §2.
-  (2018) Sanity checks for saliency maps. In NeurIPS, Cited by: §2.
-  (2012-04) Reverse Engineering the Neural Networks for Rule Extraction in Classification Problems. Neural Processing Letters 35 (2), pp. 131–150 (en). External Links: Cited by: §2.
-  (2018) Relational inductive biases, deep learning, and graph networks. arXiv:1806.01261. Cited by: §3.1, §4.4.
-  (2018) Stochastic training of graph convolutional networks with variance reduction. In ICML, Cited by: §4.4.
-  (2018) Learning to explain: an information-theoretic perspective on model interpretation. arXiv preprint arXiv:1802.07814. Cited by: §1, §2.
-  (2018) FastGCN: fast learning with graph convolutional networks via importance sampling. In ICLR, Cited by: §4.4.
-  (2019) Supervised community detection with line graph neural networks. In ICLR, Cited by: §4.4.
-  (2011) Friendship and mobility: user movement in location-based social networks. In KDD, Cited by: §1.
-  (1991) Structure-activity relationship of mutagenic aromatic and heteroaromatic nitro compounds. correlation with molecular orbital energies and hydrophobicity. Journal of Medicinal Chemistry 34 (2), pp. 786–797. Cited by: §2, §5, §5.
-  (2017) Towards A Rigorous Science of Interpretable Machine Learning. (en). Note: arXiv: 1702.08608 Cited by: §1.
-  (2015) Convolutional networks on graphs for learning molecular fingerprints. In NIPS, Cited by: §2.
-  (2009) Visualizing higher-layer features of a deep network. University of Montreal 1341 (3), pp. 1. Cited by: §1, §2.
-  (2018-01) All Models are Wrong but many are Useful: Variable Importance for Black-Box, Proprietary, or Misspecified Prediction Models, using Model Class Reliance. (en). Note: arXiv: 1801.01489 Cited by: §2.
-  (2018) A Survey of Methods for Explaining Black Box Models. ACM Comput. Surv. 51 (5), pp. 93:1–93:42. Cited by: §2.
-  (2017) Inductive representation learning on large graphs. In NIPS, Cited by: §1, §3.1, §4.4.
-  (2004) Discovering additive structure in black box functions. In KDD, Cited by: §2.
-  (2018) Adaptive sampling towards fast graph representation learning. In NeurIPS, Cited by: §4.4.
-  (2019) ExplaiNE: an approach for explaining network embedding-based link predictions. arXiv:1904.12694. Cited by: §2.
-  (2013) Auto-encoding variational bayes. In NeurIPS, Cited by: §4.2.
-  (2016) Semi-supervised classification with graph convolutional networks. In ICLR, Cited by: §1, §4.4, §4.4.
-  (2018) Neural relational inference for interacting systems. In ICML, Cited by: §4.1.
-  (2017) Understanding black-box predictions via influence functions. In ICML, Cited by: §1, §2.
-  (2018) Community interaction and conflict on the web. In WWW, pp. 933–943. Cited by: §5.
-  (2017) Interpretable & Explorable Approximations of Black Box Models. Cited by: §1, §2.
-  (2015) Gated graph sequence neural networks. arXiv:1511.05493. Cited by: §4.4.
-  (2017) A Unified Approach to Interpreting Model Predictions. In NIPS, Cited by: §1.
-  (2018) Interpretable Graph Convolutional Neural Networks for Inference on Noisy Knowledge Graphs. In ML4H Workshop at NeurIPS, Cited by: §2.
-  (2016) Why should i trust you?: explaining the predictions of any classifier. In KDD, Cited by: §1, §2.
-  (1999) ANN-DT: an algorithm for extraction of decision trees from artificial neural networks. IEEE Transactions on Neural Networks. Cited by: §1.
-  (2017) Learning Important Features Through Propagating Activation Differences. In ICML, Cited by: §2.
-  (2017) Axiomatic Attribution for Deep Networks. In ICML, Cited by: §1, §2.
-  (2018) Graph attention networks. In ICLR, Cited by: §2, §4.4, §4.4, §5, §5.
-  (2018) Crystal graph convolutional neural networks for an accurate and interpretable prediction of material properties. In Phys. Rev. Lett., Cited by: §2.
-  (2019) How powerful are graph neural networks?. In ICRL, Cited by: §3.1, §4.4.
-  (2018) Representation learning on graphs with jumping knowledge networks. In ICML, Cited by: §4.4.
-  (2015) Deep graph kernels. In KDD, pp. 1365–1374. Cited by: §5.
-  (2018) Representer point selection for explaining deep neural networks. In NeurIPS, Cited by: §1, §2.
-  (2018) Graph convolutional neural networks for web-scale recommender systems. In KDD, Cited by: §4.4, §4.4.
-  (2018) Hierarchical graph representation learning with differentiable pooling. In NeurIPS, Cited by: Appendix A, §1, §4.1, §4.4.
-  (2018) Graph convolutional policy network for goal-directed molecular graph generation. Cited by: §1.
-  (2019) Position-aware graph neural networks. In ICML, Cited by: §4.4.
-  (2014) Visualizing and Understanding Convolutional Networks. In ECCV, Cited by: §2.
-  (2018) Link prediction based on graph neural networks. In NIPS, Cited by: §1.
-  (2018) Deep Learning on Graphs: A Survey. arXiv:1812.04202. Cited by: §1, §3.1.
-  (2018) Graph Neural Networks: A Review of Methods and Applications. arXiv:1812.08434. Cited by: §1, §3.1.
-  (2016) DeepRED - Rule Extraction from Deep Neural Networks. In Discovery Science, Cited by: §2.
-  (2017) Visualizing deep neural network decisions: prediction difference analysis. In ICLR, Cited by: §4.2.
-  (2018) Modeling polypharmacy side effects with graph convolutional networks. Bioinformatics 34. Cited by: §1.
Appendix A Multi-instance explanations
The problem of multi-instance explanations for graph neural networks is challenging and an important area to study.
Here we propose a solution based on GnnExplainer to find common components of explanations for a set of 10 explanations for 10 different instances in the same label class. More research in this area is necessary to design efficient Multi-instance explanation methods. The main challenges in practice is mainly due to the difficulty to perform graph alignment under noise and variances of node neighborhood structures for nodes in the same class. The problem is closely related to finding the maximum common subgraphs of explanation graphs, which is an NP-hard problem. In the following we introduces a neural approach to this problem. However, note that existing graph libraries (based on heuristics or integer programming relaxation) to find the maximal common subgraph of graphs can be employed to replace the neural components of the following procedure, when trying to identify and align with a prototype.
The output of a single-instance GnnExplainer indicates what graph structural and node feature information is important for a given prediction. To obtain an understanding of “why is a given set of nodes classified with label ”, we want to also obtain a global explanation of the class, which can shed light on how the identified structure for a given node is related to a prototypical structure unique for its label. To this end, we propose an alignment-based multi-instance GnnExplainer.
For any given class, we first choose a reference node. Intuitively, this node should be a prototypical node for the class. Such node can be found by computing the mean of the embeddings of all nodes in the class, and choose the node whose embedding is the closest to the mean. Alternatively, if one has prior knowledge about the important computation subgraph, one can choose one which matches most to the prior knowledge.
Given the reference node for class , , and its associated important computation subgraph , we align each of the identified computation subgraphs for all nodes in class to the reference . Utilizing the idea in the context of differentiable pooling , we use the a relaxed alignment matrix to find correspondence between nodes in an computation subgraph and nodes in the reference computation subgraph . Let and be the adjacency matrix and the associated feature matrix of the to-be-aligned computation subgraph. Similarly let be the adjacency matrix and associated feature matrix of the reference computation subgraph. Then we optimize the relaxed alignment matrix , where is the number of nodes in , and is the number of nodes in as follows:
The first term in Eq. (8) specifies that after alignment, the aligned adjacency for should be as close to as possible. The second term in the equation specifies that the features should for the aligned nodes should also be close.
In practice, it is often non-trivial for the relaxed graph matching to find a good optimum for matching 2 large graphs. However, thanks to the single-instance explainer, which produces concise subgraphs for important message-passing, a matching that is close to the best alignment can be efficiently computed.
Prototype by alignment. We align the adjacency matrices of all nodes in class , such that they are aligned with respect to the ordering defined by the reference adjacency matrix. We then use median to generate a prototype that is resistent to outliers, , where is the aligned adjacency matrix representing explanation for -th node in class . Prototype allows users to gain insights into structural graph patterns shared between nodes that belong to the same class. Users can then investigate a particular node by comparing its explanation to the class prototype.
Appendix B Experiments on multi-instance explanations and prototypes
In the context of multi-instance explanations, an explainer must not only highlight information locally relevant to a particular prediction, but also help emphasize higher-level correlations across instances. These instances can be related in arbitrary ways, but the most evident is class-membership. The assumption is that members of a class share common characteristics, and the model should help highlight them. For example, mutagenic compounds are often found to have certain characteristic functional groups that such , a pair of Oxygen atoms together with a Nitrogen atom. A trained eye might notice that Figure 6 already hints at their presence. The evidence grows stronger when a prototype is generated by GnnExplainer, shown in Figure 6. The model is able to pick-up on this functional structure, and promote it as archetypal of mutagenic compounds.
Appendix C Further implementation details
Training details. We use the Adam optimizer to train both the GNN and explaination methods. All GNN models are trained for 1000 epochs with learning rate 0.001, reaching accuracy of at least 85% for graph classification datasets, and 95% for node classification datasets. The train/validation/test split is for all datasets. In GnnExplainer, we use the same optimizer and learning rate, and train for 100 - 300 epochs. This is efficient since GnnExplainer only needs to be trained on a local computation graph with nodes.
Regularization. In addition to graph size constraint and graph laplacian constraint, we further impose the feature size constraint, which constrains that the number of unmasked features do not exceed a threshold. The regularization hyperparameters for subgraph size is ; for laplacian is ; for feature explanation is . The same values of hyperparameters are used across all experiments.
Subgraph extraction. To extract the explanation subgraph , we first compute the importance weights on edges (gradients for Grad baseline, attention weights for Att baseline, and masked adjacency for GnnExplainer). A threshold is used to remove low-weight edges, and identify the explanation subgraph . The ground truth explanations of all datasets are connected subgraphs. Therefore, we identify the explanation as the connected component containing the explained node in . For graph classification, we identify the explanation by the maximum connected component of . For all methods, we perform a search to find the maximum threshold such that the explanation is at least of size . When multiple edges have tied importance weights, all of them are included in the explanation.