グラフニューラルネットワーク
グラフニューラルネットワークは、その名前が示すほど神秘的なものではない。中核にあるのは、各ノードが近傍ノードを見て、そこからメッセージを集め、自分自身の表現を更新する、という処理を繰り返すニューラルネットワークである。
一般的なメッセージパッシング層は次のように書ける。
ここで、
- $x_i^{(t)}$ は、層または時刻ステップ $t$ の後のノード $i$ の表現である。
- $mathcal{N}(i)$ は、ノード $i$ に隣接するノードの集合である。
- $e_{ij}$ は、ノード $i$ とノード $j$ の間にある任意のエッジ特徴量である。
- $square$ は順序に不変な集約演算子で、一般には
sum、mean、maxが使われる。
コンパクトな行列表記では、多くの GNN 層は次のように見なせる。
ここで $H^{(t)}$ は層 $t$ における隠れノード表現、$X$ は元のノード特徴行列、$F$ は特徴変換とグラフに基づく集約を混ぜ合わせる層関数である。
要点は単純だ。ノードの新しい埋め込みは、そのノード自身の以前の埋め込みと、近傍から届く情報に依存すべきである。層を積み重ねると、情報はより遠くまで伝わる。1 層なら 1 ホップ近傍の情報を捉え、2 層なら最大 2 ホップ先のノードからの情報を捉える、という具合である。
最小限の実装パターンは次のようになる。
import torch
def gnn_layer(x, edge_index, weight):
"""One simple sum-aggregation GNN layer.
x: [num_nodes, in_features]
edge_index: [2, num_edges], where edge_index[0] -> edge_index[1]
weight: [in_features, out_features]
"""
src, dst = edge_index
messages = x[src]
aggregated = torch.zeros_like(x)
aggregated.index_add_(0, dst, messages)
h = x + aggregated
return torch.relu(h @ weight)
これは本番用の完全な層ではないが、動いている部品は示している。送信元ノードの特徴を選び、宛先ノードで集約し、現在のノード状態と組み合わせ、その後に学習可能な変換と非線形性を適用する。
1 つの大きなグラフ
最初によくある GNN の設定は、1 つの大きなグラフである。例としては、引用ネットワーク、ソーシャルネットワーク、商品グラフ、知識グラフ、Web グラフなどがある。
この設定では、グラフ構造は固定されており、タスクはたいていノードレベルまたはエッジレベルの予測である。
- ノード分類:引用グラフにおける論文トピックなど、各ノードのラベルを予測する。
- リンク予測:2 つのノードの間にエッジが存在すべきかを予測する。
- ノード回帰:各ノードについて連続値を予測する。
ノード分類では、通常は次のものがある。
- 特徴行列 $X in mathbb{R}^{N times d}$。ここで $N$ はノード数、$d$ は入力特徴量の数である。
- 隣接行列 $A in {0,1}^{N times N}$、またはエッジリスト。
- 一部のノードに対するラベル。
典型的なグラフ畳み込み層は次のように書ける。
ここで $tilde{A} = A + I$ は自己ループを追加したもの、$tilde{D}$ は $tilde{A}$ の次数行列、$W^{(t)}$ は学習可能な重み行列、$sigma$ は非線形活性化関数である。
重要なのは、各ノードの表現が、近傍ノードの表現と混ざり合いながら平滑化されるという点である。これは、接続されたノードに意味のある関係があるときには強力だが、平滑化が強すぎて異なるノードが見分けられなくなると失敗しうる。
1 つの大きなグラフに対する単純な訓練ループは、通常このようになる。
for epoch in range(num_epochs):
logits = model(x, edge_index)
loss = loss_fn(logits[train_mask], y[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
マスクが分割しているのはグラフではなくノードである。モデルは訓練時も評価時も同じグラフ構造を見るが、ラベルは訓練ノードに対してのみ使われる。
多数の非同型グラフ
2 つ目の設定は、互いに異なる多数のグラフからなるデータセットである。例としては、分子、タンパク質構造、プログラムグラフ、抽象構文木などがある。
ここでは各サンプルが 1 つのグラフであり、サンプルごとにノード数、エッジ構造、トポロジーが異なりうる。タスクはしばしばグラフレベルの予測である。
- 分子特性予測。
- グラフ分類。
- プログラム挙動分類。
- タンパク質機能予測。
1 つのグラフ $G$ について、メッセージパッシングはノード埋め込みを生成する。
しかしグラフレベルの予測には、グラフ全体に対して 1 つのベクトルが必要になる。そのため readout 関数を追加する。
readout もまた順序に不変でなければならない。よく使われる選択肢は次の通りである。
- Sum pooling。
- Mean pooling。
- Max pooling。
- Attention pooling。
その後の予測は通常の教師あり学習である。
実装パターンは次のようになる。
def graph_prediction_model(x, edge_index, batch):
node_embeddings = gnn(x, edge_index)
graph_embeddings = global_mean_pool(node_embeddings, batch)
return mlp(graph_embeddings)
batch ベクトルは、各ノードがどのグラフに属しているかを記録する。これにより、多数の小さなグラフを、訓練中に 1 つの非連結グラフとして詰め込むことができる。メッセージパッシングは実際のエッジに沿ってのみ起こり、プーリングはノードを元のグラフサンプルへと集め直す。
この設定では、GNN は通常のニューラルネットワークと最も違って感じられる。バッチは、同じサイズの例が並んだ長方形ではない。不規則な構造の集まりであり、モデルは任意のノード順序に依存しない関数を学習しなければならない。
多数の同型グラフ
3 つ目の設定は、同じトポロジーを持つが、ノードやエッジ上の信号が異なる多数のグラフである。これはセンサーネットワーク、交通ネットワーク、メッシュ、骨格ベースの行動認識、物理シミュレーションでよく見られる。
たとえば交通道路ネットワークでは、道路グラフは毎日同じかもしれないが、ノード特徴量は時間とともに変化する。速度、密度、渋滞、事故、天気などである。グラフは固定されており、グラフ上の信号が変化する。
この場合、各サンプルは次のように表せる。
ここで $A$ はサンプル間で共有され、$X^{(k)}$ はサンプル $k$ の特徴行列である。
この設定は画像に対する畳み込みに近い。画像ではグリッド構造が固定され、ピクセル値が変化する。同型グラフデータセットでは、グラフ構造が固定され、ノード値が変化する。
モデルはすべてのサンプルで同じ隣接行列を再利用できる。
def forward(batch_x, edge_index):
# batch_x: [batch_size, num_nodes, num_features]
outputs = []
for x in batch_x:
h = gnn(x, edge_index)
outputs.append(readout_or_node_head(h))
return torch.stack(outputs)
効率のためには、実際の実装では Python でループするのではなくバッチ化するだろうが、概念は同じである。トポロジーは共有され、特徴量は変化する。
この区別が重要なのは、データの分割方法とモデルの評価方法が変わるからである。1 つの大きなグラフでは、ノードやエッジを分割することが多い。多数の非同型グラフでは、グラフサンプルを分割する。多数の同型グラフでは、グラフ構造は一定のまま、時間、条件、被験者などに沿って観測を分割する。
覚えておくべきこと
GNN は、グラフ構造を尊重するニューラルネットワーク層にすぎない。各ノードは隣接ノードからメッセージを受け取り、それらを順序に依存しない方法で集約し、自分の表現を更新する。同じ考え方が、いくつかの異なる学習問題を支えている。
- 1 つの大きなグラフ:単一グラフ内のノードやエッジを予測する。
- 多数の非同型グラフ:不規則なグラフサンプル全体のラベルや値を予測する。
- 多数の同型グラフ:共有されたグラフトポロジー上で変化する信号から学習する。
これが明確になれば、張り子の虎は消える。難しいのは、名前や記法ではなくなる。難しいのは通常のモデリング上の問いである。何をノードと見なすべきか、何をエッジと見なすべきか、どの特徴量が信頼できるか、何ホップ先までが重要か、そして評価分割が実際の問題に合っているか、という問いである。
