グラフニューラルネットワーク(GNN)という張り子の虎を追う:直感的に、数学的に、実装可能に

グラフニューラルネットワーク

グラフニューラルネットワークは、その名前が示すほど神秘的なものではない。中核にあるのは、各ノードが近傍ノードを見て、そこからメッセージを集め、自分自身の表現を更新する、という処理を繰り返すニューラルネットワークである。

一般的なメッセージパッシング層は次のように書ける。

\[x_i^{(t)} = operatorname{Update}left( x_i^{(t-1)}, square_{j in mathcal{N}(i)} operatorname{Message}left(x_i^{(t-1)}, x_j^{(t-1)}, e_{ij}right) right)\]

ここで、

  • $x_i^{(t)}$ は、層または時刻ステップ $t$ の後のノード $i$ の表現である。
  • $mathcal{N}(i)$ は、ノード $i$ に隣接するノードの集合である。
  • $e_{ij}$ は、ノード $i$ とノード $j$ の間にある任意のエッジ特徴量である。
  • $square$ は順序に不変な集約演算子で、一般には summeanmax が使われる。

コンパクトな行列表記では、多くの GNN 層は次のように見なせる。

\[H^{(t+1)} = F(H^{(t)}, X)\]

ここで $H^{(t)}$ は層 $t$ における隠れノード表現、$X$ は元のノード特徴行列、$F$ は特徴変換とグラフに基づく集約を混ぜ合わせる層関数である。

要点は単純だ。ノードの新しい埋め込みは、そのノード自身の以前の埋め込みと、近傍から届く情報に依存すべきである。層を積み重ねると、情報はより遠くまで伝わる。1 層なら 1 ホップ近傍の情報を捉え、2 層なら最大 2 ホップ先のノードからの情報を捉える、という具合である。

最小限の実装パターンは次のようになる。

import torch


def gnn_layer(x, edge_index, weight):
    """One simple sum-aggregation GNN layer.

    x:          [num_nodes, in_features]
    edge_index: [2, num_edges], where edge_index[0] -> edge_index[1]
    weight:     [in_features, out_features]
    """
    src, dst = edge_index
    messages = x[src]

    aggregated = torch.zeros_like(x)
    aggregated.index_add_(0, dst, messages)

    h = x + aggregated
    return torch.relu(h @ weight)

これは本番用の完全な層ではないが、動いている部品は示している。送信元ノードの特徴を選び、宛先ノードで集約し、現在のノード状態と組み合わせ、その後に学習可能な変換と非線形性を適用する。

1 つの大きなグラフ

最初によくある GNN の設定は、1 つの大きなグラフである。例としては、引用ネットワーク、ソーシャルネットワーク、商品グラフ、知識グラフ、Web グラフなどがある。

この設定では、グラフ構造は固定されており、タスクはたいていノードレベルまたはエッジレベルの予測である。

  • ノード分類:引用グラフにおける論文トピックなど、各ノードのラベルを予測する。
  • リンク予測:2 つのノードの間にエッジが存在すべきかを予測する。
  • ノード回帰:各ノードについて連続値を予測する。

ノード分類では、通常は次のものがある。

  • 特徴行列 $X in mathbb{R}^{N times d}$。ここで $N$ はノード数、$d$ は入力特徴量の数である。
  • 隣接行列 $A in {0,1}^{N times N}$、またはエッジリスト。
  • 一部のノードに対するラベル。

典型的なグラフ畳み込み層は次のように書ける。

\[H^{(t+1)} = sigmaleft(tilde{D}^{-1/2}tilde{A}tilde{D}^{-1/2}H^{(t)}W^{(t)}right)\]

ここで $tilde{A} = A + I$ は自己ループを追加したもの、$tilde{D}$ は $tilde{A}$ の次数行列、$W^{(t)}$ は学習可能な重み行列、$sigma$ は非線形活性化関数である。

重要なのは、各ノードの表現が、近傍ノードの表現と混ざり合いながら平滑化されるという点である。これは、接続されたノードに意味のある関係があるときには強力だが、平滑化が強すぎて異なるノードが見分けられなくなると失敗しうる。

1 つの大きなグラフに対する単純な訓練ループは、通常このようになる。

for epoch in range(num_epochs):
    logits = model(x, edge_index)
    loss = loss_fn(logits[train_mask], y[train_mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

マスクが分割しているのはグラフではなくノードである。モデルは訓練時も評価時も同じグラフ構造を見るが、ラベルは訓練ノードに対してのみ使われる。

多数の非同型グラフ

2 つ目の設定は、互いに異なる多数のグラフからなるデータセットである。例としては、分子、タンパク質構造、プログラムグラフ、抽象構文木などがある。

ここでは各サンプルが 1 つのグラフであり、サンプルごとにノード数、エッジ構造、トポロジーが異なりうる。タスクはしばしばグラフレベルの予測である。

  • 分子特性予測。
  • グラフ分類。
  • プログラム挙動分類。
  • タンパク質機能予測。

1 つのグラフ $G$ について、メッセージパッシングはノード埋め込みを生成する。

\[H_G = operatorname{GNN}(X_G, E_G)\]

しかしグラフレベルの予測には、グラフ全体に対して 1 つのベクトルが必要になる。そのため readout 関数を追加する。

\[h_G = operatorname{Readout}left({h_i : i in G}right)\]

readout もまた順序に不変でなければならない。よく使われる選択肢は次の通りである。

  • Sum pooling。
  • Mean pooling。
  • Max pooling。
  • Attention pooling。

その後の予測は通常の教師あり学習である。

\[hat{y}_G = operatorname{MLP}(h_G)\]

実装パターンは次のようになる。

def graph_prediction_model(x, edge_index, batch):
    node_embeddings = gnn(x, edge_index)
    graph_embeddings = global_mean_pool(node_embeddings, batch)
    return mlp(graph_embeddings)

batch ベクトルは、各ノードがどのグラフに属しているかを記録する。これにより、多数の小さなグラフを、訓練中に 1 つの非連結グラフとして詰め込むことができる。メッセージパッシングは実際のエッジに沿ってのみ起こり、プーリングはノードを元のグラフサンプルへと集め直す。

この設定では、GNN は通常のニューラルネットワークと最も違って感じられる。バッチは、同じサイズの例が並んだ長方形ではない。不規則な構造の集まりであり、モデルは任意のノード順序に依存しない関数を学習しなければならない。

多数の同型グラフ

3 つ目の設定は、同じトポロジーを持つが、ノードやエッジ上の信号が異なる多数のグラフである。これはセンサーネットワーク、交通ネットワーク、メッシュ、骨格ベースの行動認識、物理シミュレーションでよく見られる。

たとえば交通道路ネットワークでは、道路グラフは毎日同じかもしれないが、ノード特徴量は時間とともに変化する。速度、密度、渋滞、事故、天気などである。グラフは固定されており、グラフ上の信号が変化する。

この場合、各サンプルは次のように表せる。

\[(X^{(k)}, A)\]

ここで $A$ はサンプル間で共有され、$X^{(k)}$ はサンプル $k$ の特徴行列である。

この設定は画像に対する畳み込みに近い。画像ではグリッド構造が固定され、ピクセル値が変化する。同型グラフデータセットでは、グラフ構造が固定され、ノード値が変化する。

モデルはすべてのサンプルで同じ隣接行列を再利用できる。

def forward(batch_x, edge_index):
    # batch_x: [batch_size, num_nodes, num_features]
    outputs = []
    for x in batch_x:
        h = gnn(x, edge_index)
        outputs.append(readout_or_node_head(h))
    return torch.stack(outputs)

効率のためには、実際の実装では Python でループするのではなくバッチ化するだろうが、概念は同じである。トポロジーは共有され、特徴量は変化する。

この区別が重要なのは、データの分割方法とモデルの評価方法が変わるからである。1 つの大きなグラフでは、ノードやエッジを分割することが多い。多数の非同型グラフでは、グラフサンプルを分割する。多数の同型グラフでは、グラフ構造は一定のまま、時間、条件、被験者などに沿って観測を分割する。

覚えておくべきこと

GNN は、グラフ構造を尊重するニューラルネットワーク層にすぎない。各ノードは隣接ノードからメッセージを受け取り、それらを順序に依存しない方法で集約し、自分の表現を更新する。同じ考え方が、いくつかの異なる学習問題を支えている。

  • 1 つの大きなグラフ:単一グラフ内のノードやエッジを予測する。
  • 多数の非同型グラフ:不規則なグラフサンプル全体のラベルや値を予測する。
  • 多数の同型グラフ:共有されたグラフトポロジー上で変化する信号から学習する。

これが明確になれば、張り子の虎は消える。難しいのは、名前や記法ではなくなる。難しいのは通常のモデリング上の問いである。何をノードと見なすべきか、何をエッジと見なすべきか、どの特徴量が信頼できるか、何ホップ先までが重要か、そして評価分割が実際の問題に合っているか、という問いである。

Leave a Reply