JAX Quickstart(公式ドキュメントより)①|jaxの確認 #1
昨今のDeepLearningの実装ではJAXを用いるケースもあるようなので、簡単に仕様を確認できればということで当シリーズではJAXの把握を行なっていきます。#1ではまずは概要の把握をということでドキュメントのQuickstartの内容を取り扱います。
JAX Quickstart — JAX documentation
上記の確認を行います。
以下、目次になります。
1. 冒頭部の記載の把握
2. 簡単な動作例の確認
3. まとめ
1. 冒頭部の記載の把握
1節では冒頭部の記載の把握を行います。
以下簡単に要約します。
JAXは機械学習の研究において高いパフォーマンスを実現するために自動微分をCPU、GPU、TPUで動かすNumPyです。自動微分のバージョンの更新に伴い、JAXでは素のPythonやNumPyのコードで自動微分を実装することができます。JAXはPythonのループや再帰処理などを微分することができます。さらに新しい話題として、JAXはXLAを用いてNumPyコードのコンパイルを行います。
Autograd(自動微分)がNumPyにおいて実装され、XLAでパフォーマンス向上を行なっている程度の認識で良さそうです。
またインストールについては、下記で簡単に行うことができました。
$ pip install jax jaxlib
手元の環境でしか確認していないので、エラーが出た方は別途調べてみていただけたらと思います。
簡単な概要については把握できたので1節はここまでとします。
2. 簡単な動作例の確認
2節では簡単な動作例の確認について行います。まずはJAX Quickstartより下記を動かしてみます。
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import randomkey = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)
上記の実行結果は下記になります。
基本的にはNumPyと類似したインターフェースであり、上記は乱数を10個作成している例です。次に実行速度について確認してみましょう。
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready() # runs on the GPU
上記の実行結果は下記になります。
上記では3,000×3,000の行列を乱数で生成を行い、行列の積の計算のパフォーマンスを測定しています。以前の処理高速化では"%time"についてご紹介しましたが、こちらでは"%timeit"を用いて処理パフォーマンスの測定を行なっています。処理高速化の記事は下記などをご確認ください。
今回はMultiplying Matricesの内容について把握しましたが、簡単な動作例の確認にはできたと思われるので2節はここまでとします。
3. まとめ
#1ではJAX Quickstartより、冒頭部の記載の把握に加えて、"Multiplying Matrices"の実行を行いました。
続く#2では同じくJAX Quickstartより、jit、grad、vmapの内容について確認を行っていきます。