Quickstart(ドキュメントより)|Flaxの確認 #1

f:id:lib-arts:20210205193141p:plain

以前の記事でFlaxが出てきて軽く流したため、当シリーズでは詳しく取り扱いを行います。

FlaxはJAXで用いるニューラルネットワークのライブラリです。当シリーズでは基本的にはドキュメントの内容を元にある程度の概要の把握を目標とします。

f:id:lib-arts:20210205193216p:plain

Flax documentation — Flax documentation

#1ではドキュメントのQuickstartの内容を元に確認を行います。
以下、目次になります。
1. Overview(Quickstart)
2. Installation(Quickstart)
3. Examples(Quickstart)
4. Vision Transformer実装におけるFlax
5. まとめ

 

1. Overview(Quickstart)
1節ではQuickstartのOverviewの内容を取り扱います。

f:id:lib-arts:20210205193510p:plain

まずは背景の確認ですが、NumPyベースで自動微分を並列処理ができるようにしたJAXについて紹介されています。JAXについては以前の記事で簡単に取り扱いましたので、詳しくは下記を参照ください。

f:id:lib-arts:20210205194612p:plain
次にFlaxの概要の記載について確認します。FlaxはJAXにおいて高性能なニューラルネットワークを柔軟に設計することを目的として作成されたライブラリで、Neural network API、Optimizersなどが主な機能として紹介されています。

f:id:lib-arts:20210205194941p:plain

また、OverviewではCode ExamplesやTPU supportなどについても記載がされています。


2. Installation(Quickstart)
2節ではQuickstartのInstallationについて確認します。

f:id:lib-arts:20210205195223p:plain

基本的にはPyPIからpipでインストールを行うことができますが、GPUを用いる際はjaxlibやJAX readme記載の内容について取り組む必要があるとされています。

 

3. Examples(Quickstart)
3節ではQuickstartのExamplesの内容について確認します。

f:id:lib-arts:20210205195606p:plain

上記のようにいくつかのCore examplesが共有されています。基本的には有名どころの実装が紹介されていますが、以下簡易的な例としてMNISTについて簡単に確認を行ってみます。

f:id:lib-arts:20210205200105p:plain

flax/examples/mnist at master · google/flax · GitHub

FlaxにおけるMNISTの例は上記で実装が公開されています。

f:id:lib-arts:20210205201222p:plain

(中略)

f:id:lib-arts:20210205201241p:plain
実装の流れとしては、上記のmain.pyのtrain_and_evaluateを実装したtrain.pyをtrain_and_evaluate→train_epoch→train_stepと辿ることでニューラルネットワークの実装の取り扱いを確認できます。

f:id:lib-arts:20210205201608p:plain
上記のtrain_stepの実装において、CNN().apply....をlogitsとして取り扱い、このlogitsを元にクロスエントロピー誤差関数を計算しています。また、ここで計算したlossを元にoptimizerを用いて最適化を行います。

f:id:lib-arts:20210205201827p:plain
CNNの実装としては上記のようにニューラルネットワークの実装の記載がされています。

f:id:lib-arts:20210205202247p:plain

また、CNNの実装において継承したnnはflax.linenから利用していることも確認できます。

主な実装例の流れについては確認できたので3節はここまでとします。


4. Vision Transformer実装におけるFlax
4節ではここまでの内容を元に、以前の記事のVision TransformerにおけるFlaxと対応させながら確認を行います。

f:id:lib-arts:20210204195619p:plain

vision_transformer/models.py at master · google-research/vision_transformer · GitHub

まず用いているネットワークの実装やライブラリの確認ですが、VisionTransformerクラスでは3節のMNISTで実装されていたCNNクラスと概ね同様な記述方法で記載が行われていることがわかります。

f:id:lib-arts:20210205203025p:plain

とはいえ、相違点もあり、「3節で読み込んだflax.linenではなくflax.nnを読み込んでいること」や「クラスの実装が__call__ではなくapplyが用いられていること」などは注意が必要です。

f:id:lib-arts:20210205203508p:plain

flax.nn package (deprecated) — Flax documentation

まず、「flax.linenとflax.nnの違い」としては、上記にあるように「flax.nnがdeprecateされ、flax.linenに変わった」と把握しておけば良いです。この手のライブラリの仕様が変わることはよく起こりうるため、この辺は過度に気にせず必要に応じて確認を行うで問題ないかと思います。

f:id:lib-arts:20210205203957p:plain

The Flax FAQ — Flax documentation

また、__call__ではなくapplyが用いられているのは上記のようにinit_by_shapeを用いる場合はapplyメソッドを使用すると把握しておけば良さそうです。

f:id:lib-arts:20210205204148p:plain
実際にVisionTransformerの実装では上記のようにinit_by_shapeを用いてparamsに代入し、このparamsを用いてOptimizerによる最適化が記載されています。

ここまでの内容で、1節〜3節までの内容とVision Transformerの実装の基本的な点が一通り対応づけられたかと思います。


5. まとめ
#1ではFlaxのQuickstartの内容を確認しつつ、以前のVision Transformer実装におけるFlaxの利用との対応づけを行いました。
#2以降でも引き続きドキュメントの内容の確認を行えればと思います。