SAG-VAE: End-to-end Joint Inference of Data Representations and Feature Relations

SAG-VAE: End-to-end Joint Inference of Data Representations and Feature Relations


Variational Autoencoders (VAEs) are powerful in data representation inference, but it cannot learn relations between features with its vanilla form and common variations. The ability to capture relations within data can provide the much needed inductive bias necessary for building more robust Machine Learning algorithms with more interpretable results. In this paper, inspired by recent advances in relational learning using Graph Neural Networks, we propose the Self-Attention Graph Variational AutoEncoder (SAG-VAE) network which can simultaneously learn feature relations and data representations in an end-to-end manner. SAG-VAE is trained by jointly inferring the posterior distribution of two types of latent variables, which denote the data representation and a shared graph structure, respectively. Furthermore, we introduce a novel self-attention graph network that improves the generative capabilities of SAG-VAE by parameterizing the generative distribution allowing SAG-VAE to generate new data via graph convolution, while still trainable via backpropagation. A learnable relational graph representation enhances SAG-VAE’s robustness to perturbation and noise, while also providing deeper intuition into model performance. Experiments based on graphs show that SAG-VAE is capable of approximately retrieving edges and links between nodes based entirely on feature observations. Finally, results on image data illustrate that SAG-VAE is fairly robust against perturbations in image reconstruction and sampling.

1 Introduction

In practice, data often comes with complex relations between features which are not explicitly visible, and extracting this structural information has been a crucial, yet challenging, task in the field of Machine Learning. Recently, renewed interest in relational and structure learning has been largely driven by the development of new end-to-end Neural Network and Deep Learning frameworks [2, 37, 33], with multiple promising results reported. This renewed drive in relational structure inference using Neural Networks can be partially attributed to current efforts to overcome the limited generalization capabilities of Deep Learning [1]. More importantly, learning the relational structure with Neural Network models has several inherent advantages: strong and efficient parameterization ability of Deep Learning can extract essential relational information and perform large-scale inference, which are considered difficult with other learning algorithms.

Recently, research in relational learning using Neural Networks has largely focused on sequential generation/prediction of dynamical systems, while static data has been largely ignored [22, 17, 28]. At their core, these algorithms use either one or a combination of Graph Neural Networks (GNNs) [29, 18, 36] and Variational Autoencoders (VAEs) [16]. The former provide a convenient framework for relational operations through the use of graph convolutions [30], and the latter offer a powerful Bayesian inference method to learn the distribution of the latent graph structure of data. Inspired by these recently developed methods, we devised a Neural Network based algorithm for relational learning on graph data.

Figure 1: SAG-VAE reconstruction of the position and link information of Zachary’s karate club data. Top: Ground Truth; Middle: Position Reconstruction; Bottom: Position and Link Reconstruction. Notice that the pattern of the right-most column is not seen by SAG-VAE during the training phase.

In this paper, inspired by the recent advances in the field of GNNs and VAEs, we propose Self-Attention Graph Variational Autoencoder (SAG-VAE), a novel VAE framework that jointly learns data representation and latent structure in an end-to-end manner. SAG-VAE utilizes the gumbel-softmax reparameterization [15] to infer the graph adjacency matrix, and employs a novel Graph Convolutional Network (also proposed by this paper) as the generative network. During the generation process, a sampled adjacency matrix will serve as the graph edge information for the novel Graph Network, and a sampled data representation will be fed into the network to generate data. Based on this framework, SAG-VAE will be able to directly infer the posterior distributions of both the data representation and relational matrix based on vanilla gradient descent.

Several experiments are carried out with multiple data sets of different kinds to test the performances. We observe in the experiments that SAG-VAE is able to learn organized latent structures for homogeneous image data, and the interpretation can match the nature of the certain type of image. Also, for graph data with known connections, SAG-VAE can retrieve a significant portion of the connections based entirely on feature observations. Based on these performances, we argue that SAG-VAE can serve as a general relational structure learning method from data. Furthermore, since SAG-VAE is a general framework compatible with most Variational Autoencoders, it is straightforward to combine advanced VAEs with SAG-VAE to create more powerful models.

The rest of the paper is arranged as follows: Section 2 conducts a literature review regarding methods related to the paper; Section 3 introduces the background and discuss the proposed SAG-VAE; Experimental results are shown in section 4, and the implications are discussed; And finally, section 5 gives a general conclusion of the paper.

2 Related Work

The task of learning latent structure and feature relations from data traces its history back to the early work on causality inference, when researchers started to question the casual-effective relationships in fields like economy [11] and medicine [20]. Early methods on this topic heavily relied on domain knowledge and statistical tests, and work only on small-scale problems. Bayesian inference serves as another ‘line of research’ for this task, and recent work has developed more representative models [21]. However, most of the recent methods lying in this branch are domain-specified and hard to be tested on general-purpose data like images. Thus, these models are not quantitatively tested or compared in the paper.

Perhaps counter-intuitively, relational learning in Neural Networks started long before the era of Deep Learning. Recurrent Neural Network (RNN) and its variants like LSTM [14] are essentially relational learning methods, although they were developed under the guise of ‘sequential learning’. After the emerge of Deep Learning, neural mining of relational information was firstly advanced largely in the domain of Natural Language Processing [23]. More recently, a range of notable methods on neural relational learning, such as AIR [9], (N)REM [12, 33] and JK network [37], have achieved state-of-the-art performances by adopting explicit modeling of certain relations. In addition, Graph Neural Networks are increasingly popular in the presence of feature relations, and one prominent example can be found in [30]. However, almost all the work discussed above assumed a known relational structure, while such structure is inferred in our model.

Another important research line closely related to this work is the development of Graph Neural Network. The pioneering applications and theories of neural networks on graph backdated can be found in early 2000s literature [32, 10, 29]. In recent years, a considerable amount of approaches have been developed to simplify computation and improve performance [4, 6, 13, 18]. Prominently, the proposal of Graph Convolutional Network (GCN) [18] remarkably simplified computation and effectively unified the operations of Graph and ordinary Neural Networks. As a method closely related to Convolutional Neural Networks (CNNs), GNNs are likewise often improved by attention mechanism [34]. Among the development of attention methods, a recent advancement based on self-attention GANs [38] has attracted considerable notice. Consequently, inspired by the novel technique, this paper designed a similar Self-attention Graph Generative Network on the top of the idea, although significant changes have been made to adopt the technique to GNNs.

Furthermore, as the method utilizes the framework of Variational Autoencoder, several publications concerning this technique are closely linked to our model. On the top of the original VAE paper [16], [19] proposed an auto-encoding inference structure operating on graph data. In addition, [25] further elaborated on the idea to use VAEs to learn explicit graph structure. Finally, [17] adopts a paradigm similar to SAG-VAE with Graph Neural Network embedded into the framework of Variational Autoencoder to learn latent structure. However, that paper was focusing on dynamic and sequential model, and the scheme they used is to alternately generate node and edge representations, which are different from ours.

3 Method

3.1 Background

Graph Convolution Networks

We first introduce Graph Convolutional Networks based on standard setup in [18]. Under this framework, a graph is denoted as , where is the set of vertices and is the set of edges. The vertices and their features are denoted by a matrix, where and is number of features. A graph adjacency matrix of size will additionally be obtained based on the edge connections. Furthermore, in this setup, we ignore potential edge features. Finally, the network layers are connected using a Convolutional projection of input graph with adjacency matrix , and a feed-forward layer is characterized by the following equation:


where is the normalized adjacency matrix with as the identity matrix (adding self-loop on each vertex) and for row-normalization.

Variational Autoencoder Framework

Variational Autoencoders (VAEs) have been witnessed to be one of the most efficient approaches to learn latent representations of the data and perform large-scale variational inferences [16]. Following the standard notation, we use to denote the real distribution and for the variational distribution. Therefore, the inference model and the generative model as:


Where stands for the amount of data. And under the Gaussian prior used in the original paper [16], the inference network will be:


and the optimization objective was given as the format of Evidence Lower Bound (ELBO):


For conjugate priors like Gaussian, the KL-divergence can be computed analytically to avoid the noise in Monte-Carlo simulations.

3.2 Inference of SAG-VAE

Variational Autoencoders cannot naturally model data with graph structure, and to let graph convolution operations to be carried out, a discrete distribution of the graph edges must be present. Thus, we introduce another latent variable , which represents the distribution of the adjacency matrix of the graph. Ideally, for each edge , the model should parametrize a Bernoulli distribution. And the joint distribution will be:


where is the number of nodes, and is therefore the number of features/dimensions in the original non-graph data. The factorization of distributions holds as random varaibles as independent, and goes from as the graph adjacency is symmetric. Equation 5 precisely describes the desired distribution; However, since Bernoulli distribution is discrete, using this the target directly can make the model difficult to optimize. Thanks to recent advances in variational Deep Learning, we are able to utilize gumbel-softmax distribution [15] to simulate Bernoulli/categorical distributions. The gumbel-softmax approximation used in our model can be expressed as follows:


Notice that in equation 6 there is no index on , which means the learned adjacency matrix is a shared structure (amortized inference) and should be averaging over the input. In practice, one can apply Gumbel-Softmax to each , and averaging over the probability:


as this will make the estimation of the KL-divergence part more robust. we will discuss this issue further in the later paragraphs.

Figure 2: The model structure and graphical model of SAG-VAE.

Taking back the original variable, the joint posterior will now be . Figure 2 illustrates the structure and graphical model of SAG-VAE, and draws a comparison with vanilla VAE. Observing from the graphical model of SAG-VAE, since and are considered not d-separated, they are not necessarily independent given . Nevertheless, to simplify computation, we perform the conditional independence approximation on the variational distributions:


Crucially, equation 8 allows the posterior distributions to be separated, and therefore avoids noisy and expensive Monte-Carlo simulation of the joint KL-divergence. With the similar derivation developed in [16], one can get the new ELBO of our model:


The posterior distribution of is characterized by a learned Gaussian with a standard Gaussian prior. We omit more complicated priors to obtain a fair comparison against standard VAEs. For SAG-VAE, we need the dimension of the hidden representation to be equal to the number of dimension (one can see the reason in section 3.3). Therefore, we propose two types of implementations. The first one is to apply a set of hidden distributions for each data point, as it is usually done in ordinary VAEs; and the second one is to learn a distribution for each dimension. Noticeably, the latter scheme will lead to high-quality reconstruction results, albeit with the cost that the model becomes more vulnerable to overfitting and sampling from the SAG-VAE becomes difficult. Nevertheless, the advantages of robustness and noise-resistance of SAG-VAE are more obvious with the second implementation.

Another issue to notice is the computation of the KL-divergence term . Notice that for the SAG-VAE with per-data-point representation, with the implementation based on equation 7, the KL divergence will become:

This function is not properly normalized as the summation depends on but there is no such parameter on the denominator. For the per-dimension version of SAG-VAE, although we do have an additional factor, this KL-divergence term can still be way too dominating as the summation is of terms. Thus, inspired by the idea in [5], we use a to normalize the KL-divergence term () and improve the performance in practice.

3.3 Self-attention Graph Generative Network

The generative network of SAG-VAE is composed by a novel Self-attention Graph Neural Network model design in this paper. We can denote this in a short-handed form:


The Self-attention Graph Network generally follows the framework of [18] for information aggregation, and part of the network update directly follows equation 1. However, one significant difference is the introduction of the self-attention layer. The approach is similar to the mechanism in Self-attention GANs [38], but instead of performing global attention regardless of geometric information, the self-attention layer in our model is based on the neighbouring nodes. The reason for adopting such a paradigm in this model is that the node features and edge connections are learned instead of given. And if a global unconditional attention is performed, the errors on the initialization stage will be augmented, which will make the network more difficult to be trained.

Suppose feature is the output of the previous layer has the shape , where and represent the number of dimensions (nodes) and graph features, respectively. Now for node and any other node (neighboring nodes), the relevance value is computed as follows:


where and are the convolution matrices to transform the -dim node features to -dim attention features. All the multiplications in equation 11 are matrix multiplications. Finally, having taken into consideration the graph edge connections as geometric information, we perform the softmax operation on the neighboring nodes of (including itself). Formally, the attention value will be computed as:


In practice, the above operation can be done before normalization in parallel by multiplying the relevance information computed by equation 11 with the adjacency matrix. This attention mechanism is similar to Graph Attention Network (GAT) [34], with one main difference being that in GAT the relevance features are aggregated and multiplied by a learnable vector, while in SA-GNN the relevance features are directly processed by dot products. After computing for each pair and obtaining the matrix , the attention result can be directly computed by matrix multiplication in the same manner of [38]:


where and are the and transformation matrices, respectively. The main purpose of using the two matrix is to reduce computational cost, and in practice, if the number of graph features (channels) are not too large, the transformations are not necessarily required.

To further improve the performance, we considered incorporating edge weights into the attention mechanism. The weights can be computed by the encoding matrix with a share structure of network. Formally, this can be expressed as:


Where indicates a network shares the structure with except the last layer. Meanwhile, the main diagonal of will be set to 1. Therefore, equation 12 can be revised into:


And in a similar idea to [38], the attention-based feature will be multiplied by a coeffcient originally set as 0 and added to the features updated by the rules in vanilla GCN:


where is the convolution weights of the -th layer. Based on the above equation, the network will first focus on learning the graph geometry (edges), and then using the attention mechanism to improve the generation quality. Notice that the intention behind this interpolation is different from that of the self-attention GAN, which focus on global-local combination instead, although the mathematical expressions are identical.

In practice, to form a the generative network, we need to take both reconstruction quality and generalization into consideration. In addition, to make sure the distribution of the graph is properly trained, there has to be a mechanism to enforce correlation between the generated adjacency graph and the information at each layer. To achieve the above goals, inspired by the method proposed in [7], we use skip connection to interpolate each layer of the generative network. Specifically, at each layer, we combine the self-attention-processed layer information with the the graph message aggregation of the latent code . This can be denoted as:


where is the non-linear activation, is the latent representation of the data (directly from ), and is the convolutional weight between the latent representation and the current layer. For the last layer, we apply activation after amalgamating the information:


to keep the properties produced by certain activation (e.g. Sigmoid will produce results in ). Finally, it is important to note that in the VAE framework, the latent variable does not naturally fit in the GCN framework where each node is treated as a feature vector. Thus, for the data point-wise distribution version of SAG-VAE, one needs to first transform the dimension into with a fully-connected layer, and then add one dimension to get a tensor. In contrast, for the SAG-VAE with individual distributions for each dimension, one can directly sample a to operate on GCNs.

4 Experiments

In this section, we demonstrate the performances and favorable properties of SAG-VAE. Intuitively, by introducing the graph-structured relational learning, SAG-VAE will have two advantages over ordinary VAE and its popular variations: interpretable relations between features and robustness against perturbations. To validate the correctness of the learned interpretable relations, one can apply SAG-VAE to the task of retrieving graph edges/links based on node feature observations. Also, this type of task is itself considered important in the regime of graph learning. On the other hand, for the robustness of the SAG-VAE model, we can test its performance on tasks like reconstruction with noise/mask and sampling with perturbations. Notice that for the majority of the models tested in this section are implemented with independent distributions on each dimension. We choose this setup because the advantage of SAG-VAE is more significant under these situations. Moreover, in practice, the data point-wise distribution version will be straightforward to implement, although the parameters will be more difficult to tune.

4.1 Graph Connection Retrieval

We apply two types of feature observations based on graph data. For the first type, the features are generated by a 2-layer Graph Neural Network (GCN in [18]) by propagating information between neighboring nodes; And for the second type, we pick graph data with given feature observations and randomly drop out rows and add Gaussian noises to obtain a collection of noisy data. Notice this task of retrieving graph edge from feature observations is considered as an interesting problem in the areas of machine learning. To facilitate the training process, for the SAG-VAE model used for graph connection retrieval, we apply an ‘informative’ prior that adopts the edge density as the prior of Bernoulli distribution. This is a realistic assumption and the type of information is likely available for real-life problems. Thus, it does not affect the fairness of performance comparisons.

Results of experiments on two types of graph data illustrate that SAG-VAE can correctly retrieve a significant portion of links (satisfactory recall) while avoid generating overly redundant connections (satisfactory precision). For the first type of data, SAG-VAE can effectively generalize the reconstruction to an unseen pattern of positions. Also, by sampling from the hidden distributions, new patterns of positions can be observed. For the second type of data, SAG-VAE can outperform major existing methods. In addition, the inference of hidden representation is a unique advantage comparing to existing methods.

To show the performance advantages of SAG-VAE, the performances of SAG-VAE are compared with pairwise product and Variational Graph Autoencoder (VGAE) [19]. The number of models for comparison is in small scale since there is only limited number of methods capable of inferring links based entirely on feature observations. Nevertheless, it does not invalidate the performance superiority of SAG-VAE. The most naive model (pairwise product) is to directly compute the dot product between any pair of nodes, and use Sigmoid to produce the probability for a link to exist. This simple method serves as the baseline in the comparison of [19], although the features are computed with deep walk [26] in their model. More advanced baselines are based on VGAE, which use part of the graph to learn representation and generalize the generation to the overall graph. The direct comparison between VGAE and SAG-VAE is to remove all edge connections and feed the graph data to VGAE with only ‘self-loops’ on each node. To further validate the superiority of SAG-VAE, we also demonstrate the performance of VGAE with 10% of edges preserved in the training input, and show that SAG-VAE can outperform VGAE even under this biased setup.

Karate Synthetic Data

We adopt Zachary’s karate club graph to generate the first type of feature observations. In the implementation, each type of node (labeled in the original dataset) is parametrized by an individual Gaussian distribution, and 5 different weights are adopted to generate graphs with 5 patterns. During the training phase, only the first 4 types of graphs are provided to the network, and the final pattern is used to test if the trained SAG-VAE is able to generalize the prediction.

Figure 1 illustrates the reconstruction of 3 patterns of node positions based on the SAG-VAE with an individual Gaussian distribution on each dimension. From the figure, it can be observed that the SAG-VAE model can approximately correctly reconstruct the node positions, and while the pattern of links are not exactly the same as the original, the overall geometries are similar in terms of edge distributions. In addition, for the unseen pattern (the rightmost column), the model successfully infers the position and the key links of the graph.

Figure 3: Karate position sampling from SAG-VAE with two different implementations

Figure 3 shows the sampling results with both data point- and dimension-wise representation of SAG-VAE. From the figures, it can be observed that both versions of SAG-VAE can generate Karate data information in an organized manner. Sampling from the SAG-VAE with data-wise latent code can further restrict the patterns of the graph, while sampling from its dimension-wise counterpart appears to get a more organized distribution on the node level with different types of nodes better segmented.

Table 1 illustrates the comparison of performance between different methods on the Karate-generated data. From the table it can be observed that SAG-VAE with both data-wise and dimension-wise implementations can outperform methods of comparisons. It is noticeable that for this graph generation task, adding 10% ground-truth links does not help significantly improve the score of VGAE. In contrast, simply applying pairwise product will lead to a better performance in this case.

Method Precision Recall score
Pairwise Product 0.139 0.985 0.243
VGAE (no input edge) 0.142 0.524 0.223
VGAE (10 % link) 0.150 0.539 0.234
SAG-VAE (data-wise) 0.616 0.558 0.586
SAG-VAE (dimension-wise) 0.558 0.611 0.583
Table 1: Performance comparison between SAG-VAE and other methods on Karate-generated data.

Graph Data with given Node Features

Table 2 illustrates the comparison of performance ( scores) between different models on three benchmark graph data sets: Graph Protein [3, 8], Coil-rag [27, 24] and Enzymes [3, 31]. All the 3 types of data come with rich node feature representation, and we obtain the training and testing data by selecting one sub-graph from the data and apply the second type of data generation (with random noise and row dropout). The extracted graph are of size 64, 6 and 18, respectively. Comparing to the Karate data used above, the graphs adopted here are significantly sparser

From Table 2, it can be observed that SAG-VAE can outperform methods adopted for comparison, especially for the VGAE-based results. For VGAE, the performance is poor for all datasets and adding back 10% links does not help remedy the situation. On the other hand, simply applying pairwise product yields in quite competitive performances. One possible reason behind this observation is that since the node features are highly noisy, it is very difficult for the VAE architecture to learn meaningful embedding of the nodes; on the other hand, since the feature representations are originally rich, pairwise product can capture sufficient information, and therefore leads to an unexpected good performance. The curse of noisy feature is resolved by applying SAG-VAE: with the merits of the joint inference of data representation and feature relations, the model can overcome the problem of noise under the VAE framework and lead to overall superior performances.

Method Protein Colirag Enzymes
Pairwise Product 0.367 0.714 0.410
VGAE (no input edge) 0.276 0.620 0.315
VGAE (10 % link) 0.283 0.643 0.319
SAG-VAE (dimension-wise) 0.385 0.800 0.423
Table 2: Performance comparison ( score only) between SAG-VAE and other methods on graph data with given node features.

4.2 Image Data: Robust Reconstruction and Sampling

As it is stated before, we expect SAG-VAE has a more robust performance against perturbations because of the learned correlations between features can lead to a noise-resisting inductive bias. In this section, we test the robustness of SAG-VAE on two image datasets: MNIST and Fashion MNIST. The performances are evaluated based on 3 tasks: masked/corrupted reconstruction, noisy reconstruction, and noisy sampling. Intuitively, for the reconstruction tasks, if the reconstructed images from SAG-VAE are of higher qualities than those from plain VAE, the robustness of SAG-VAE will be corroborated. Moreover, the noisy sampling task will directly perturb some of the hidden representations, and the inductive bias in SAG-VAE will be able to overcome it. Finally, the plots of the adjacency matrices will show how well has the model learned adjacency the structured relationships between features.

In these experiments, we only implemented SAG-VAE with dimension-wise distributions. This type of model can produce reconstruction with higher qualities, but it is more vulnerable to perturbation. Therefore, testing with this type of setup can better illustrate the advantages of SAG-VAE. Nevertheless, this setup makes sampling harder than ordinary VAE, as we no longer have low-dimension latent codes. To conduct the sampling process, we model the mean and variance of each pixel for data with different labels. We use Gaussian distribution:

to approximately model the manifold and distributions of each dimensions. Notice that unlike graph data, for images, using dimension-wise distribution will damage the ability for VAE to sample meaningful data. We apply this paradigm here mainly for the purpose to illustrate the robustness of the SAG-VAE framework. And in practice, one can always switch to data-wise distribution model, albeit with the parameters harder to tune.

Noisy and Masked Reconstruction

Both MNIST and Fashion MNIST images are in shape of . For the Fashion MNIST, to better leverage the common structure, we remove the image classes that are not shirt-like or pans-like since their geometries are significantly different from the rest of the dataset. To artificially introduce adversarial perturbation on images, two types of noises are applied: uniform noise and block-masking/corruption. For uniform noise-based perturbation, 200 pixels (or 150 depending on the type of data) are randomly selected and replaced with a number generated from uniform distribution . For masked-based perturbation, a block of is added at random position on each image, thus a small portion of the digit or object in the image is unseen.

We firstly test SAG-VAE on MNIST data with perturbations. 10 reconstructed images with corresponding perturbed and original images are randomly selected and presented in figure 4 and figure 5. On the same image, the performance of vanilla VAE is also illustrated. The vanilla VAE is implemented with fully connected layer for each dimension, which is equivalent to SAG-VAE with the adjacency matrix (links) to be zero for all but the main diagonal.

Figure 4: Reconstruction comparison on noisy MNIST. Top: Noisy images; 2nd row: VAE Reconstruction; 3rd row: SAG-VAE Reconstruction; Bottom: Original images.
Figure 5: Reconstruction comparison on masked MNIST. Top: Masked images; 2nd row: VAE Reconstruction; 3rd row: SAG-VAE Reconstruction; Bottom: Original images.

As one can observe, images reconstructed by vanilla VAE falsely learned the patterns of noise and blocks, as there is no inductive bias against such situation. On the other hand, for both tasks, SAG-VAE outperforms VAE significantly in terms of reconstruction quality. For the noisy perturbation, one can merely observe visible noise from the reconstruction result of SAG-VAE. And for the masked perturbation, although the reconstruction quality is not as strong, it can still be observed that the edges of blocks are smoothed and mask sizes are reduced adequately. Notice that the performance of SAG-VAE on the task with uniform noise is close to denoising autoencoder [35], yet we did not introduce any explicit denoising measure. The de-noising characteristics is introduced almost entirely by the inductive bias from the learned feature relations.

We further test the same tasks on Fashion MNIST, and the performances can be shown in figures 7 and 8. Again, we can observe from the figures that SAG-VAE significantly outperforms VAE when perturbation exists in the input data. It is noticeable that in Fashion MNIST reconstruction, SAG-VAE appears to be more resistant to block-masking, and it maintains its noisy strong robustness against uniform noise.

Figure 6 shows the loss ( distance) between reconstructed images and the original and the noise-corrupted images respectively on SAG-VAE. It can be observed that the gap between reconstructed and original images declines aligned with training loss, while the loss between reconstructed images and noise images declines ends up with landing at a plateau on a high level. This indicates that the robustness of SAG-VAE will defy itself from learning the perturbation as information. Limited by the space, we did not include the figure for the training losses of vanilla VAE. In our experiments, we observe that for vanilla VAE, the reconstruction loss between the noisy image will continue to decrease while the loss between the real image will increase, indicating that plain VAE falsely fits the perturbation as information.

Finally, figure 9 shows the learned adjacency matrix as the relations between different features. It can be observed that while it is not very straightforward to interpret the reason for each connection to exist, the graph structure is properly organized, and it can be reasonably argues that the robustness against perturbation comes from this organized structure.

Figure 6: Training Loss and Reconstruction Loss of Fashion MNIST.
Figure 7: Reconstruction comparison on noisy Fashion MNIST. Top: Noisy images; 2nd row: VAE Reconstruction; 3rd row: SAG-VAE Reconstruction; Bottom: Original images.
Figure 8: Reconstruction comparison on masked Fashion MNIST. Top: Masked images; 2nd row: VAE Reconstruction; 3rd row: SAG-VAE Reconstruction; Bottom: Original images.
Figure 9: Adjacency Matrix Generated from MNIST and Fashion MNIST.

Noisy Sampling

Following the method discussed in the beginning of section 4.2, we model the distribution of the latent representations of MNIST images for different image classes by fitting the distributions of means and variances for each pixel. For each image class, 50 images are selected to fit the distribution of the parameters. Figure 10 illustrates the comparison between the sampling with uncorrupted latent representations and with a corruption/noise of 200 dimensions of been filled with uniform noises. Notice that since there is an individual distribution for each pixel, the overall image does not have ’sampling manifold’ as ordinary data-wise VAE does, and the generation quality is lower. Nevertheless, this does not affect the main purpose of the experiment, which aims to verify the robustness of the SAG-VAE framework.

From the figure, it can be observed that although a significant number of dimensions have been filled with uniform noise, the generated results can still generally preserve the characteristics of the digits. Given the condition of the limited quality of the original per-pixel generation, the decline of quality is actually not very significant. Also notice that SAG-VAE is able to ’adjust’ the generated digits to the area it should appear and resist noise on the ’canvas’ – no generated image appears to be noisy on black background.

Limited by space, we do not provide results on Fashion MNIST for this task. Such work can be treated as a future direction to further investigate this novel method.

Figure 10: Noisy Sampling on MNIST images.

5 Conclusion

In this paper, we propose Self-Attention Graph Variaional AutoEncoder (SAG-VAE) based on recent advances on Variational Autoencoders and Graph Neural Networks. This novel model can jointly infer data representations and relations between features, which provides strong explainable results for the input datasets. In addition, by introducing the learned relations as inductive biases, the model demonstrates strong robustness against perturbations. Furthermore, a novel Self-Attention Graph Neural Network (SA-GNN) is proposed in the paper.

To conclude, this paper makes the following major contributions: firstly, it proposes a novel VAE-based framework which can jointly infer representations and feature relations in an end-to-end manner; secondly, it presents a novel Self-attention-based Graph Neural Network, which leverages the power of self-attention mechanism to improve the performance; and finally, it demonstrates multiple favorable experimental results, which further corroborate the intention of the method.

In the future, the authors intend to extend the model to more advanced posterior approximation techniques (e.g. IWAE) and more flexible priors (e.g. normalized flow). Testing the performances of the model on more complicated datasets is another direction.


The authors would like to thank Prof. Sungjin Ahn of Rutgers University for organizing the Machine Learning reading group which covers the topic of this paper. The first author would like to thank Sepehr Janghorbani of Rutgers University for constructive discussions.

Appendix A The derivation of the new ELBO

Our goal is to maximize . By substituting the posterior with and leveraging the approximation in equation 8, we can write:

(Jensen’s Inequality)
(factorize )

which is the ELBO in equation 9.


  1. P. W. Battaglia, J. B. Hamrick, V. Bapst, A. Sanchez-Gonzalez, V. Zambaldi, M. Malinowski, A. Tacchetti, D. Raposo, A. Santoro and R. Faulkner (2018) Relational inductive biases, deep learning, and graph networks. arXiv preprint arXiv:1806.01261. Cited by: §1.
  2. A. Bordes, N. Usunier, A. Garcia-Duran, J. Weston and O. Yakhnenko (2013) Translating embeddings for modeling multi-relational data. In Advances in neural information processing systems, pp. 2787–2795. Cited by: §1.
  3. K. M. Borgwardt, C. S. Ong, S. Schönauer, S. Vishwanathan, A. J. Smola and H. Kriegel (2005) Protein function prediction via graph kernels. Bioinformatics 21 (suppl_1), pp. i47–i56. Cited by: §4.1.2.
  4. J. Bruna, W. Zaremba, A. Szlam and Y. LeCun (2013) Spectral networks and locally connected networks on graphs. arXiv preprint arXiv:1312.6203. Cited by: §2.
  5. J. Chou (2019) Generated loss and augmented training of mnist vae. arXiv preprint arXiv:1904.10937. Cited by: §3.2.
  6. M. Defferrard, X. Bresson and P. Vandergheynst (2016) Convolutional neural networks on graphs with fast localized spectral filtering. In Advances in neural information processing systems, pp. 3844–3852. Cited by: §2.
  7. A. B. Dieng, Y. Kim, A. M. Rush and D. M. Blei (2018) Avoiding latent variable collapse with generative skip models. arXiv preprint arXiv:1807.04863. Cited by: §3.3.
  8. P. D. Dobson and A. J. Doig (2003) Distinguishing enzyme structures from non-enzymes without alignments. Journal of molecular biology 330 (4), pp. 771–783. Cited by: §4.1.2.
  9. S. A. Eslami, N. Heess, T. Weber, Y. Tassa, D. Szepesvari and G. E. Hinton (2016) Attend, infer, repeat: fast scene understanding with generative models. In Advances in Neural Information Processing Systems, pp. 3225–3233. Cited by: §2.
  10. M. Gori, G. Monfardini and F. Scarselli (2005) A new model for learning in graph domains. In Proceedings. 2005 IEEE International Joint Conference on Neural Networks, 2005., Vol. 2, pp. 729–734. Cited by: §2.
  11. C. W. Granger (1969) Investigating causal relations by econometric models and cross-spectral methods. Econometrica: Journal of the Econometric Society, pp. 424–438. Cited by: §2.
  12. K. Greff, S. van Steenkiste and J. Schmidhuber (2017) Neural expectation maximization. In Advances in Neural Information Processing Systems, pp. 6691–6701. Cited by: §2.
  13. M. Henaff, J. Bruna and Y. LeCun (2015) Deep convolutional networks on graph-structured data. arXiv preprint arXiv:1506.05163. Cited by: §2.
  14. S. Hochreiter and J. Schmidhuber (1997) Long short-term memory. Neural computation 9 (8), pp. 1735–1780. Cited by: §2.
  15. E. Jang, S. Gu and B. Poole (2016) Categorical reparameterization with gumbel-softmax. arXiv preprint arXiv:1611.01144. Cited by: §1, §3.2.
  16. D. P. Kingma and M. Welling (2013) Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114. Cited by: §1, §2, §3.1.2, §3.2.
  17. T. Kipf, E. Fetaya, K. Wang, M. Welling and R. Zemel (2018) Neural relational inference for interacting systems. arXiv preprint arXiv:1802.04687. Cited by: §1, §2.
  18. T. N. Kipf and M. Welling (2016) Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907. Cited by: §1, §2, §3.1.1, §3.3, §4.1.
  19. T. N. Kipf and M. Welling (2016) Variational graph auto-encoders. arXiv preprint arXiv:1611.07308. Cited by: §2, §4.1.
  20. B. Kuipers and J. P. Kassirer (1984) Causal reasoning in medicine: analysis of a protocol. Cognitive Science 8 (4), pp. 363–385. Cited by: §2.
  21. S. Linderman, R. P. Adams and J. W. Pillow (2016) Bayesian latent structure discovery from multi-neuron recordings. In Advances in neural information processing systems, pp. 2002–2010. Cited by: §2.
  22. C. Louizos, U. Shalit, J. M. Mooij, D. Sontag, R. Zemel and M. Welling (2017) Causal effect inference with deep latent-variable models. In Advances in Neural Information Processing Systems, pp. 6446–6456. Cited by: §1.
  23. T. Mikolov, K. Chen, G. Corrado and J. Dean (2013) Efficient estimation of word representations in vector space. arXiv preprint arXiv:1301.3781. Cited by: §2.
  24. S. A. Nene, S. K. Nayar and H. Murase (1996) Columbia object image library (coil-20). Cited by: §4.1.2.
  25. S. Pan, R. Hu, G. Long, J. Jiang, L. Yao and C. Zhang (2018) Adversarially regularized graph autoencoder for graph embedding. arXiv preprint arXiv:1802.04407. Cited by: §2.
  26. B. Perozzi, R. Al-Rfou and S. Skiena (2014) Deepwalk: online learning of social representations. In Proceedings of the 20th ACM SIGKDD international conference on Knowledge discovery and data mining, pp. 701–710. Cited by: §4.1.
  27. K. Riesen and H. Bunke (2008) IAM graph database repository for graph based pattern recognition and machine learning. In Joint IAPR International Workshops on Statistical Techniques in Pattern Recognition (SPR) and Structural and Syntactic Pattern Recognition (SSPR), pp. 287–297. Cited by: §4.1.2.
  28. A. Sanchez-Gonzalez, N. Heess, J. T. Springenberg, J. Merel, M. Riedmiller, R. Hadsell and P. Battaglia (2018) Graph networks as learnable physics engines for inference and control. arXiv preprint arXiv:1806.01242. Cited by: §1.
  29. F. Scarselli, M. Gori, A. C. Tsoi, M. Hagenbuchner and G. Monfardini (2008) The graph neural network model. IEEE Transactions on Neural Networks 20 (1), pp. 61–80. Cited by: §1, §2.
  30. M. Schlichtkrull, T. N. Kipf, P. Bloem, R. Van Den Berg, I. Titov and M. Welling (2018) Modeling relational data with graph convolutional networks. In European Semantic Web Conference, pp. 593–607. Cited by: §1, §2.
  31. I. Schomburg, A. Chang, C. Ebeling, M. Gremse, C. Heldt, G. Huhn and D. Schomburg (2004) BRENDA, the enzyme database: updates and major new developments. Nucleic acids research 32 (suppl_1), pp. D431–D433. Cited by: §4.1.2.
  32. A. Sperduti and A. Starita (1997) Supervised neural networks for the classification of structures. IEEE Transactions on Neural Networks 8 (3), pp. 714–735. Cited by: §2.
  33. S. Van Steenkiste, M. Chang, K. Greff and J. Schmidhuber (2018) Relational neural expectation maximization: unsupervised discovery of objects and their interactions. arXiv preprint arXiv:1802.10353. Cited by: §1, §2.
  34. P. Veličković, G. Cucurull, A. Casanova, A. Romero, P. Lio and Y. Bengio (2017) Graph attention networks. arXiv preprint arXiv:1710.10903. Cited by: §2, §3.3.
  35. P. Vincent, H. Larochelle, Y. Bengio and P. Manzagol (2008) Extracting and composing robust features with denoising autoencoders. In Proceedings of the 25th international conference on Machine learning, pp. 1096–1103. Cited by: §4.2.1.
  36. Z. Wu, S. Pan, F. Chen, G. Long, C. Zhang and P. S. Yu (2019) A comprehensive survey on graph neural networks. arXiv preprint arXiv:1901.00596. Cited by: §1.
  37. K. Xu, C. Li, Y. Tian, T. Sonobe, K. Kawarabayashi and S. Jegelka (2018) Representation learning on graphs with jumping knowledge networks. arXiv preprint arXiv:1806.03536. Cited by: §1, §2.
  38. H. Zhang, I. Goodfellow, D. Metaxas and A. Odena (2018) Self-attention generative adversarial networks. arXiv preprint arXiv:1805.08318. Cited by: §2, §3.3, §3.3, §3.3.
Comments 0
Request Comment
You are adding the first comment!
How to quickly get a good reply:
  • Give credit where it’s due by listing out the positive aspects of a paper before getting into which changes should be made.
  • Be specific in your critique, and provide supporting evidence with appropriate references to substantiate general statements.
  • Your comment should inspire ideas to flow and help the author improves the paper.

The better we are at sharing our knowledge with each other, the faster we move forward.
The feedback must be of minimum 40 characters and the title a minimum of 5 characters
Add comment
Loading ...
This is a comment super asjknd jkasnjk adsnkj
The feedback must be of minumum 40 characters
The feedback must be of minumum 40 characters

You are asking your first question!
How to quickly get a good answer:
  • Keep your question short and to the point
  • Check for grammar or spelling errors.
  • Phrase it like a question
Test description