PyTorchベースでの処理概要の把握|Deep Graph Libraryの0.5系を確認する #1

f:id:lib-arts:20200910150153p:plain

以前のシリーズでDGL(Deep Graph Library)について確認していたのですが、0.4系から0.5系への移行に伴い色々とドキュメントなども変わっているようなので、改めて0.5系のシリーズとして確認していきます。
それほど読み込んでいるわけではないですが、"DGL at a Glance"の例を見ても以前よりもわかりやすくなっているような印象を受けます。細かいところなどは変わっていますが、ある程度は前シリーズでも取り扱っているので、当シリーズではもう少し踏み込んだ内容を中心にまとめていければと思います。

以下、今回の目次になります。
1. DGL at a Glanceの再確認
2. USER GUIDE
3. Building GNN Modules(USER GUIDE Chapter3)関連のモジュール一覧
4. PyTorchベースのGNN Modulesの利用例の確認
5. まとめ


1. DGL at a Glanceの再確認

インストールとグラフ畳み込みを用いた学習の動作例の確認①|DGL(Deep Graph Library)を動かす #1 - Liberal Art’s diary

以前の記事でも取り扱った"DGL at a Glance"ですが、少々変更点などもあるので改めて見ていきます。問題の確認は以前の記事で行なっているので、変更点などを中心に確認していきます。

DGL at a Glance — DGL 0.5.0 documentation

f:id:lib-arts:20200910150842p:plain
まず、Step.2の"Assign features to nodes or edges"ですが、以前は1-hot vectorの形式だった特徴量は、nn.Embeddingを用いて5次元のベクトルで表現を行なっています。これはBoWとWord2vecの関係と同様に、理解しておけば十分です。

f:id:lib-arts:20200910151517p:plain

特徴量の割り当てが違ったことで、Step.3の"Define a Graph Convolutional Network"も以前は"net = GCN(34, 5, 2)"だったのですが、上記のように"net = GCN(5, 5, 2)"に変更になっています。

f:id:lib-arts:20200910152040p:plain

次に、以前の記事では流しましたが、学習にあたっては半教師ありの設定(semi-supervised setting)で学習を行うとあります。ノード0の講師(instructor)とノード33のクラブ代表(club president)のみにラベル(ノード0に0、ノード33に1を与えています)を与えノード0が0にノード33が1に分類されるように学習させることで、それ以外の人物はノード0とノード33のどちらかに所属するかをソーシャルグラフに基づいて予測できるようになります。全てのノードに正解ラベルをつける教師あり学習と、全てのノードにつけない教師なし学習の中間として半教師あり(semi-supervised)とされています。

f:id:lib-arts:20200910152952p:plain

学習自体は上記のようにPyTorchライクな実装で記述を行います。基本的にはlogitsを計算してlossを考え最適化を行う流れですが、半教師ありなので"logp[labeled_nodes]"のようにStep.4でinstructorとclub presidentに設定した"labeled_nodes"を教師としてlossの計算を行なっていることには注意が必要です。

f:id:lib-arts:20200910153314p:plain
また、内部処理としては、主に3節で確認しますが、Step.3のGraphConvが対応しています。それぞれのノードに割り当てられた5次元の特徴量を図のように計算し、情報伝播(message passing)を表現しています。この詳しい処理については3節で確認します。


2. USER GUIDE
2節ではUSER GUIDEについて確認します。

f:id:lib-arts:20200910154153p:plain

User Guide — DGL 0.5.0 documentation

全て確認すると冗長になる印象のため、主要処理と思われるところを中心に確認します。

f:id:lib-arts:20200910154336p:plain

まず、Graph Neural Networksを考える上でベースになるメッセージ伝播(Message Passing)についてですが、『業務連絡』や『又聞きの噂』などのイメージで考えていただけたら良いと思います。こちらはChapter2で言及されています。直接的にエッジを持つ近傍のノード同士が情報を交換し、時間が経つにつれて全体に情報が行き渡るイメージです。基本的にはこのMessage Passing Paradigmに沿って、Graph Neural Networksは表現されており、画像におけるCNNの畳み込みに相当する処理であると理解しておいて良いかと思います。

f:id:lib-arts:20200910155228p:plain

次にChapter3ですが、Graph Neural Networksのネットワークの構築について言及されています。ここで、Graph Neural Networksの構築にあたって様々な記述を行います。この際の詳しい処理については、下記などを確認すると良いと思いますので、3節で取り扱います。

NN Modules (PyTorch) — DGL 0.5.0 documentation

f:id:lib-arts:20200910155740p:plain

Chapter5ではGraph Neural Networksの学習についてが記載されています。

解説の流れは基本的に1節で取り扱った"DGL at a Glance"と同様なので、確認はこのくらいにします。


3. Building GNN Modules(USER GUIDE Chapter3)関連のモジュール一覧

NN Modules (PyTorch) — DGL 0.5.0 documentation

3節では上記を元に、Graph Neural Networksのモジュールの一覧について把握していきます。

f:id:lib-arts:20200910160405p:plain

まずは"DGL at a Glance"でも出てきたGraphConvですが、上記のように表されています。基本的に隠れ層にパラメータを用いたMLP(Multi Layer Perceptron)処理を行なったのちに、近隣のノードの合計を取り、バイアスを加え、活性化関数の処理を行うことで次の隠れ層を生成します。この際にh_{i}^{l}h_{i}^{l+1}は時間経過のように理解するとわかりやすいと思います。たとえば1日おきにMessage Passingを行うと考えた上で、Message Passingを行うことでh_{i}^{l}からh_{i}^{l+1}になるというイメージです。詳しい処理については下記の論文を参照しています。

[1609.02907] Semi-Supervised Classification with Graph Convolutional Networks

ちなみにこの際に2日経てば、隣接するノードが隣接するノード(距離2)のノードまで情報が伝達されるイメージです。従来はこのようなMessage Passingは有向非巡回グラフ(DAG; Directed Acyclic Graph)を中心とするRNN系で表現するのが多かったですが、近年このような処理は畳み込み処理で代用できることからグラフ畳み込みとして表現する研究の方が多いようです(要確認)。

f:id:lib-arts:20200910161332p:plain

次にRelGraphConvについて確認します。こちらについては自身のノードも参照するself-loopにあたっての重みであるW_{0}について追加されたことについて把握しておくと良いと思います。処理についての詳しい背景については下記の論文を参照しているので、こちらを確認すれば良さそうです。

[1703.06103] Modeling Relational Data with Graph Convolutional Networks

他にもTAGConv、GATConv、SAGEConvなどの処理がグラフ畳み込みとしてConv Layerに実装されています。Conv Layer以外では密なグラフを取り扱うDense Graph Layerや、Poolingを行うGlobal Pooling Layerなどが紹介されています(ここでは省略します)。


4. PyTorchベースのGNN Modulesの利用例の確認

NN Modules (PyTorch) — DGL 0.5.0 documentation

4節では3節同様に上記のページから利用例の確認を行います。

f:id:lib-arts:20200910162414p:plain

まず上記はGraphConvの利用例のCase1です。ここで注意しておくのが良いのが、"feat = th.ones(6, 10)"と、"conv = GraphConv(10, 2, ...)"です。まずfeatですが、グラフのノードが0〜5の6つあることから6、特徴量が10次元だから10を指定し、"feat = th.ones(6, 10)"となっています。これを最終的に二次元にするにあたって、"conv = GraphConv(10, 2, ...)"をのように指定しています。最終結果のresはノード6つに対して、2次元の結果を保持しています。

f:id:lib-arts:20200910163034p:plain

同様にGraphConvのCase2についても確認してみます。基本的には同様ですが、u_fea、v_feaについていまいち解説がないようなので、こちらについては一旦流しておきます。


5. まとめ
#1ではDeep Graph Libraryの0.5系の内容を掴むにあたって色々と見てきました。
#2では今回の内容を踏まえてUSER GUIDEのSection5-4のGraph Classificationの例について確認していければと思います。

5.4 Graph Classification — DGL 0.5.0 documentation