JAX Quickstart(公式ドキュメントより)②|jaxの確認 #2

f:id:lib-arts:20210201211216p:plain

昨今のDeepLearningの実装ではJAXを用いるケースもあるようなので、簡単に仕様を確認できればということで当シリーズではJAXの把握を行なっていきます。一旦はドキュメントのQuickstartの内容を取り扱うことにし、下記の確認を行っていきます。

JAX Quickstart — JAX documentation
#1ではQuickstartの記載より冒頭部の記載と簡単な動作例の確認を行いました。

#2では同じくQuickstartの記載より、jit、grad、vmapの内容について取り扱います。

f:id:lib-arts:20210202180200p:plain

以下、目次になります。
1. jit(to speed up functions)について
2. grad(Taking derivatives) について
3. vmap(Auto-vectorization)について
4. まとめ

 

1. jit(to speed up functions)について
1節ではjitについて取り扱います。jitは関数の高速化のために用いられるとざっくり把握しておくと良いかと思います。

f:id:lib-arts:20210202181929p:plain

上記の記載では、処理の流れをjitを用いて取り扱うと書かれており、実行例はseluにjitを用いることで、1,000,000の数字に対しseluを計算する処理速度を6倍ほどに向上させています。

また、selu関数については下記のように可視化できるので、グラフの形状も合わせて抑えておくと良さそうです。

f:id:lib-arts:20210202184339p:plain


2. grad(Taking derivatives) について
2節ではgradについて取り扱います。gradは自動微分(automatic differentiation)を取り扱っていると大まかに掴んでおくと良いかと思います。

f:id:lib-arts:20210202184852p:plain

上記の例では"sum_logistic"関数に対してgradを用いた自動微分(セル[9])と、数値的な一次の微分(セル[10]の"first_finite_differences"関数)についてそれぞれ計算を行っています。自動微分については詳しくは下記関連のシリーズをご確認いただけたらと思います。

f:id:lib-arts:20210202190203p:plain

また、上記の記載ではgradとjitの組み合わせが任意であることなどが記載されています。


3. vmap(Auto-vectorization)について
3節ではvmapについて取り扱います。vmapはvectorizing mapについて取り扱います。

f:id:lib-arts:20210202191553p:plain

(中略)

f:id:lib-arts:20210202191823p:plain

上記が具体例で、jnp.dotとvmapの比較を行っています。行列の積などの計算を取り扱っているとざっくりと把握しておくのが良さそうです。


4. まとめ
#2ではJAX Quickstartの記載より、jit、grad、vmapの内容について取り扱いました。jitが関数、gradが自動微分、vmapは行列の積などについてそれぞれ取り扱っているとざっくり把握しておくと良さそうです。また、それぞれの詳しい用法などはJAXを用いた論文などの著者実装を確認するのが、利用シーンのイメージなどもついて良いのではないかと思われます。
#3以降でも引き続きチュートリアルを読み進めていきます。