直观地、数学地、可实现地猎杀图神经网络(GNN)的纸老虎

图神经网络

图神经网络并不像它的名字听起来那么神秘。它的核心,是一种神经网络:反复让每个节点查看自己的邻居,从邻居那里收集消息,并更新自己的表示。

一个通用的消息传递层可以写作:

\[x_i^{(t)} = operatorname{Update}left( x_i^{(t-1)}, square_{j in mathcal{N}(i)} operatorname{Message}left(x_i^{(t-1)}, x_j^{(t-1)}, e_{ij}right) right)\]

这里:

  • $x_i^{(t)}$ 是节点 $i$ 在第 $t$ 层或第 $t$ 个时间步之后的表示。
  • $mathcal{N}(i)$ 是与节点 $i$ 相邻的节点集合。
  • $e_{ij}$ 是节点 $i$ 与节点 $j$ 之间可选的边特征。
  • $square$ 是一个对排列不敏感的聚合算子,常见的有 summeanmax

用紧凑的矩阵形式表示,许多 GNN 层都可以看作:

\[H^{(t+1)} = F(H^{(t)}, X)\]

其中 $H^{(t)}$ 是第 $t$ 层的隐藏节点表示,$X$ 是原始节点特征矩阵,$F$ 是把特征变换与基于图的聚合混合起来的层函数。

关键思想很简单:一个节点的新嵌入应该依赖于它自己先前的嵌入,也依赖于从邻居传来的信息。堆叠多层可以让信息传播得更远。一层捕获一跳邻域信息;两层捕获最多两跳之外节点的信息;以此类推。

一个最小实现模式如下:

import torch


def gnn_layer(x, edge_index, weight):
    """One simple sum-aggregation GNN layer.

    x:          [num_nodes, in_features]
    edge_index: [2, num_edges], where edge_index[0] -> edge_index[1]
    weight:     [in_features, out_features]
    """
    src, dst = edge_index
    messages = x[src]

    aggregated = torch.zeros_like(x)
    aggregated.index_add_(0, dst, messages)

    h = x + aggregated
    return torch.relu(h @ weight)

这不是一个完整的生产级层,但它展示了其中的运动部件:选择源节点特征,在目标节点上聚合它们,把它们与当前节点状态合并,然后应用一个可学习的变换和非线性函数。

一个单独的大图

第一种常见的 GNN 场景是一个大型图。例子包括引用网络、社交网络、商品图、知识图谱和网页图。

在这种场景中,图结构是固定的,任务通常是节点级或边级预测:

  • 节点分类:预测每个节点的标签,例如引用图中的论文主题。
  • 链接预测:预测两个节点之间是否应该存在一条边。
  • 节点回归:为每个节点预测一个连续值。

对于节点分类,我们通常有:

  • 一个特征矩阵 $X in mathbb{R}^{N times d}$,其中 $N$ 是节点数,$d$ 是输入特征数。
  • 一个邻接矩阵 $A in {0,1}^{N times N}$ 或一份边列表。
  • 部分节点的标签。

一个典型的图卷积层可以写作:

\[H^{(t+1)} = sigmaleft(tilde{D}^{-1/2}tilde{A}tilde{D}^{-1/2}H^{(t)}W^{(t)}right)\]

其中 $tilde{A} = A + I$ 加入了自环,$tilde{D}$ 是 $tilde{A}$ 的度矩阵,$W^{(t)}$ 是可学习的权重矩阵,$sigma$ 是非线性激活函数。

重点在于,每个节点的表示会与其邻居的表示一起被平滑和混合。当相连节点确实存在有意义的关系时,这很强大;但如果过度平滑让不同节点变得无法区分,它也会失败。

一个用于单个大图的简单训练循环通常如下:

for epoch in range(num_epochs):
    logits = model(x, edge_index)
    loss = loss_fn(logits[train_mask], y[train_mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

这些掩码划分的是节点,而不是图。模型在训练和评估期间看到的是同一个图结构,但标签只用于训练节点。

许多非同构图

第二种场景是由许多不同图组成的数据集。例子包括分子、蛋白质结构、程序图和抽象语法树。

这里每个样本都是一张图,不同样本可能有不同的节点数量、不同的边结构和不同的拓扑。任务通常是图级预测:

  • 分子性质预测。
  • 图分类。
  • 程序行为分类。
  • 蛋白质功能预测。

对于一张图 $G$,消息传递会产生节点嵌入:

\[H_G = operatorname{GNN}(X_G, E_G)\]

但图级预测需要为整张图得到一个向量。因此我们加入一个读出函数:

\[h_G = operatorname{Readout}left({h_i : i in G}right)\]

读出函数也必须对排列不敏感。常见选择包括:

  • 求和池化。
  • 平均池化。
  • 最大池化。
  • 注意力池化。

然后预测就是普通的监督学习:

\[hat{y}_G = operatorname{MLP}(h_G)\]

实现模式如下:

def graph_prediction_model(x, edge_index, batch):
    node_embeddings = gnn(x, edge_index)
    graph_embeddings = global_mean_pool(node_embeddings, batch)
    return mlp(graph_embeddings)

batch 向量记录每个节点属于哪一张图。这让许多小图可以在训练时被打包成一张不连通的大图。消息传递仍然只沿真实边发生,而池化会把节点重新收集回它们原本所属的图样本中。

这个场景最能体现 GNN 与普通神经网络的差异。一个批次不是一块由等大小样本组成的矩形。它是一组不规则结构,而模型必须学习不依赖任意节点顺序的函数。

许多同构图

第三种场景是许多拓扑相同、但节点或边上的信号不同的图。这在传感器网络、交通网络、网格、基于骨架的动作识别和物理仿真中很常见。

例如,一个交通道路网络每天可能有同一张道路图,但节点特征会随时间变化:速度、密度、拥堵、事故和天气。图是固定的;图上的信号在变化。

在这种情况下,每个样本可以表示为:

\[(X^{(k)}, A)\]

其中 $A$ 在样本之间共享,$X^{(k)}$ 是样本 $k$ 的特征矩阵。

这个场景接近图像上的卷积。在图像中,网格结构是固定的,像素值会变化。在同构图数据集中,图结构是固定的,节点值会变化。

模型可以为每个样本复用同一个邻接矩阵:

def forward(batch_x, edge_index):
    # batch_x: [batch_size, num_nodes, num_features]
    outputs = []
    for x in batch_x:
        h = gnn(x, edge_index)
        outputs.append(readout_or_node_head(h))
    return torch.stack(outputs)

为了效率,真实实现会进行批处理,而不是在 Python 中循环,但概念是一样的:拓扑共享,特征变化。

这个区别很重要,因为它会改变数据划分和模型评估方式。在一个大图中,我们常常划分节点或边。在许多非同构图中,我们划分图样本。在许多同构图中,我们按时间、条件或主体划分观测,同时图结构可能保持不变。

要记住什么

GNN 只是一个尊重图结构的神经网络层。每个节点从相邻节点接收消息,以与顺序无关的方式聚合这些消息,并更新自己的表示。同一个思想可以支持几类不同的学习问题:

  • 一个大图:预测单张图内部的节点或边。
  • 许多非同构图:预测整张不规则图样本的标签或数值。
  • 许多同构图:从共享图拓扑上不断变化的信号中学习。

一旦这点清楚了,纸老虎就消失了。困难之处不再是名字或符号。真正困难的是普通的建模问题:什么应该算作节点,什么应该算作边,哪些特征可靠,多少跳的信息重要,以及评估划分是否匹配真实问题。

Leave a Reply