JAX Quickstart(公式ドキュメントより)①|jaxの確認 #1

f:id:lib-arts:20210201211216p:plain

昨今のDeepLearningの実装ではJAXを用いるケースもあるようなので、簡単に仕様を確認できればということで当シリーズではJAXの把握を行なっていきます。#1ではまずは概要の把握をということでドキュメントのQuickstartの内容を取り扱います。

JAX Quickstart — JAX documentation

上記の確認を行います。
以下、目次になります。
1. 冒頭部の記載の把握
2. 簡単な動作例の確認
3. まとめ


1. 冒頭部の記載の把握
1節では冒頭部の記載の把握を行います。

f:id:lib-arts:20210201211350p:plain

以下簡単に要約します。

JAXは機械学習の研究において高いパフォーマンスを実現するために自動微分をCPU、GPU、TPUで動かすNumPyです。自動微分のバージョンの更新に伴い、JAXでは素のPythonやNumPyのコードで自動微分を実装することができます。JAXはPythonのループや再帰処理などを微分することができます。さらに新しい話題として、JAXはXLAを用いてNumPyコードのコンパイルを行います。

Autograd(自動微分)がNumPyにおいて実装され、XLAでパフォーマンス向上を行なっている程度の認識で良さそうです。

またインストールについては、下記で簡単に行うことができました。

$ pip install jax jaxlib

手元の環境でしか確認していないので、エラーが出た方は別途調べてみていただけたらと思います。

簡単な概要については把握できたので1節はここまでとします。

 

2. 簡単な動作例の確認
2節では簡単な動作例の確認について行います。まずはJAX Quickstartより下記を動かしてみます。

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

上記の実行結果は下記になります。

f:id:lib-arts:20210201213702p:plain

基本的にはNumPyと類似したインターフェースであり、上記は乱数を10個作成している例です。次に実行速度について確認してみましょう。

size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready() # runs on the GPU

上記の実行結果は下記になります。

f:id:lib-arts:20210201213954p:plain

上記では3,000×3,000の行列を乱数で生成を行い、行列の積の計算のパフォーマンスを測定しています。以前の処理高速化では"%time"についてご紹介しましたが、こちらでは"%timeit"を用いて処理パフォーマンスの測定を行なっています。処理高速化の記事は下記などをご確認ください。

今回はMultiplying Matricesの内容について把握しましたが、簡単な動作例の確認にはできたと思われるので2節はここまでとします。


3. まとめ
#1ではJAX Quickstartより、冒頭部の記載の把握に加えて、"Multiplying Matrices"の実行を行いました。
続く#2では同じくJAX Quickstartより、jit、grad、vmapの内容について確認を行っていきます。