Table of Contents
- What are Graph Convolutional Networks (GNNs)
- Current Challenges with Explainability for GNNs
- First Attempt: Visualizing Node Activations
- Reusing Approaches from Convolutional Neural Networks
- Model-Agnostic Approach: GNNExplainer
- About Me
- References
Foreword
This is a slightly more advanced tutorial that assumes a basic knowledge of Graph Neural Networks and a little bit of computational chemistry. If you would like to get prepared for the article, I’ve listed some useful articles below:
- Introduction to Graph Convolutional Networks
- More Advanced Tutorial on Graph Convolutional Networks
- Introduction to Cheminformatics
- Molecular Fingerprints
What are Graph Neural Networks
What is the main difference between Convolutional Neural Networks (CNNs) and Graph Convolutional Networks (GNNs)?
Simply speaking, it’s the input data.
As you might remember, the required input for CNNs is a fixed size vector or a matrix. Certain type of data, however, is naturally represented by graphs; molecules, citation networks, or social media connections networks can be represented as graph data. In the past, when GNNs were not popular, graph data was often transformed in such a way that it could be directly provided to CNNs as an input. For example, molecular structures are still converted into a fixed size fingerprint where each bit indicates whether a certain molecular substructure is present or not [1]. This is a nice & hacky approach to plug in the molecular data to CNNs but doesn’t it lead to information loss?
GNNs leverage graph data which gets rid of the data preprocessing step and fully utilize the information contained in data. Now, there are many different GNN Architectures and the theory behind it gets complicated very quickly. However, GNNs can be divided into two categories: spatial and spectral approach. The spatial approach is more intuitive as it redefines pooling and convolutional operations (as in CNNs) to a graph domain. Spectral approach tackles the problem from a slightly different perspective as it focuses on processing signals that are defined as a graph network using the Fourier transform.
Have a look at a blog post written by Thomas Kipf which is an in-depth introduction to GCNs if you would like to know more. Another interesting paper to read is “MoleculeNet: A Benchmark for Molecular Machine Learning” [2] that provides a nice introduction to GNNs and describes popular GNN architectures.
Current Issue with Explainability for GNNs
At the moment of writing, I can count papers that contribute to the GNN explanation methods on one hand. Nontheless, it is a very important topic that is recently getting more popular.
GNNs became popular much later than standard Neural Networks. Although there is a lot of interesting research going in this field, it is still not quite mature. Libraries and tools for GNNs are still in an “experimental phase” and what we really need now is to have more people using it to spot bugs/errors and shift towards production-ready models.
Well, one way to create production-ready models is to better understand predictions that they make. It can be done with different explainability methods. We have already seen a lot of interesting explainability methods applied to CNNs, such as gradient attribution, saliency maps, or class activation mapping. So why not reuse them for GNNs?
Actually, that’s what is happening at the moment. The explainability methods that were initially used with CNNs are being redesigned and applied to GNNs. While there is no need to reinvent the wheel anymore, we still have to adjust these methods by redefining mathematical operations to make them applicable for graph data. The main issue here is that the research on this topic is fairly new with first publications appearing in 2019. However, it is getting more popular over time and there have been few interesting explainability approaches that can be used with GNN models.
In this article, we will look at the novel graph explanation methods and see how they can be applied to GNN models.
First Attempt: Visualising Node Activations
Pioneering work on explanation techniques for GNNs was published in 2015 by Duvenaud et al. in [3]. The major contribution of the paper was a novel neural graph fingerprint model but they also created an explanation method for this architecture. The main idea behind the model was to create differentiable fingerprints that were created directly from graph data itself. To achieve that, authors had to redefine pooling and smoothing operations for graphs. These operations then were used to create a single layer.
The layers are stacked n times to generate a vector output, as in Figure 4 (left image). The layer depth also corresponds to the radius of the neighbouring nodes from which node features were gathered and pooled (sum function in this case). It is because, for each layer, pooling operation gathers information from neighbouring nodes, similarly as in Figure 1. For deeper layers, the propagation of the pooling operations reaches out to nodes from the more distant neighbourhood. Contrary to normal fingerprints, this approach is differentiable which allows backpropagation to update its weights in a similar way to CNNs.
In addition to the GNN model, they created a simple method of visualizing node activations, together with its neighbouring nodes. It is unfortunately not well explained in the paper and it is necessary to look at their code implementation to understand the underlying mechanism. However, they run it on the molecular data to predict solubility and it highlights part of molecules that have high predictive power for solubility prediction.
How does it work exactly? To compute node activations, we need to do following computations. For each molecule, let’s forward-pass the data through each layer, similarly as it would be done in a typical CNN network for an image. Then, we extract contribution of each layer for each fingerprint bit with a softmax() function. We are then able to associate a node (atom) with it’s surrounding neighbours (that depends on the layer depth) having the highest contribution to a specific fingerprint bit.
This approach is relatively straight forward, but not well-documented. The initial work done in this paper was quite promising and it was followed up with more elaborate attempts of transforming CNN explanation methods into the graph domain.
If you want to learn about it more, have a look at their paper , code repository , or my detailed explanation of the method at the bottom of this Github issue.
Reusing Approaches from Convolutional Neural Networks
Sensitive Analysis , Class Activation Mapping , or Excitation Backpropagation are examples of explanation techniques that have already been successfully applied to CNNs. Current work towards explainable GNNs attempts to convert this approaches into graph domain. Majority of the work around this area has been done in [4] and [5].
Instead of focusing on the mathematical explanation of these methods which already has been done in these papers, I will provide you with an intuitive explanation of these methods and briefly discuss outcome of these approaches.
Generalizing CNN explanation methods
To reuse CNN explanation methods, let’s consider CNN input data which is an image, as a lattice-shaped graph. Figure 6 illustrates this idea.
If we keep in mind the graph generalization of an image, we can say that CNN explanation methods don’t pay much attention to edges (connections between pixels) and rather focus on the nodes (pixel values). The problem here is that graph data contains a lot of useful information in edges [4]. What we really look for, is a way of generalizing CNN explanation techniques to allow arbitrary connections between nodes [5] rather than in lattice-like order.
Which methods have been transformed into graph domain?
So far, the following explanation methods have been addressed in the literature:
- Sensitivity Analysis [4]
- Guided Backpropagation [4]
- Layer-wise Relevance Propagation [4]
- Gradient-based heatmaps [5]
- Class Activation Maps (CAM)[5]
- Gradient-weighted Class Activation Mapping (Grad-CAM) [5]
- Excitation Backpropagation [5]
Please note that authors of [4] and [5] didn’t provide an open-source implementation of these methods so it’s not possible to use them yet.
Model-Agnostic Approach: GNNExplainer
Code repository for the paper can be found here .
Worry not, there is actually a GNN explanation tool that you can use!
GNNExplainer is a model-agnostic and open-sourced explanation method for GNNs! It is also a quite versatile as it can be applied to node classification, graph classification, and to edge prediction. It is a very first attempt to create a graph-specific explanation method and it was published by researchers from Stanford [6].
How does it work?
The authors claim that reusing explanation methods previously applied to CNNs is a bad approach because they fail to incorporate relational information, which is an essence of graph data. Moreover, gradient-based methods don’t work particularly well with discrete inputs which often might be the case for GNNs (e.g. Adjacency matrix is a binary matrix).
To overcome those problems, they created a model-agnostic approach that finds a subgraph of input data which influence GNNs predictions in the most significant way. Talking more specifically, the subgraph is chosen to maximize the mutual information with model’s predictions. Figure below shows an example of how GNNExplainer works on graph data consisting of sport activities.
One very important assumption that authors make is the GNN model’s formulation. The architecture of the model can be more or less arbitrary but it needs to implement 3 key computations:
- Neural message computation between two adjacent nodes
- Message aggregation from node’s neighbourhood
- Non-linear transformation of the aggregated message and node’s representation
These required computations are somewhat limiting but most of the modern GNN architectures are based on the message-passing architectures anyway [6].
The mathematical definition of the GNNExplainer and the description of the optimization framework are right there in the paper. I will, however, spare you (and myself) gory details and instead show you some interesting results that can be obtained with the GNNExplainer.
GNNExplainer in practice
To compare results they used 3 different explanation methods: GNNExplainer , Gradient-based method (GRAD) , and a graph attention model (GAT) . The GRAD was mentioned in the previous section but the GAT model needs a bit of explanation. It is another GNN architecture that learns attention weights for edges which will help us determine which edges in the graph network are actually important for the node classification [6]. This can be used as another explanation technique but it only works for this specific model and it does not explain node features, contrary to the GNNExplainer.
Let’s first have a look at the performance of explanation methods applied to synthetic datasets. Figure 8 shows the experiment run on two different datasets.
BA-Shapes datasetis based on Barabasi-Albert (BA) graph which is a type of a graph network that we can resize freely by changing some of its parameters. To this base graph, we will attach small house-structured graphs (motifs) which are illustrated as a Ground Truth in Figure 8. This house structure has 3 different labels for the nodes: top, middle, and bottom. These node labels just indicate the position of the node in the house. So for a single house-like node, we have 1 top, 2 middle, and 2 bottom nodes. There is also one additional label which indicates that the node does not belong to house-like graph structure. Overall, we have 4 labels and a BA graph that consists of 300 nodes and 80 house-like structures that are added to random nodes in the BA graph. We also add a bit of randomness by adding 0.1N random edges.
BA-Communityis a union of two BA-Shapes datasets, so in total, we have 8 different labels (4 labels per each BA-Shapes dataset) and twice as many nodes.
Let’s have a look at results.
Results seem to be quite promising. GNNExplainer seems to explain the results in the most accurate way because the chosen subgraph is the same as the ground truth. The Grad and Att methods fail to provide similar explanations.
The experiment was also run on real datasets, as seen in Figure 9. This time a task is to classify the whole graph networks instead of classifying a single node.
Mutagis a dataset that consists of molecules that are classified based on the mutagenic effect on a certain type of bacterium. The dataset has a lot of different labels. It contains 4337 molecular graphs.
Reddit-Binaryis a dataset that represents online discussion threads in Reddit. In this graph network, users are represented as nodes and edges indicate a response to another user’s comment. There are 2 possible labels which depend on the type of user interaction. It can be either Question-Answer or Online-Discussion interactions. Overall, it contains 2000 graphs.
For Mutag dataset, GNNExplainer correctly identifies chemical groups (e.g. NO2, NH2) that are known to be mutagenic. GNNExplainer also explains well graphs classified as online discussion for the Reddit-Binary dataset. This type of interactions can be usually represented with a tree-like pattern (look at the ground truth).
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网
猜你喜欢:本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。
Introduction to Programming in Java
Robert Sedgewick、Kevin Wayne / Addison-Wesley / 2007-7-27 / USD 89.00
By emphasizing the application of computer programming not only in success stories in the software industry but also in familiar scenarios in physical and biological science, engineering, and appli......一起来看看 《Introduction to Programming in Java》 这本书的介绍吧!
UNIX 时间戳转换
UNIX 时间戳转换
HEX HSV 转换工具
HEX HSV 互换工具