Hunt the papertiger of graph neural networks (GNNs), intuitively, mathematically, implementably

Graph neural networks

A graph neural network is not as mysterious as the name suggests. At its core, it is a neural network that repeatedly lets each node look at its neighbors, collect messages from them, and update its own representation.

A generic message-passing layer can be written as:

\[x_i^{(t)} = \operatorname{Update}\left( x_i^{(t-1)}, \square_{j \in \mathcal{N}(i)} \operatorname{Message}\left(x_i^{(t-1)}, x_j^{(t-1)}, e_{ij}\right) \right)\]

Here:

  • $x_i^{(t)}$ is the representation of node $i$ after layer or time step $t$.
  • $\mathcal{N}(i)$ is the set of nodes adjacent to node $i$.
  • $e_{ij}$ is an optional edge feature between node $i$ and node $j$.
  • $\square$ is a permutation-invariant aggregation operator, commonly sum, mean, or max.

In a compact matrix form, many GNN layers can be viewed as:

\[H^{(t+1)} = F(H^{(t)}, X)\]

where $H^{(t)}$ is the hidden node representation at layer $t$, $X$ is the original node feature matrix, and $F$ is the layer function that mixes feature transformation with graph-based aggregation.

The key idea is simple: a node’s new embedding should depend on its own previous embedding and on the information arriving from its neighbors. Stacking layers lets information travel farther. One layer captures one-hop neighborhood information; two layers capture information from nodes up to two hops away; and so on.

A minimal implementation pattern looks like this:

import torch


def gnn_layer(x, edge_index, weight):
    """One simple sum-aggregation GNN layer.

    x:          [num_nodes, in_features]
    edge_index: [2, num_edges], where edge_index[0] -> edge_index[1]
    weight:     [in_features, out_features]
    """
    src, dst = edge_index
    messages = x[src]

    aggregated = torch.zeros_like(x)
    aggregated.index_add_(0, dst, messages)

    h = x + aggregated
    return torch.relu(h @ weight)

This is not a full production layer, but it shows the moving parts: select source node features, aggregate them at destination nodes, combine them with the current node state, then apply a learnable transformation and nonlinearity.

One single big graph

The first common GNN setting is one large graph. Examples include citation networks, social networks, product graphs, knowledge graphs, and web graphs.

In this setting, the graph structure is fixed, and the task is usually node-level or edge-level prediction:

  • Node classification: predict the label of each node, such as paper topic in a citation graph.
  • Link prediction: predict whether an edge should exist between two nodes.
  • Node regression: predict a continuous value for each node.

For node classification, we usually have:

  • A feature matrix $X \in \mathbb{R}^{N \times d}$, where $N$ is the number of nodes and $d$ is the number of input features.
  • An adjacency matrix $A \in \{0,1\}^{N \times N}$ or an edge list.
  • Labels for some nodes.

A typical graph convolution layer can be written as:

\[H^{(t+1)} = \sigma\left(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}H^{(t)}W^{(t)}\right)\]

where $\tilde{A} = A + I$ adds self-loops, $\tilde{D}$ is the degree matrix of $\tilde{A}$, $W^{(t)}$ is a learnable weight matrix, and $\sigma$ is a nonlinear activation function.

The important point is that each node’s representation is smoothed and mixed with its neighbors’ representations. This is powerful when connected nodes are meaningfully related, but it can fail when too much smoothing makes different nodes indistinguishable.

A simple training loop for one big graph usually looks like this:

for epoch in range(num_epochs):
    logits = model(x, edge_index)
    loss = loss_fn(logits[train_mask], y[train_mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

The masks split nodes, not graphs. The model sees the same graph structure during training and evaluation, but labels are only used for the training nodes.

Many nonisomorphic graphs

The second setting is a dataset of many different graphs. Examples include molecules, protein structures, program graphs, and abstract syntax trees.

Here each sample is a graph, and different samples may have different numbers of nodes, different edge structures, and different topology. The task is often graph-level prediction:

  • Molecular property prediction.
  • Graph classification.
  • Program behavior classification.
  • Protein function prediction.

For one graph $G$, message passing produces node embeddings:

\[H_G = \operatorname{GNN}(X_G, E_G)\]

But graph-level prediction needs one vector for the whole graph. Therefore we add a readout function:

\[h_G = \operatorname{Readout}\left(\{h_i : i \in G\}\right)\]

The readout must also be permutation-invariant. Common choices are:

  • Sum pooling.
  • Mean pooling.
  • Max pooling.
  • Attention pooling.

Then prediction is ordinary supervised learning:

\[\hat{y}_G = \operatorname{MLP}(h_G)\]

The implementation pattern is:

def graph_prediction_model(x, edge_index, batch):
    node_embeddings = gnn(x, edge_index)
    graph_embeddings = global_mean_pool(node_embeddings, batch)
    return mlp(graph_embeddings)

The batch vector records which graph each node belongs to. This lets many small graphs be packed into one disconnected graph during training. Message passing still happens only along real edges, while pooling gathers nodes back into their original graph samples.

This setting is where GNNs feel most different from ordinary neural networks. A batch is not a rectangle of equal-sized examples. It is a collection of irregular structures, and the model must learn functions that do not depend on arbitrary node ordering.

Many isomorphic graphs

The third setting is many graphs with the same topology but different signals on the nodes or edges. This is common in sensor networks, traffic networks, meshes, skeleton-based action recognition, and physical simulations.

For example, a traffic road network may have the same road graph every day, but the node features change over time: speed, density, congestion, incidents, and weather. The graph is fixed; the signal on the graph changes.

In this case, each sample can be represented as:

\[(X^{(k)}, A)\]

where $A$ is shared across samples and $X^{(k)}$ is the feature matrix for sample $k$.

This setting is close to convolution on images. In an image, the grid structure is fixed and pixel values change. In an isomorphic graph dataset, the graph structure is fixed and node values change.

A model can reuse the same adjacency matrix for every sample:

def forward(batch_x, edge_index):
    # batch_x: [batch_size, num_nodes, num_features]
    outputs = []
    for x in batch_x:
        h = gnn(x, edge_index)
        outputs.append(readout_or_node_head(h))
    return torch.stack(outputs)

For efficiency, a real implementation would batch this rather than loop in Python, but the concept is the same: topology is shared, features vary.

This distinction matters because it changes how to split data and evaluate the model. In one big graph, we often split nodes or edges. In many nonisomorphic graphs, we split graph samples. In many isomorphic graphs, we split observations over time, conditions, or subjects, while the graph structure may remain constant.

What to remember

A GNN is just a neural network layer that respects graph structure. Each node receives messages from adjacent nodes, aggregates them in an order-independent way, and updates its representation. The same idea supports several different learning problems:

  • One big graph: predict nodes or edges inside a single graph.
  • Many nonisomorphic graphs: predict labels or values for whole irregular graph samples.
  • Many isomorphic graphs: learn from changing signals on a shared graph topology.

Once this is clear, the paper tiger disappears. The hard parts are no longer the name or the notation. The hard parts are ordinary modeling questions: what should count as a node, what should count as an edge, what features are reliable, how many hops matter, and whether the evaluation split matches the real problem.

Leave a Reply