BERTリポジトリのコードリーディング②(計算グラフの流れ)|言語処理へのDeepLearningの導入の研究トレンドを俯瞰する #8

f:id:lib-arts:20190501180404p:plain

#6ではまずサンプル実行に関して、#7では実行コードの概要について確認しました。

#8ではコードリーディングの続きとして、計算グラフの流れに着目してまとめられればと思います。
以下目次になります。

1. Input Representationの実装に関して
2. InputからOutputまでの計算グラフの流れ
3. attention_layer内の実装(Multi-Head Attention)
4. まとめ


1. Input Representationの実装に関して
まずはBERTの論文の3.2節で言及されているInput Representationについて追っていきます。

f:id:lib-arts:20190509203826p:plain
元論文では上記の図でInput Representationについて説明しています。

f:id:lib-arts:20190509205717p:plain
論文の図に対応する実装は上記になります。transformer_modelの第一引数のinput_tensorにself.embedding_outputを与えています。ということは、実際にモデルに入力するデータはself.embedding_outputの中身を紐解けば把握できることがわかります。このself.embedding_outputは、embedding_lookupでToken Embeddings(Word Embeddings)を計算し、embedding_postprocessorでPosition Embeddingsを加えています。

f:id:lib-arts:20190509205742p:plain
(中略)

f:id:lib-arts:20190509205805p:plain
上記がembedding_lookupの実装ですが、embedding_tableを取得し、tf.matmulでone_hot_input_idsに掛け合わせています。単語の分散表現(distributed representation)の計算の際に埋め込み行列(embeddings matrix)を用いて実現できることを知っていればこちらについては理解が進むかと思います。また、"word_embeddings"をnameに指定して取得するembedding_tableですが、下記の同リポジトリのやり取りを見るに事前学習されたbert_model.ckptに組み込まれているとの言及があります。また、落としてきた事前学習データのvocab.txtが単語のidに対応しているようです。

f:id:lib-arts:20190616121813p:plain

How to get the word embedding after pre-training? · Issue #60 · google-research/bert · GitHub

f:id:lib-arts:20190509205825p:plain
(中略)

f:id:lib-arts:20190509205842p:plain
(中略)

f:id:lib-arts:20190509205901p:plain
上記がembedding_postprocessorの実装ですが、token_type_tableとfull_position_embeddingsを取得し、それぞれ計算後の値をoutputに足し合わせています。このことは、Input Representationを計算するにあたって足し合わせを行なっていたことに話が繋がってきます。


2. InputからOutputまでの計算グラフの流れ
1節ではInput Representationの実装であるinput_tensorの中身について確認を行いました。2節ではinput_tensorが最終出力になるまでの全体的な流れを追うことで、Attention mechanismの実装や最終出力結果の作り方について把握したいと思います。まずはtransformer_modelの中身について確認していきます。

f:id:lib-arts:20190509211311p:plain
上記がtransformer_modelの関数ですが、第一引数のinput_tensorの次元数に着目すると、Tensorの形は[batch_size, seq_length, hidden_size]ということが書かれています。これを読むことで、Transformerへ入力する入力の形式を把握することができます。

f:id:lib-arts:20190509212359p:plain
transformer_modelの実装はなかなか込み入っているので最終的な戻り値から確認すると上記のように、do_return_all_layersで条件分岐し、全てのレイヤーを返すか最終レイヤーだけを返すのかを制御していることがわかります。

f:id:lib-arts:20190509213446p:plain
(中略)

f:id:lib-arts:20190509213619p:plain

ここで気になるのが、all_layer_outputsですが、上記のnum_hidden_layersの数だけappendを用いて層(layer_output)の追加を行なっていることを確認できます。処理の流れを全て貼ると画像が多くなって見づらいので繰り返し文の中のの変数の流れをまとめると、『attention_head->attention_output->intermediate_output->layer_output』の処理の流れになります。この処理についてはTransformerのサブレイヤーで行なっている処理だと把握しておくと良いかと思います。

f:id:lib-arts:20190506200344p:plain
ここでattention_headを出力するattention_layer関数ではMulti-Head Attentionについて実装されていますがネストが深くなりここだけ単体で読む方が良さそうなので3節で取り扱います。
次にtransformer_model関数のアウトプット後の処理について抑えていきます。

f:id:lib-arts:20190509155655p:plain
上記のように、create_modelのアウトプットが(total_loss, per_example_loss, logits, probabilities)などlossやlogitsになるので、create_modelの中身を読み解けば十分であることがわかります。

f:id:lib-arts:20190509155713p:plain

f:id:lib-arts:20190509215121p:plain
上記においてmodel.get_pooled_output()はBertModelの実装で『transformer_model->self.all_encoder_layers->sequence_output->first_token_tensor->self.pooled_output』という形で演算グラフが実装されています。
model.get_pooled_output()での値取得後の流れはlogits、probabilities、lossのどれを見てもそこまで特殊なことは行なっていないので、この辺はタスクに合わせて色々と組んでいく形になると思います。


3. attention_layer内の実装(Multi-Head Attention)
3節では2節で省略したattention_layer内の実装について取り扱っていきます。attention_layerはEncoder内部のモジュールの処理なので、この実装はMulti-Head Attentionに対応しています。

f:id:lib-arts:20190510153554p:plain
上記がMulti-Head Attentionの実装ですが、tf.matmulのところでScaled Dot-Product Attentionの計算が記述されています。具体的には701行目でQとKの掛け算、735行目でsoftmax後の計算結果とVの掛け算が行われています。


4. まとめ
#8では計算グラフの処理の流れについてより詳細に把握を行いました。
大体の流れは読めたので、#9以降はGLUEなどのタスクなどに着目してBERTについて見ていきたいと思います。