图神经网络
图神经网络并不像它的名字听起来那么神秘。它的核心,是一种神经网络:反复让每个节点查看自己的邻居,从邻居那里收集消息,并更新自己的表示。
一个通用的消息传递层可以写作:
这里:
- $x_i^{(t)}$ 是节点 $i$ 在第 $t$ 层或第 $t$ 个时间步之后的表示。
- $mathcal{N}(i)$ 是与节点 $i$ 相邻的节点集合。
- $e_{ij}$ 是节点 $i$ 与节点 $j$ 之间可选的边特征。
- $square$ 是一个对排列不敏感的聚合算子,常见的有
sum、mean或max。
用紧凑的矩阵形式表示,许多 GNN 层都可以看作:
其中 $H^{(t)}$ 是第 $t$ 层的隐藏节点表示,$X$ 是原始节点特征矩阵,$F$ 是把特征变换与基于图的聚合混合起来的层函数。
关键思想很简单:一个节点的新嵌入应该依赖于它自己先前的嵌入,也依赖于从邻居传来的信息。堆叠多层可以让信息传播得更远。一层捕获一跳邻域信息;两层捕获最多两跳之外节点的信息;以此类推。
一个最小实现模式如下:
import torch
def gnn_layer(x, edge_index, weight):
"""One simple sum-aggregation GNN layer.
x: [num_nodes, in_features]
edge_index: [2, num_edges], where edge_index[0] -> edge_index[1]
weight: [in_features, out_features]
"""
src, dst = edge_index
messages = x[src]
aggregated = torch.zeros_like(x)
aggregated.index_add_(0, dst, messages)
h = x + aggregated
return torch.relu(h @ weight)
这不是一个完整的生产级层,但它展示了其中的运动部件:选择源节点特征,在目标节点上聚合它们,把它们与当前节点状态合并,然后应用一个可学习的变换和非线性函数。
一个单独的大图
第一种常见的 GNN 场景是一个大型图。例子包括引用网络、社交网络、商品图、知识图谱和网页图。
在这种场景中,图结构是固定的,任务通常是节点级或边级预测:
- 节点分类:预测每个节点的标签,例如引用图中的论文主题。
- 链接预测:预测两个节点之间是否应该存在一条边。
- 节点回归:为每个节点预测一个连续值。
对于节点分类,我们通常有:
- 一个特征矩阵 $X in mathbb{R}^{N times d}$,其中 $N$ 是节点数,$d$ 是输入特征数。
- 一个邻接矩阵 $A in {0,1}^{N times N}$ 或一份边列表。
- 部分节点的标签。
一个典型的图卷积层可以写作:
其中 $tilde{A} = A + I$ 加入了自环,$tilde{D}$ 是 $tilde{A}$ 的度矩阵,$W^{(t)}$ 是可学习的权重矩阵,$sigma$ 是非线性激活函数。
重点在于,每个节点的表示会与其邻居的表示一起被平滑和混合。当相连节点确实存在有意义的关系时,这很强大;但如果过度平滑让不同节点变得无法区分,它也会失败。
一个用于单个大图的简单训练循环通常如下:
for epoch in range(num_epochs):
logits = model(x, edge_index)
loss = loss_fn(logits[train_mask], y[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
这些掩码划分的是节点,而不是图。模型在训练和评估期间看到的是同一个图结构,但标签只用于训练节点。
许多非同构图
第二种场景是由许多不同图组成的数据集。例子包括分子、蛋白质结构、程序图和抽象语法树。
这里每个样本都是一张图,不同样本可能有不同的节点数量、不同的边结构和不同的拓扑。任务通常是图级预测:
- 分子性质预测。
- 图分类。
- 程序行为分类。
- 蛋白质功能预测。
对于一张图 $G$,消息传递会产生节点嵌入:
但图级预测需要为整张图得到一个向量。因此我们加入一个读出函数:
读出函数也必须对排列不敏感。常见选择包括:
- 求和池化。
- 平均池化。
- 最大池化。
- 注意力池化。
然后预测就是普通的监督学习:
实现模式如下:
def graph_prediction_model(x, edge_index, batch):
node_embeddings = gnn(x, edge_index)
graph_embeddings = global_mean_pool(node_embeddings, batch)
return mlp(graph_embeddings)
batch 向量记录每个节点属于哪一张图。这让许多小图可以在训练时被打包成一张不连通的大图。消息传递仍然只沿真实边发生,而池化会把节点重新收集回它们原本所属的图样本中。
这个场景最能体现 GNN 与普通神经网络的差异。一个批次不是一块由等大小样本组成的矩形。它是一组不规则结构,而模型必须学习不依赖任意节点顺序的函数。
许多同构图
第三种场景是许多拓扑相同、但节点或边上的信号不同的图。这在传感器网络、交通网络、网格、基于骨架的动作识别和物理仿真中很常见。
例如,一个交通道路网络每天可能有同一张道路图,但节点特征会随时间变化:速度、密度、拥堵、事故和天气。图是固定的;图上的信号在变化。
在这种情况下,每个样本可以表示为:
其中 $A$ 在样本之间共享,$X^{(k)}$ 是样本 $k$ 的特征矩阵。
这个场景接近图像上的卷积。在图像中,网格结构是固定的,像素值会变化。在同构图数据集中,图结构是固定的,节点值会变化。
模型可以为每个样本复用同一个邻接矩阵:
def forward(batch_x, edge_index):
# batch_x: [batch_size, num_nodes, num_features]
outputs = []
for x in batch_x:
h = gnn(x, edge_index)
outputs.append(readout_or_node_head(h))
return torch.stack(outputs)
为了效率,真实实现会进行批处理,而不是在 Python 中循环,但概念是一样的:拓扑共享,特征变化。
这个区别很重要,因为它会改变数据划分和模型评估方式。在一个大图中,我们常常划分节点或边。在许多非同构图中,我们划分图样本。在许多同构图中,我们按时间、条件或主体划分观测,同时图结构可能保持不变。
要记住什么
GNN 只是一个尊重图结构的神经网络层。每个节点从相邻节点接收消息,以与顺序无关的方式聚合这些消息,并更新自己的表示。同一个思想可以支持几类不同的学习问题:
- 一个大图:预测单张图内部的节点或边。
- 许多非同构图:预测整张不规则图样本的标签或数值。
- 许多同构图:从共享图拓扑上不断变化的信号中学习。
一旦这点清楚了,纸老虎就消失了。困难之处不再是名字或符号。真正困难的是普通的建模问题:什么应该算作节点,什么应该算作边,哪些特征可靠,多少跳的信息重要,以及评估划分是否匹配真实问题。
