Graph Classification|Deep Graph Libraryの0.5系を確認する #2

f:id:lib-arts:20200913170721p:plain

Deep Graph Libraryの0.5系のドキュメントを読み進めています。
#1では"DGL at a Glance"や"USER GUIDE"の内容などを元にPyTorchベースでの処理概要の把握を行いました。

#2では引き続き"USER GUIDE"から、5.4のGraph Classificationについて確認できればと思います。
以下目次になります。
1. 問題設定やアルゴリズムの大枠の把握
2. 実装の確認
3. まとめ


1. 問題設定やアルゴリズムの大枠の把握
1節では問題設定やアルゴリズムの大枠の把握を行います。

f:id:lib-arts:20200913170721p:plain

タスクのざっくりとしたイメージとしては、冒頭の図にもあるように上記になります。

f:id:lib-arts:20200913171411p:plain

概要の説明としては上記のように、グラフの分類(graph classification)は入力するグラフの全体における性質(property)を予測するとされています。グラフの分類を実現するにあたっては、ノードやエッジの分類と同様にメッセージ伝播(message passing)の枠組みを使うのと同時に、グラフ全体のレベルでの特徴量を抽出することが必要になるとされています。以下、グラフ全体のレベルでの特徴量抽出について確認していきます。

f:id:lib-arts:20200913171951p:plain

まず入力にあたっては、グラフをいくつかセットにしたバッチ単位で入力するにあたって、上記のようなバッチを考えるとしています。グラフの情報は対角成分にのみ並ぶため、要素が0となっている行列の部分は考慮しないとすることで、それぞれ別のグラフとして取り扱うことができると考えているようです。タスク概要としては上記のように円(cycle)と星(star)の分類のようなタスクと理解しておけば良さそうです。

f:id:lib-arts:20200913172523p:plain

次に出力層の値を作成するReadoutの処理についてです。出力層の作成にあたって合計/平均/最大値/最小値などの値の抽出を行います。上記では平均を計算するReadout(average readout)を行うにあたって、\displaystyle h_{g} = \frac{1}{|V|} \sum_{v \in V} h_{v}という計算を行なっています。ここで計算したh_{g}などの値を元にMLP(MultiLayerPerceptron)などを経て、分類結果を出力すると考えて良いと思います。ちなみにここで、Readoutだけが紹介されており、同様のPoolingの処理が記載されていないのは、対象のグラフがあまり大きなグラフではなく、中間層での圧縮処理が必要ないためだと思われます。ReadoutとPoolingはどちらとも集約(aggregation)処理として認識しておくと良いと思います。

グラフの分類(graph classification)の問題設定やアルゴリズムについては大体つかめたと思いますので1節はここまでとします。


2. 実装の確認
2節では実装の確認について行っていきます。

f:id:lib-arts:20200913174129p:plain

まず、入力に関しては上記のようにグラフのバッチを実装すると考えるのが良いようです。実装例の解釈としては、g1は0と1のノードが連結しており、g2は0と1、1と2が連結しています。g1は2つのノード、g2は3つのノードを持つので、それぞれndata['h']には2つと3つの値を割り当てています。これはノード単位に1次元の特徴量を割り当てていると理解すると良いかと思います。また、dgl.readout_nodes(...)の処理ではそれぞれのノードに割り当てた特徴量の和の合計を計算しています。bgはg1とg2を一緒に取り扱ったバッチですが、gl.readout_nodes(...)はそれぞれグラフ単位で計算するため、g1とg2でそれぞれ特徴量の合計を計算しています。この話を(B,D)として表現しており、Bはバッチサイズ(ここではg1とg2だから2)、Dは特徴量の次元数(ここでは1)なので、(B,D)=(2,1)になっていると解釈して良いと思います。

f:id:lib-arts:20200913174907p:plain

次にニューラルネットワークの定義については上記のようになります。基本的にはPyTorchの記法そのままですが、畳み込み処理としてdgl.nn.GraphConvを用いていることがグラフ畳み込みで、下から2行目で"dgl.mean_nodes(g,'h')"としているのがグラフの分類にあたって用いているReadoutの処理です。

f:id:lib-arts:20200913175238p:plain

学習については上記のように、繰り返し処理による最適化を行っています(上記はドキュメントの記載ですが、Classifierでエラーが出たので入力層は10次元から7次元に変更したら動くようになりました)。学習にあたっては通常のPyTorchを用いた最適化処理のため、特に気にする点についてはなさそうです。


3. まとめ
#2では"USER GUIDE"から、5.4のGraph Classificationについて取り扱いました。ここでのaggregationの処理はReadoutだけでしたが、中間層での圧縮が必要な場合はPooling処理なども生じてくると考えておくのが良いかと思います。
#3ではグラフの生成モデルを考えるにあたって、Generative models of graphsについて見ていきます。

Tutorial: Generative models of graphs — DGL 0.5.0 documentation