GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation
This paper presents a new Graph Neural Network (GNN) type using feature-wise linear modulations (FiLM). Many GNN variants propagate information along the edges of a graph by computing “messages” based only on the representation source of each edge. In GNN-FiLM, the representation of the target node of an edge is additionally used to compute a transformation that can be applied to all incoming messages, allowing feature-wise modulation of the passed information.
Experiments with GNN-FiLM as well as a number of baselines and related extensions show that it outperforms baseline methods while not being significantly slower.
|Cambridge, United Kingdom|
Learning from graph-structured data has seen explosive growth over the last few years, as graphs are a convenient formalism to model the broad class of data that has objects (treated as vertices) with some known relationships (treated as edges). This graph construction is a highly complex form of feature engineering, mapping the knowledge of a domain expert into a graph structure which can be consumed and exploited by high-capacity neural network models (Battaglia et al., 2018).
Many neural graph learning methods can be summarised as neural message passing (Gilmer et al., 2017): nodes are initialised with some representation, and then exchange information by transforming their current state (in practice with a single linear layer) and sending it as a message to all neighbours in the graph. At each node, messages are aggregated in some way and then used to update the associated node representation. In this setting, the message is entirely determined by the source node (and potentially the edge type), and the target node is not taken into consideration. A (partial) exception to this is the family of Graph Attention Networks (Veličković et al., 2018), where the agreement between source and target representation of an edge is used to determine the weight of the message in an attention architecture. However, this weight is applied to all dimensions of the message at the same time.
A simple consequence of this observation may be to simply compute messages from the pair of source and target node state. However, the linear layer commonly used to compute messages would only allow additive interactions between the representations of source and target nodes. While a more complex transformation function is conceivable, its computational cost would make the method impractical, as computation in GNN implementations is dominated by the message transformation function.
However, this need for non-trivial interaction between different information sources is a common problem in neural network design. A recent trend has been the use of hypernetworks, neural nets that compute the weights of other networks. In this setting, interaction between two signal sources is achieved by using one of them as the input to a hypernetwork, and the other as input to the computed network. While an intellectually pleasing approach, it is often impractical because the prediction of weights of non-trivial neural networks is computationally expensive. Approaches to mitigate this exist (e.g., Wu et al. (2019) handle this in natural language processing), but are often domain-specific.
A more general mitigation method is to restrict the structure of the computed network. Recently, “feature-wise linear modulation” (FiLM) was introduced in the visual question answering domain (Perez et al., 2017). Here, the hypernetwork is fed with an encoding of a question and produces an element-wise affine function that is applied to the features extracted from a picture. This compromise between expressiveness and computational feasibility has been very effective in other domains, and this article shows that it also succeeds on graph tasks.
This article explores the use of hypernetworks in learning on graphs. Sect. 2 first reviews existing GNN models to identify commonalities and differences and then introduces two new formalisms: Relational Graph Dynamic Convolutional Networks (RGDCN), which dynamically compute the neural message passing function as a linear layer, and Graph Neural Networks with Feature-wise Linear Modulation (GNN-FiLM), which combine learned message passing functions with dynamically computed elementwise affine transformations. In experiments on tasks from the literature (Sect. 3), only the FiLM-based variant performs well, improving on a range of strong baselines on three tasks.
Let be a finite (usually small) set of edge types. Then, a graph has nodes and typed edges edges , where denotes an edge from node to node of type , usually written as .
Graph Neural Networks.
As discussed above, Graph Neural Networks (GNNs) operate by propagating information along the edges of a given graph. Concretely, each node is associated with an initial representation (for example obtained from the label of that node, or by some other model component). Then, a GNN layer updates the node representations using the node representations of its neighbours in the graph, yielding representations . This process can be unrolled through time by repeatedly applying the same update function, yielding representations . Alternatively, different layers of such a GNN update mechanism can be stacked, which is intuitively similar to unrolling through time, but increases the GNN capacity by using different parameters for each timestep.
In Gated Graph Neural Networks (GGNN) (Li et al., 2016), the update rule uses the representation of the node at an earlier time step in a recurrent unit (e.g., GRU or LSTM cells), yielding the following definition.
The learnable parameters of the model are the edge-type-dependent weights and the recurrent cell parameters .
In Relational Graph Convolutional Networks (R-GCN) (Schlichtkrull et al., 2018), the gated unit is replaced by a simple non-linearity (e.g., the hyperbolic tangent).
Here, is a normalisation constant usually chosen as the number of edges of type ending in . The learnable parameters of the model are the edge-type-dependent weights . It is important to note that in this setting, the edge type set is assumed to contain a special edge type for self-loops , allowing state associated with a node to be kept.
In Graph Attention Networks (GAT) (Veličković et al., 2018), new node representations are computed from a weighted sum of neighbouring node representations as follows (slightly generalised from the original definition to support different edge types).
Here, is a learnable row vector used to weight different feature dimensions in the computation of an attention (“relevance”) score of the node representations, is the concatentation of vectors and , and refers to the weight computed by the softmax for that edge. The learnable parameters of the model are the edge-type-dependent weights and the attention parameters . In practice, GATs usually employ several attention heads that independently implement the mechanism above in parallel, using different learnable parameters. The results of the different attention heads are then concatenated after each propagation round to yield the value of .
While many more GNN variants exist, the three formalisms above are broadly representative of general trends. It is notable that in all of these models, the information passed from one node to another is based on the learned weights and the representation of the source of an edge. The representation of edge targets is only updated (in the GGNN case ), treated as another incoming signal (in the R-GCN case ), or used to weight the relevance of an edge (in the GAT case ).
2.1 Graph Hypernetworks
Hypernetworks (i.e., neural networks computing the parameters of another neural network) have been successfully applied to a number of different tasks; naturally raising the question on their applicability in the graph domain. Intuitively, a hypernetwork corresponds to a higher-order function, i.e., it can be viewed as a function computing another function. Hence, a natural idea would be to use the target of a message propagation step to compute the function computing the message; essentially allowing it to focus on features that are especially relevant for the update of the target node representation.
Relational Graph Dynamic Convolutional Networks (RGDCN)
A first attempt would be to adapt to replace the learnable message transformation by the result of some learnable function that operates on the target representation:
However, for a representation size , would need to produce a matrix of size from inputs. Hence, if implemented as a simple linear layer, would have on the order of parameters, quickly making it impractical in most contexts.
This can be somewhat mitigated by splitting the node representations into “chunks” of dimension :
The number of parameters of the model can now be reduced by tying the value of some instances of . For example, the update function for a chunk can be computed using only the corresponding chunk of the node representation , or the same update function can be applied to all “chunks” by setting . The learnable parameters of the model are only the hypernetwork parameters . This is somewhat less desirable than the related idea of Wu et al. (2019), which operates on sequences, where sharing between neighbouring elements of the sequence has an intuitive interpretation that is not applicable in the general graph setting.
Graph Neural Networks with Feature-wise Linear Modulation (GNN-FiLM)
In , the message passing layer is a linear transformation conditioned on the target node representation, focusing on separate chunks of the node representation at a time. In the extreme case in which the dimension of each chunk is 1, this method coincides with the ideas of Perez et al. (2017), who propose to use layers of element-wise affine transformations to modulate feature maps in the visual question answering setting; there, a natural language question is the input used to compute the affine transformation applied to the features extracted from a picture.
In the graph setting, we can use each node’s representation as an input that determines an element-wise affine transformation of incoming messages, allowing the model to dynamically up- and downweight features based on the information present at the target node of an edge. This procedure the yields the following update rule.
The learnable parameters of the model are both the hypernetwork parameters and the weights . In practice, implementing as a single linear layer works well.
Finally, a small implementation bug brought focus to the fact that applying the non-linearity after summing up messages from neighbouring nodes can make it harder to perform tasks such as counting the number of neighbours with a certain feature. In experiments, applying the non-linearity before aggregation as in the following update rule improved performance.
However, this means that the magnitude of node representations is now dependent on the degree of nodes in the handled graph. This can sometimes lead to instability during training, which can in turn be controlled by adding an additional layer after message passing, which can be a simple nonlinearity, a fully connected layer, or layer normalization (Ba et al., 2016) (or any combination of these).
3.1 GNN Benchmark Tasks
Due to the versatile nature of the GNN modeling formalism, many fundamentally different tasks are studied in the research area, and it should be noted that good results on one task often do not transfer over to other tasks.
This is due to the widely varying requirements of different tasks, as the following summarization of tasks from the literature should illustrate.
Cora/Citeseer/Pubmed (Sen et al., 2008): Each task consists of a single graph of nodes corresponding to documents, and undirected (sic!) edges corresponding to references. The sparse node features are a bag of words representation of the corresponding documents. The goal is to assign a subset of nodes to a small number of classes. State of the art performance on these tasks is achieved with two propagation steps along graph edges.
PPI (Zitnik & Leskovec, 2017): A protein-protein interaction dataset consisting of 24 graphs of nodes corresponding to different human tissues. Each node has 50 features selected by domain experts, and the goal is node-level classification, where each node may belong to several of the 121 classes. State of the art performance on this task requires three propagation steps.
QM9 property prediction (Ramakrishnan et al., 2014): graphs of nodes represent molecules, where nodes are heavy atoms and undirected, typed edges are bonds between these atoms, different edge types indicating single/double/etc. bonds. The goal is to regress each graph to a number of quantum chemical properties. State of the art performance on these tasks requires at least four propagation steps.
VarMisuse (Allamanis et al., 2018): graphs of nodes each represent program fragments, where nodes are tokens in the program text and different edge types represent the program’s abstract syntax tree, data flow between variables, etc. The goal is to select one of a set of candidate nodes per graph. State of the art performance requires at least 6-8 propagation steps.
Hence, tasks differ in the complexity of edges (from undirected and untyped to directed and many-typed), the size of the considered graphs, the size of the dataset, the importance of node-level vs. graph-level representations, and the number of required propagation steps.
This article reports results on the PPI and QM9 tasks, as well as preliminary results on the VarMisuse task. Preliminary experiments on the citation network data showed results that were at best comparable to the baseline methods, but changes of a random seed already led to substantial fluctuations (matching the problems with evaluation on these tasks reported by Shchur et al. (2018)).
To allow for a wider comparison, the implementation of GNN-FiLM is accompanied by implementations of a range of baseline methods (GGNN (Li et al., 2016), R-GCN (Schlichtkrull et al., 2018), GAT (Veličković et al., 2018)111Here, GAT was generalised to several edge types as shown in .) in TensorFlow. The re-implemented baseline methods were individually tested to reach performance equivalent to results reported in their respective source papers. All code for the implementation of these GNNs is released on https://github.com/Microsoft/tf-gnn-samples, together with implementations of all tasks and scripts necessary to reproduce the results reported in this paper. This includes the hyperparameter settings found by search, which are stored in tasks/default_hypers/ and selected by default on the respective tasks. The code is structured such that new GNN types can easily be tested on existing tasks, and such that new tasks can easily be added.
Early in the experiments, it became clear that the GDCN approach (Eq. ) as presented is infeasible. It is extremely sensitive to the parameter initialization, and changes to the random seed would lead to wild swings in the target metrics. Hence, no experimental results are reported for it in the following. It is nonetheless included in the article (and the implementation) to show the thought process leading to GNN-FiLM, as well as to allow other researchers to build upon this. In the following, GNN-FiLM refers to the formulation of Eq. , which performed better than the variant of Eq. across all experiments. Somewhat surprisingly, the same trick (of moving the non-linearity before the message aggregation step) did not help the other GNN types.
For most experiments, the time training a model to convergence is reported, where convergence is determined as not improving the loss on the validation set for 25 epochs. As task-specific code is shared between all models, this runtime is entirely determined by the throughput and convergence properties of the evaluated models.
3.3 Experimental Results
The models are first evaluated on the node-level classification PPI task (Zitnik & Leskovec, 2017), following the dataset split from earlier papers. Training hence used a set of 20 graphs, and validation and test steps of 2 separate graphs each. The graphs use two edge types: the dataset-provided untyped edges as well as a fresh “self-loop” edge type to allows nodes to keep state across propagation steps.
Hyperparameters for all models were selected based on results from earlier papers and some manual exploration on validation set results; this led to 3 layers (propagation steps) and a node representation size of 256. After selecting hyperparameters, all models were trained ten times with different random seeds on a NVidia V100.
Tab. 1 reports the micro-averaged F1 score on the classification task on the test graphs, with standard deviations and training times in seconds computed over the ten runs.
|Model||Avg. Micro-F1||Time (s)|
The results for GAT are slightly better than the results reported by Veličković et al. (2018), which in cursory experiments seems to be due to generalisation to different edge types (cf. ) and the subsequent use of a special self-loop edge type. However, the new GNN-FiLM improves slightly on this, while also converging to a trained model much more quickly than the GAT models.
All models were additionally evaluated on the graph-level regression task on the QM9 molecule data set (Ramakrishnan et al., 2014), considering 13 different quantum chemical properties. The molecular graphs in the dataset were split into training, validation and test data by randomly selecting 10 000 graphs for the latter two sets. Additionally, another data split without a test set was used for the hyperparameter search (see below). The graphs use 5 edge types: the dataset-provided typed edges (single, double, triple and aromatic bonds between atoms) as well as a fresh “self-loop” edge type that allows nodes to keep state across propagation steps. The evaluation differs from the setting reported by Gilmer et al. (2017), as no additional molecular information is encoded as edge features, nor are the graphs augmented by master nodes or additional virtual edges.222Adding these features is straightforward, but orthogonal to the comparison of different GNN variants.
Hyperparameters for all models were found using a staged search process. First, random search (in an author-provided, model-specific range of hyperparameters) with 500 trials was performed on the first three regression tasks. The top three configurations for each of these three tasks were then run on all 13 tasks, and the final configuration was chosen as the one with the lowest average mean absolute error across all properties, as evaluated on the validation data of that dataset split. This process lead to eight layers for all models but GGNN, which showed best performance with six layers. Furthermore, all models used residual connections connecting every second layer, and GGNN and R-GCN additionally used layer normalisation (as in ).
|mu||(; min)||(; min)||(; min)||(; min)|
|alpha||(; min)||(; min)||(; min)||(; min)|
|HOMO||(; min)||(; min)||(; min)||(; min)|
|LUMO||(; min)||(; min)||(; min)||(; min)|
|gap||(; min)||(; min)||(; min)||(; min)|
|R2||(; min)||(; min)||(; min)||(; min)|
|ZPVE||(; min)||(; min)||(; min)||(; min)|
|U0||(; min)||(; min)||(; min)||(; min)|
|U||(; min)||(; min)||(; min)||(; min)|
|H||(; min)||(; min)||(; min)||(; min)|
|G||(; min)||(; min)||(; min)||(; min)|
|Cv||(; min)||(; min)||(; min)||(; min)|
|Omega||(; min)||(; min)||(; min)||(; min)|
Each model was trained for each of the properties separately five times using different random seeds on compute nodes with NVidia P100 cards. The average results of the five runs are reported in Tab. 2, with their respective standard deviations and the average time in minutes that a training run needed.333Note that some runs diverged on the R2 task (as visible in the large standard deviation). Removing these outliers, GGNN achieved and GAT achieved . The results indicate that the new GNN-FiLM model outperforms all baselines on all sub-tasks, sometimes by substantial margins (e.g., results on U0, U, H and G), with a noticeable, but not crippling increase in runtime.
Finally, some preliminary results of the models on the Variable Misuse task of Allamanis et al. (2018) can be reported. This task requires to process a graph representing an abstraction of a program fragment and then select one of a few candidate nodes (representing program variables) based on the information of another node (representing the location to use a variable in). The experiments are performed using the released split of the dataset, which contains training graphs, validation graphs and two test sets: “SeenProjTest”, which contains graphs extracted from open source projects that also contributed data to the training and validation sets, and “UnseenProjTest”, which contains graphs extracted from completely unseen projects.
For now, preliminary results using the same hyperparameters as used by Allamanis et al. (2018) are presented, and no model-specific optimisations have been performed. The implementation of the task is also slightly simpler than in the original paper, only using the string labels of nodes for the representation, and not using the additional type information provided in the data; Nonetheless, the results in Tab. 3 are broadly in line with the results reported by the original paper ((84.0% for SeenProjTest and 74.1% UnseenProjTest)), and again show that the GNN-FiLM model outperforms the simpler models.
The author wants to thank Miltos Allamanis for the many discussions about GNNs and feedback on a draft of this article, Daniel Tarlow for helpful discussions and pointing to the FiLM idea, and Pashmina Cameron for feedback on the implementation.
- Allamanis et al. (2018) Miltiadis Allamanis, Marc Brockschmidt, and Mahmoud Khademi. Learning to represent programs with graphs. In International Conference on Learning Representations (ICLR), 2018.
- Ba et al. (2016) Lei Jimmy Ba, Ryan Kiros, and Geoffrey E. Hinton. Layer normalization. CoRR, abs/1607.06450, 2016.
- Battaglia et al. (2018) Peter W. Battaglia, Jessica B. Hamrick, Victor Bapst, Alvaro Sanchez-Gonzalez, Vinícius Flores Zambaldi, Mateusz Malinowski, Andrea Tacchetti, David Raposo, Adam Santoro, Ryan Faulkner, Çaglar Gülçehre, Francis Song, Andrew J. Ballard, Justin Gilmer, George E. Dahl, Ashish Vaswani, Kelsey Allen, Charles Nash, Victoria Langston, Chris Dyer, Nicolas Heess, Daan Wierstra, Pushmeet Kohli, Matthew Botvinick, Oriol Vinyals, Yujia Li, and Razvan Pascanu. Relational inductive biases, deep learning, and graph networks. CoRR, abs/1806.01261, 2018.
- Gilmer et al. (2017) Justin Gilmer, Samuel S Schoenholz, Patrick F Riley, Oriol Vinyals, and George E Dahl. Neural message passing for quantum chemistry. In International Conference on Machine Learning, pp. 1263–1272, 2017.
- Li et al. (2016) Yujia Li, Daniel Tarlow, Marc Brockschmidt, and Richard Zemel. Gated graph sequence neural networks. In International Conference on Learning Representations (ICLR), 2016.
- Perez et al. (2017) Ethan Perez, Florian Strub, Harm de Vries, Vincent Dumoulin, and Aaron C. Courville. FiLM: Visual reasoning with a general conditioning layer. In AAAI Conference on Artificial Intelligence, 2017.
- Ramakrishnan et al. (2014) Raghunathan Ramakrishnan, Pavlo O. Dral, Matthias Rupp, and O. Anatole Von Lilienfeld. Quantum chemistry structures and properties of 134 kilo molecules. Scientific Data, 1, 2014.
- Schlichtkrull et al. (2018) Michael Schlichtkrull, Thomas N. Kipf, Peter Bloem, Rianne van den Berg, Ivan Titov, and Max Welling. Modeling relational data with graph convolutional network. In Extended Semantic Web Conference (ESWC), 2018.
- Sen et al. (2008) Prithviraj Sen, Galileo Namata, Mustafa Bilgic, Lise Getoor, Brian Galligher, and Tina Eliassi-Rad. Collective classification in network data. AI magazine, 29, 2008.
- Shchur et al. (2018) Oleksandr Shchur, Maximilian Mumme, Aleksandar Bojchevski, and Stephan Günnemann. Pitfalls of graph neural network evaluation. CoRR, abs/1811.05868, 2018.
- Veličković et al. (2018) Petar Veličković, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Liò, and Yoshua Bengio. Graph Attention Networks. In International Conference on Learning Representations (ICLR), 2018.
- Wu et al. (2019) Felix Wu, Angela Fan, Alexei Baevski, Yann Dauphin, and Michael Auli. Pay less attention with lightweight and dynamic convolutions. In International Conference on Learning Representations (ICLR), 2019.
- Zitnik & Leskovec (2017) Marinka Zitnik and Jure Leskovec. Predicting multicellular function through multi-layer tissue networks. Bioinformatics, 33, 2017.