Inference in Probabilistic Graphical Models by Graph Neural Networks
A useful computation when acting in a complex environment is to infer the marginal probabilities or most probable states of task-relevant variables. Probabilistic graphical models can efficiently represent the structure of such complex data, but performing these inferences is generally difficult. Message-passing algorithms, such as belief propagation, are a natural way to disseminate evidence amongst correlated variables while exploiting the graph structure, but these algorithms can struggle when the conditional dependency graphs contain loops. Here we use Graph Neural Networks (GNNs) to learn a message-passing algorithm that solves these inference tasks. We first show that the architecture of GNNs is well-matched to inference tasks. We then demonstrate the efficacy of this inference approach by training GNNs on an ensemble of graphical models and showing that they substantially outperform belief propagation on loopy graphs. Our message-passing algorithms generalize out of the training set to larger graphs and graphs with different structure.
Probabilistic graphical models provide a statistical framework for modelling conditional dependencies between random variables, and are widely used to represent complex, real-world phenomena. Given a graphical model for a distribution , one major goal is to compute marginal probability distributions of task-relevant variables at each node of the graph: given a loss function, these distributions determine the optimal estimator. Another major goal is to compute the most probable state, , or MAP (maximum a posteriori) inference.
For complex models with loopy graphs, exact inferences of these sorts is often computationally intractable, and therefore generally relies on approximate methods. One important method for computing approximate marginals is the belief propagation (BP) algorithm, which exchanges statistical information among neighboring nodes (Pearl, 1988; Wainwright et al., 2003). This algorithm performs exact inference on tree graphs, but not on graphs with cycles. Furthermore, the basic update steps in belief propagation may not have efficient or even closed-form solutions, leading researchers to construct BP variants (Sudderth et al., 2010; Ihler & McAllester, 2009; Noorshams & Wainwright, 2013) or generalizations (Minka, 2001).
In this work, we introduce end-to-end trainable inference systems based on Graph Neural Networks (GNNs) (Gori et al., 2005; Scarselli et al., 2009; Li et al., 2016), which are recurrent networks that allow complex transformations between nodes. We show how this network architecture is well-suited to message-passing inference algorithms, and have a flexibility that gives them wide applicability even in cases where closed-form algorithms are unavailable. These GNNs have vector-valued nodes that can encode probabilistic information about variables in the graphical model. The GNN nodes send and receive messages about those probabilities, and these messages are determined by canonical learned nonlinear transformations of the information sources and the statistical interactions between them. The dynamics of the GNN reflects the flow of probabilistic information throughout the graphical model, and when the model reaches equilibrium, a nonlinear decoder can extract approximate marginal probabilities or states from each node.
To demonstrate the value of these GNNs for inference in probabilistic graphical models, we create an ensemble of graphical models, train our networks to perform marginal or MAP inference, and test how well these inferences generalize beyond the training set of graphs. Our results compare quite favorably to belief propagation on loopy graphs.
2 Related Work
Several researchers have used neural networks to implement some form of probabilistic inference. (Heess et al., 2013) proposes to train a neural network that learns to map message inputs to message outputs for each message operation needed for Expectation Propagation inference, and (Lin et al., 2015) suggests learning CNNs for estimating factor-to-variable messages in a message-passing procedure.
Another related line of work is on inference machines: (Ross et al., 2011) trains a series of logistic regressors with hand-crafted features to estimate messages. (Wei et al., 2016) applied this idea to pose estimation using convolutional layers and (Deng et al., 2016) introduces a sequential inference by recurrent neural networks for the same application domain.
The most similar line of work to the approach we present here is that of GNN-based models. GNNs are essentially an extension of recurrent neural networks that operate on graph-structured inputs (Scarselli et al., 2009; Li et al., 2016). The central idea is to iteratively update hidden states at each GNN node by aggregating incoming messages that are propagated through the graph. Here, expressive neural networks model both message- and node-update functions. (Gilmer et al., 2017) recently provide a good review of several GNN variants and unify them into a model called message-passing neural networks. GNNs indeed have a similar structure as message passing algorithms used in probabilistic inference. For this reason, GNNs are powerful architectures for capturing statistical dependencies between variables of interest (Bruna et al., 2014; Duvenaud et al., 2015; Li et al., 2016; Marino et al., 2016; Li et al., 2017; Qi et al., 2017; Kipf & Welling, 2017).
In this section, we briefly review probabilistic graphical models, describe our GNN architecture, and present how the network is applied to the problem of estimating marginal probabilities and most probable states of each variable in discrete undirected graphical models.
3.1 Probabilistic graphical models
Probabilistic graphical models simplify a joint probability distribution over many variables by factorizing the distribution according to conditional independence relationships. Factor graphs are one convenient, general representation of structured probability distributions. These are undirected, bipartite graphs whose edges connect variable nodes that encode individual variables , to factor nodes that encode direct statistical interactions between groups of variables . (Some of these factors may affect only one variable.) The probability distribution is the normalized product of all factors:
Here is a normalization constant, and is a vector with components for all variable nodes connected to the factor node by an edge .
Our goal is to compute marginal probabilities or MAP states , for such graphical models. For general graphs, these computations require exponentially large resources, summing (integrating) or maximizing over all possible states except the target node: or .
Belief propagation operates on these factor graphs by constructing messages and that are passed between variable and factor nodes:
where are the neighbors of , i.e. factors that involve , and are the neighbors of , i.e. variables that are directly coupled by . The recursive, graph-based structure of these message equations leads naturally to the idea that we could describe these messages and their nonlinear updates using a graph neural network in which GNN nodes correspond to messages, as described in the next section.
Interestingly, belief propagation can also be reformulated entirely without messages: BP operations are equivalent to successively reparameterizing the factors over subgraphs of the original graphical model (Wainwright et al., 2003). This suggests that we could construct a different mapping between GNNs and graphical models, where GNN nodes correspond to factor nodes rather than messages. Interestingly, the reparameterization accomplished by BP only adjusts the univariate potentials, since the BP updates lead the multivariate coupling potentials unchanged: after the inference algorithm converges, the estimated marginal joint probability of a factor , namely , is given by
Observe that all of the messages depend only on one variable at a time, and the only term that depends on more than one variable at a time is the factor itself, , which is therefore invariant over time. Since BP does not change these interactions, to imitate the action of BP the GNNs need only to represent single variable nodes explicitly, while the nonlinear functions between nodes can account for (and must depend on) their interactions. Our experiments evaluate both of these architectures, with GNNs constructed with latent states that represent either message nodes or single variable nodes.
3.2 Binary Markov random fields
In our experiments, we focus on binary graphical models, with variables . The probability is determined by singleton factors biasing individual variables according to the vector , and pairwise factors that couple different variables according to the symmetric matrix . Together these factors produce the joint distribution
In our experiments, each graphical model’s parameters and are specified randomly, and are provided as input features for the GNN inference. We allow a variety of graph structures, ranging in complexity from tree graphs to grid graphs to fully connected graphs. The target marginals are , and MAP states are given by . For our experiments with small graphs, the true values of these targets were computed exactly by exhaustive enumeration of states. Our goal is to construct a recurrent neural network with canonical operations whose dynamics converge to these targets, and , in a manner that generalizes immediately to new graphical models.
Belief propagation in these binary graphical models updates messages from to according to
where is the set of neighboring nodes for . BP provides estimated marginals by . This message-passing structure motivates one of the two graph neural network architectures we will use below.
3.3 Graph Neural Networks
Graph Neural Networks (Gori et al., 2005; Scarselli et al., 2009; Li et al., 2016) are recurrent networks with vector-valued nodes whose states are iteratively updated by trainable nonlinear functions that depend on the states of neighbor nodes on a specified graph. The form of these functions is canonical, i.e. shared by all graph edges, but the function can also depend on properties of each edge. The function is parameterized by a neural network whose weights are shared across all edges. Eventually, the states of the nodes are interpreted by another trainable ‘readout’ network. Once trained, the entire GNN can be reused on different graphs without alteration, simply by running it on a different graph with different inputs.
Our work builds on a specific type of GNN, the Gated Graph Neural Networks (GG-NNs) (Li et al., 2016), which adds a Gated Recurrent Unit (GRU) (Cho et al., 2014) at each node to integrate incoming information with past states.
Mathematically, each node in GNN graph is associated with a -dimensional hidden state vector at time step . We initialize this hidden state to all zeros, but our results do not depend on the initial values. On every successive time step, each node sends a message to each of its neighboring nodes. We define the -dimensional vector-valued message from node to at time step by
where is a message function, here specified by a multilayer perceptron (MLP) with rectified linear units (ReLU). Note that this message function depends on the properties of each edge .
We then aggregate all incoming messages into a single message for the destination node:
where denotes the neighbors of a node . Finally, every node updates its hidden state based on the current hidden state and the aggregated message:
where is a node update function, in our case specified by another neural network, the gated recurrent unit (GRU), whose parameters are shared across all nodes. The described equations (7, 8, 9) for sending messages and updating node states define a single time step. We evaluate the graph neural network by iterating these equations for a fixed number of time steps to obtain final state vectors , and then feeding these final node states to a readout function given by another MLP with a final sigmoidal nonlinearity :
We train our GNNs using supervised learning to predict target outputs , using backpropagation through time to minimize the loss function .
3.4 Applying Graph Neural Networks to inference in graphical models
Next we apply this general GNN architecture to the task of probabilistic inference in probabilistic graphical models. We investigate two mappings between graphical models and the GNN (Figure 1). Our experiments show that both perform similarly, and much better than belief propagation.
The first mapping conforms most closely to the structure of conventional belief propagation, by using a graph for the GNN that reflects how messages depend on each other in (Eq 6). Each node in the GNN corresponds to a message between nodes and in the graphical model. GNN nodes and are connected if their corresponding message nodes are and (Figure 1b). If they are connected, the message from to is computed by . We then update its hidden state by .The readout to extract node marginals or MAP states first aggregates all GNN nodes with the same target by summation, and then applies a shared readout function, . This representation grows in size with the number of factors in the graphical model.
The second mapping uses GNN nodes to represent variable nodes in the probabilistic graphical model, and does not provide any hidden states to update the factor nodes (Figure 1c). These factors still influence the inference, since the parameters , , and are passed into the message function on each iteration (Eq. 7). However, this avoids spending representational power on properties that may not change due to the invariances of tree-based reparameterization. In this mapping, the readout is generated directly from the hidden state of the corresponding GNN node (Eq. 10).
In both mappings, we optimize our networks to minimize the total cross-entropy loss between the exact target ( for marginals or for MAP) and the GNN estimates .
The message functions in both mappings receive external inputs about the couplings between edges, which is necessary for GNNs to infer the correct marginals or MAP state. Most importantly, the message function depends on the hidden states of both source and destination nodes at the previous time step. This added flexibility is suggested by the expectation propagation algorithm (Minka, 2001) where, at each iteration, inference proceeds by first removing the previous estimate from the destination node and then updating based on the source distribution.
4.1 Experimental design
Our experiments test how well graph neural networks trained on a diverse set of small graph structures perform on inference tasks. In each experiment we test two types of GNNs, one representing variable nodes (node-GNN) and the other representing message nodes (msg-GNN). We examine generalization under four conditions (Table 1): to unseen graphs of the same structure (I, II), and to completely contrasting random graphs (III, IV). These graphs may be the same size (I, III) or larger (II, IV). For each condition, we examine performance in estimating marginal probabilities and the MAP state.
More specifically, our GNNs are trained on 100 graphical models for each of classic graphs of size (Figures 2a-b). For each graphical model, we sample coupling strengths from a normal distribution, , and sample biases from . Our simulated data comprise training models, validation models, and test models. All of these graphical models are small enough that ground truth marginals and MAP states can be computed exactly by enumeration.
We train GNNs using ADAM (Kingma & Ba, 2014) with a learning rate of until the validation error saturates: we use early stopping with a window size of . The GNN nodes’ hidden states and messages both have dimensions. In all experiments, messages propagate for time steps. All the MLPs in the message function and readout function have two hidden layers with units each, and use ReLU nonlinearities.
4.2 Within-Set generalization
To understand the properties of our learned GNN, we evaluate it on different graph datasets than the ones they are trained on. In condition I, test graphs had the same size and structure as training graphs, but the values of singleton and edge potentials differed. We then compared the GNN inferences against the ground truth, as well as against inferences drawn by BP. When tested on acyclic graphs, BP is exact, but our GNNs show impressive accuracy as well (Figures 2c-e). However, as the test graphs became loopier, BP worsened substantially while the GNN inference maintained strong performance (Figures 2c-e).
4.3 Out-of-Set generalization
After training our GNNs on the graph structures in condition I, we froze their parameters, and tested these GNNs on a broader set of graphs.
In condition II (Table 1), we increased the graph size from to variables while retaining the graph structures of the training set. In this scenario, scatter plots of estimated versus true marginals show that the GNN still outperforms BP in all of the loopy graphs, except for the case of graphs with a single loop (Figure 3a). We quantify this performance for BP and the GNNs by the average Kullback-Leibler divergence across the entire set of test graphs with the small and large number of nodes. We find that performance of BP and both GNNs degrades as the graphs grow. However, except for the msg-GNN tested on nearly fully-connected graphs, the GNNs perform far better than BP, with improvements over an order of magnitude better for graphs with many loops (Figure 3a–b).
To investigate how GNNs generalize to the networks of a different size and structure, we constructed connected random graphs , also known as Erdős-Rényi graphs (1959), and systematically changed the connectivity by increasing the edge probability from (sparse) to (dense) for smaller and larger graphs (Conditions III & IV, Figures 3c–d). Our GNNs clearly ourperform BP irrespective of the size and structure of random graphs, although both inference methods show a size- and connectivity-dependent decline in accuracy (Figure 3e).
4.4 Convergence of inference dynamics
Past work provides some insight into the dynamics and convergence properties of BP (Weiss & Freeman, 2000; Yedidia et al., 2001; Tatikonda & Jordan, 2002). For comparison, we examine how GNN node hidden states change over time, by collecting the distances between successive node states, . Despite some variability, the mean distance decreases with time independently of graph topologies and size, which suggests reasonable convergence of the GNN inferences (Figure 4), although the rate and final precision of convergence vary depending on graph structures.
4.5 MAP Estimation
We also apply our GNN framework to the task of MAP estimation, using the same graphical models, but now minimizing the cross entropy loss between a delta function on the true MAP target and sigmoidal outputs of GNNs. As in the marginalization experiments, the node-GNN slightly outperformed the msg-GNN computing the MAP state, and both significantly outperform BP (the max-product variant, sometimes called belief revision (Pearl, 1988)) in these generalization tasks (Figure 5).
Our experiments demonstrated that Graph Neural Networks provide a flexible method for learning to perform inference in probabilistic graphical models. We showed that the learned representations and nonlinear transformations operating on the edges of the graphical model do generalize to somewhat larger graphs, even to those with different structure. These results support GNNs as an excellent framework for solving difficult inference tasks.
The reported experiments demonstrated successes on small, binary graphical models. Future experiments will consider training and testing on larger and more diverse graphs, as well as on broader classes of graphical models with non-binary variables and more interesting sufficient statistics for nodes and factors. We expect that as the training set grows larger, the generalization abilities will correspondingly increase, and the resultant algorithm can be evaluated for useful regularities.
The flexibility of the nonlinearities implemented within graph neural networks should allow us to apply these learned inference techniques on a much wider range of problems than the binary-valued pairwise Markov random fields used here. Continuous-valued distributions may benefit especially from this approach, since exact marginals are not available even in many simple cases. Moreover, the exponential family is not even closed under marginalization. Nonetheless, it may be useful to continually approximate the iteratively improving marginals by a target exponential family. This is the approach taken by Expectation Propagation (Minka, 2001), which uses optimization to find the best approximation. However, this is likely to be more computationally burdensome at inference time than learning a canonical neural network to perform that task for a range of inputs.
We examined two possible representations of graphical models within graph neural networks, using variable nodes and message nodes. Interestingly, our experiments do not reveal any benefit of the more expensive representations for each factor node. This was expected based on theoretical arguments from examining invariances of belief propagation, but these invariances are a direct consequence of BP’s assumption of tree graphs, so a richer structure could in principle perform better. One such possible structure could map GNN nodes to factor nodes, similar to the message graph (Figure 1b), but with fewer constraints on information flow.
Currently we are using backpropagation through time to learn the GNN parameters. This is also a computationally expensive procedure, but we are now finding that recurrent backpropagation (Almeida, 1987; Pineda, 1987) can be a more efficient method for training GNNs that avoids the onerous memory and computation requirements.
Belief propagation assumes that all incoming messages are independent, which is a correct assumption for a tree graph. This assumption is still true when BP is run iteratively on all nodes in the graphical model — even though, in such an update schedule, information could flow in loops of length two () as the messages pass backwards and forwards through the tree graph. The BP algorithm compensates for these loops when aggregating messages by excluding messages whose source was the current target node.
The reason that loopy belief propagation is incorrect on graphs with cycles is that information flowing around longer loops is wrongly integrated as if it were independent evidence. The loopy BP messages are not independent, but are correlated through the loop, and thus their evidence is misweighted.
We hypothesize that a major reason our GNNs outperform BP in loopy graphs is that the latent states can store outbound information, and use these memories to properly discount the contributions of incoming messages that originated from the current target node. This function is enabled by the gated recurrent units that synthesize incoming messages in our GNN. This theory predicts that longer memories, and therefore higher-dimensional latent states, are required for graphs with longer loops of equal total strength. Future experiments will test this hypothesis, and thereby guide novel neural network architectures for improved probabilistic inference.
Three main threads of artificial intelligence offer complementary advantages: probabilistic or statistical inference, neural networks, and symbolic reasoning. Combining the strengths of all three may provide the best route forward to general AI. Here we proposed combined probabilistic inference with neural networks: by using neural networks’ flexibility in approximating functions, with the canonical nonlinear structure of inference problems and the sparsity of direct interactions for graphical models, we provide better performance in example problems. These and other successes should encourage further exploration.
- Almeida (1987) Almeida, Luis B. A learning rule for asynchronous perceptrons with feedback in a combinatorial environment. In Proceedings, 1st First International Conference on Neural Networks, volume 2, pp. 609–618. IEEE, 1987.
- Bruna et al. (2014) Bruna, Joan, Zaremba, Wojciech, Szlam, Arthur, and LeCun, Yann. Spectral networks and locally connected networks on graphs. ICLR, 2014.
- Cho et al. (2014) Cho, Kyunghyun, Van Merriënboer, Bart, Gulcehre, Caglar, Bahdanau, Dzmitry, Bougares, Fethi, Schwenk, Holger, and Bengio, Yoshua. Learning phrase representations using rnn encoder-decoder for statistical machine translation. arXiv preprint arXiv:1406.1078, 2014.
- Deng et al. (2016) Deng, Zhiwei, Vahdat, Arash, Hu, Hexiang, and Mori, Greg. Structure inference machines: Recurrent neural networks for analyzing relations in group activity recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4772–4781, 2016.
- Duvenaud et al. (2015) Duvenaud, David K, Maclaurin, Dougal, Iparraguirre, Jorge, Bombarell, Rafael, Hirzel, Timothy, Aspuru-Guzik, Alán, and Adams, Ryan P. Convolutional networks on graphs for learning molecular fingerprints. In Advances in neural information processing systems, pp. 2224–2232, 2015.
- Erdős & Rényi (1959) Erdős, Paul and Rényi, Alfréd. On random graphs, i. Publicationes Mathematicae (Debrecen), 6:290–297, 1959.
- Gilmer et al. (2017) Gilmer, Justin, Schoenholz, Samuel S, Riley, Patrick F, Vinyals, Oriol, and Dahl, George E. Neural message passing for quantum chemistry. ICML, 2017.
- Gori et al. (2005) Gori, Marco, Monfardini, Gabriele, and Scarselli, Franco. A new model for learning in graph domains. In Neural Networks, 2005. IJCNN’05. Proceedings. 2005 IEEE International Joint Conference on, volume 2, pp. 729–734. IEEE, 2005.
- Heess et al. (2013) Heess, Nicolas, Tarlow, Daniel, and Winn, John. Learning to pass expectation propagation messages. In Advances in Neural Information Processing Systems, pp. 3219–3227, 2013.
- Ihler & McAllester (2009) Ihler, Alexander and McAllester, David. Particle belief propagation. In Artificial Intelligence and Statistics, pp. 256–263, 2009.
- Kingma & Ba (2014) Kingma, Diederik and Ba, Jimmy. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
- Kipf & Welling (2017) Kipf, Thomas N and Welling, Max. Semi-supervised classification with graph convolutional networks. ICLR, 2017.
- Li et al. (2017) Li, Ruiyu, Tapaswi, Makarand, Liao, Renjie, Jia, Jiaya, Urtasun, Raquel, and Fidler, Sanja. Situation recognition with graph neural networks. arXiv preprint arXiv:1708.04320, 2017.
- Li et al. (2016) Li, Yujia, Tarlow, Daniel, Brockschmidt, Marc, and Zemel, Richard. Gated graph sequence neural networks. ICLR, 2016.
- Lin et al. (2015) Lin, Guosheng, Shen, Chunhua, Reid, Ian, and van den Hengel, Anton. Deeply learning the messages in message passing inference. In Advances in Neural Information Processing Systems, pp. 361–369, 2015.
- Marino et al. (2016) Marino, Kenneth, Salakhutdinov, Ruslan, and Gupta, Abhinav. The more you know: Using knowledge graphs for image classification. arXiv preprint arXiv:1612.04844, 2016.
- Minka (2001) Minka, Thomas P. Expectation propagation for approximate bayesian inference. In Proceedings of the Seventeenth conference on Uncertainty in artificial intelligence, pp. 362–369. Morgan Kaufmann Publishers Inc., 2001.
- Noorshams & Wainwright (2013) Noorshams, Nima and Wainwright, Martin J. Stochastic belief propagation: A low-complexity alternative to the sum-product algorithm. IEEE Transactions on Information Theory, 59(4):1981–2000, 2013.
- Pearl (1988) Pearl, Judea. Probabilistic reasoning in intelligent systems: networks of plausible inference. Morgan Kaufmann, 1988.
- Pineda (1987) Pineda, Fernando J. Generalization of back-propagation to recurrent neural networks. Physical review letters, 59(19):2229, 1987.
- Qi et al. (2017) Qi, Xiaojuan, Liao, Renjie, Jia, Jiaya, Fidler, Sanja, and Urtasun, Raquel. 3d graph neural networks for rgbd semantic segmentation. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5199–5208, 2017.
- Ross et al. (2011) Ross, Stephane, Munoz, Daniel, Hebert, Martial, and Bagnell, J Andrew. Learning message-passing inference machines for structured prediction. In Computer Vision and Pattern Recognition (CVPR), 2011 IEEE Conference on, pp. 2737–2744. IEEE, 2011.
- Scarselli et al. (2009) Scarselli, Franco, Gori, Marco, Tsoi, Ah Chung, Hagenbuchner, Markus, and Monfardini, Gabriele. The graph neural network model. IEEE Transactions on Neural Networks, 20(1):61–80, 2009.
- Sudderth et al. (2010) Sudderth, Erik B, Ihler, Alexander T, Isard, Michael, Freeman, William T, and Willsky, Alan S. Nonparametric belief propagation. Communications of the ACM, 53(10):95–103, 2010.
- Tatikonda & Jordan (2002) Tatikonda, Sekhar C and Jordan, Michael I. Loopy belief propagation and gibbs measures. In Proceedings of the Eighteenth conference on Uncertainty in artificial intelligence, pp. 493–500. Morgan Kaufmann Publishers Inc., 2002.
- Wainwright et al. (2003) Wainwright, Martin J, Jaakkola, Tommi S, and Willsky, Alan S. Tree-based reparameterization framework for analysis of sum-product and related algorithms. IEEE Transactions on information theory, 49(5):1120–1146, 2003.
- Wei et al. (2016) Wei, Shih-En, Ramakrishna, Varun, Kanade, Takeo, and Sheikh, Yaser. Convolutional pose machines. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4724–4732, 2016.
- Weiss & Freeman (2000) Weiss, Yair and Freeman, William T. Correctness of belief propagation in gaussian graphical models of arbitrary topology. In Advances in neural information processing systems, pp. 673–679, 2000.
- Yedidia et al. (2001) Yedidia, Jonathan S, Freeman, William T, and Weiss, Yair. Generalized belief propagation. In Advances in neural information processing systems, pp. 689–695, 2001.