GNN (Graph Neural Net) Explained— Intuition, Concepts, Application

mmayank bhardwaj
5 min read
GNN (Graph Neural Net) Explained— Intuition, Concepts, Application
Humans view the universe in terms of objects, their relations and rules that govern how these objects interact. This simple scheme helps us react to situations we have never seen before. In that sense, we are “models” that generalize well! For e.g., since we know that balls tend to bounce when they hit a solid surface, we can easily “imagine” what happens when we drop a bag full of balls on the floor, even if we have never seen that happen in real life. Google DeepMind argues that since we have these inherent biases (in the way we represent complex systems as compositions of objects & their interactions), it may be beneficial to ensure that the AI systems we train, too, have these biases so that they can mimic the way we think. Graph structures seem best-poised to enable AI to think in this manner since graphs are all about objects (sometimes called entities) & relations between them. Structure of a Graph A graph has Nodes & Edges. Nodes represent entities & Edges represent relations that link these entities. Facebook represents social networks as graphs. There, people are the nodes & their relationships are the edges. Creating a small graph is simple. You take a piece of paper and draw some circles representing each person in the network. Next, you draw lines between circles if a relationship exists between them. They say that Friends has been watched over 100 billion times. It is therefore likely that most readers may immediately relate to the below graph. We travel from node to node via edges. Can you tell me the path from Nora Bing, the romantic novelist in NY to David, the scientist doing research in Minsk? An example of nodes and edfes in a graph https://medium.com/towards-artificial-intelligence/hnsw-small-world-yes-but-how-in-the-world-is-it-navigable-77701ed37e20 When dealing with larger datasets, we need switch from pen-n-paper to using data structures like sets to hold data. We would also have to come up with hacks to ensure graph construction & graph searches are possible at billion+ scale, but apart from these small inconveniences, it is pretty much the same thing. That is all there is to graph theroy we need to learn. Euler invented Graph theory to solve an interesting puzzle. The details are here. Using Graph structures in AI Graph Neural Nets (GNNs): Here, we combine graph structures with neural nets to solve problems. We start out by preparing the graph structure i.e. deciding what are the nodes & edges. This may need domain knowledge. We then construct the graph. We use back-propagation to learn the weights of the model’s neural layers using a suitable loss function by fitting to the training data. There are also unsupervised methods which can be used in the absence of labelled data. Probabilistic Graphical Modeling (PGM): Here we combine graph theory with Probabilistic AI modeling to do some really cool stuff. Hybrids like Neural Graphical Models (NGM): We could even mix-n-match the above two concepts & create hybrids like NGMs. While this term was recently introduced, the Restricted Boltzmann Machines of yore mixed PGMs & neural nets with good success. In an NGM, we represent the probability function over the domain using a neural net. The Representations learnt are judged on a downstream task & the “probabilistic view” of the model is often discarded unlike a PGM. We have been doing Probabilistic theory in this series, so we have all the background needed (here is a refresher). We already learnt all there is to learn about graph structures above (thanks to Friends). So let us take a look at GNNs in this article & jump to PGMs in the next! We will not discuss hybrids since once we understand the core concepts involved, we can easily understand how they are mixed & matched. The focus, as always, is on the intuition, plain-English theory & a holistic view. There are many articles deep diving into specific topics, but it is rare to get a bird’s eye view. We also devote substantial real estate to practical applications of GNNs. Graph Neural Nets Explained Traditional neural nets operate on inputs which are embeddings or numerical representations. In GNNs, the input is the graph structure. A GNN model, therefore, has the additional task of converting these graph structures into learnable embeddings. How do we do this? Well, we saw how nodes in a graph strcuture represent real-world entities. We can represent each node in a GNN by an embedding. The end-goal of a GNN training is to learn a representation of the input graph and this can be done in two steps. The first step is to learn the embeddings node-wise based on the features i.e. attributes of the entity involved. We then do some form of an aggregation operation across connected nodes & use this info to update the embedding of the node in question. This way, we can capture the key features of each node (via node embedding) and the topology of the graph itself (in terms of how these nodes connect) via the aggregation op. Once we have these representations, we can use GNNs for predictions. In case of social media networks, predictions may revolve around expected user behavior based on their profile & social circles. Basically, a GNN takes a graph as input. Each node in the graph has a set of attributes which generate the original embedding for that node. During training, we multiply these node embeddings with the (randomly initialized) neural layer’s weights & apply an activation function to get the updated embedding. We further update this embedding by looking at the connected nodes & doing some sort of aggregation. We repeat this many times. Before looking at the specific order in which we do things, let us first understand an important general concept related to graphs — Message passing. Message Passing between a graph’s nodes Training a model using back-propagation to adjust its weights to get the best embedding for each node is relatable to anyone who has ever trained any form of a neural network. The novelty (or pecularity if you want to call it) of a GNN is how we handle the topology of the graph. In other words, if a particular node X is connected to 4 neighbouring nodes, we somehow want the information from the 4 connected nodes to be reflected in the embedding of X node. It shouldn’t end there. Those 4 nodes may be connected to several other nodes and information from those nodes should flow to them & reflect in their embeddings. Moreover, now that the 4 connected nodes of node X have updated embeddings, we need to consider updating the embedding of node X to reflect this new information. This information flow is called message passing and happens via the aggregation operations like sum, mean etc. What happens is that for each node, embeddings from all the connected nodes are summed (assuming a sum operation) & this sum is made to reflect in the node’s own embedding. If I am following Gavaskar, Hadlee, Richards & Botham, then my embedding should reflect the fact that I am an 80’s cricket fan stuck in a time warp. Basically, we update the source node’s embeddings to (also) reflect the embeddings of the neighboring nodes using an aggregation op. We do this for all the nodes in the graph. We then recalculate each node’s embedding again (since neighbour’s embeddings would have got updated). We repeat a few times till there is stability. This is the core of a GNN — Exchange information between neighbours via message passing, until equilibrium is reached. There are many architectures that implement this core concept. Press enter or click to view image in full size An illustration of Message passing in GNNs Source: Fig 4.1 from https://cs.mcgill.ca/~wlh/comp766/files/chapter4_draft_mar29.pdf While Message passing is a generic term, there are specific algorithms like Belief Propagation which do this efficiently. A good part of the literature in graphical models revolves around solutions to speeden up this process during inference. The core idea is that one can reuse these message computations. This simple idea unlocks many type of algorithms that focus on efficiently organizing the calculations involved — given that we can reuse messages & that each Variable Elimination “collapses” part of the graph. But, if we use GNNs, Variable Elimination etc is not needed. Why? Message Passing in GNNs GNN’s use a separate MultiLayer Perceptron (MLP is just a fancy acronym for a neural net with at least one hidden layer) for each component of a graph. For e.g., for each node, we apply its corresponding MLP and get back a learned embedding. Message passing in GNN works something like this: For each node, Gather all the neighboring node embeddings (messages) Aggregate, maybe do a SUM or an AVG Update, pass the aggregated embedding thru’ the node’s MLP to get an updated embedding (which includes info from all messages passed) We simply stack layers that do Step 1–3, one above the other & keep passing the output to the higher layer. Each such layer is equivalent to one traditional Message Passing iteration. We can add as many layers in the GNN as the number of iterations we want. Basically, these layers & the Math inside do the job of Message Passing. With each layer, our embeddings strengthen. Modern GNNs can be attention-based but the concepts are similar. So Variable Elimination, Belief propagation etc are not used in GNNs. There are mild variations in Message passing depending on the architecture we look at. A simple technique called self-loop involves drawing an imaginary edge to make a node connect to itself, effectively including the node’s own features in the Aggregate step. Many GNN architectures concatenate the node’s embedding, with the aggregated neighborhood embedding & feed this concatenated vector to the Update layer. The general idea of message passing is important here & specific implementations can be studied at a later stage. This is a good link. Message passing can be applied to either nodes or edges (e.g. let an edge be the mean of the embeddings of the 2 nodes it connects). But let us take a step back & view holistically what is happening. Things are somewhat similar to a CNN convolution & pooling operation. There, the convolution extracts features from pixels & the pooling averages over connected regions to reduce the overall image landscape to be analyzed. Repeat this iteration multiple times & the feature set becomes richer & richer & the image is no longer analyzed as a bunch of individual pixels but is analyzed as a whole. We end up with a single feature rich embedding for the entire image which can be used for the prediction task. Well, Message passing in GNNs do something similar. They are operations to aggregate and process the information of an element’s neighbors in order to update the element’s value which becomes richer in meaning after each layer until finally, it reflects information of the entire graph structure. In fact, if we do a good job of aggregating local information, our models may even work reasonably well in situations where there are no node-specific attributes available. Wow! How can node embeddings be generated without some set of attributes? Well, in such a situation, one could leverage the graph structure (topology) alone to learn representations & generate Node embeddings based on proximity. Unsupervised learning techniques like DeepWalk use this concept. But we are getting ahead of ourselves. We will visit DeepWalk later, for now let us look at some basic architectures.

About the Author

m

mayank bhardwaj

Senior Engineer at Codlyn Softwares. Passionate about building scalable systems and sharing knowledge with the community.