Pyroの概要とインストール|Pyroドキュメントに学ぶ統計モデリングの実装 #1
以前のシリーズでPyMC3のチュートリアルを元にPyMC3の実装例の把握や、それに伴って階層線形モデリングなどについても確認しました。
PyMC3はMCMC法などを中心とした統計モデリングを行う上でデファクトスタンダードに近くなっているライブラリですが、少しネックがあるとするならTheanoベースだということです。Theanoはすでに開発が終了しており、これを受けてPyMC4ではTensorFlowなどに移行するという話が出ています。
移行にあたってPyMC4が出てくるタイミングでの諸々を考慮すると、一応代替のモジュールも使えるようになっておく方が望ましいのではということで、最近着目されているのがPyTorchをベースにしたPyroです。
そのため、当シリーズではPyroのドキュメントを元に統計モデリングの実装について確認していきます。
Pyro Documentation — Pyro documentation
#1ではPyroの導入として概要の確認・インストール・簡単な動作確認を行います。
以下、目次になります。
1. Pyroの概要
2. Pyroのインストール
3. 動作確認
4. まとめ
1. Pyroの概要
まず1節ではPyroの概要について、ドキュメントを元に確認します(https://pyro.ai/)。
上記によると、Pyroは「Pythonで書かれた確率プログラミング言語(PPL; Probabilistic Programming Language)で柔軟かつ表現力の高い深層確率モデリング(Deep Probabilistic Modeling)を可能にする」とされています。PyTorchをバックエンドに使用しており、近年のDeepLearningとベイジアンのモデリングを統合するともされています。また、特徴として、汎用性(Universal)、計算のスケール性(Scalable)、最小限の実装(Minimal)、自動化と制御のどちらもできる柔軟性(Flexible)、が挙げられています。
大体のイメージがつかめたので1節はここまでとします。
2. Pyroのインストール
2節ではPyroのインストールについて確認します。
インストールは上記のように行えるとされています。ここではpipを用いてインストールするということにします。
pipでインストールした場合は上記のようにコマンド実行をすることで、Pyroがインストールされているかどうかとそのバージョンについて確認することができます。
ここまででインストール自体は確認できたので、続く3節では動作確認を行なっていきます。
3. 動作確認
3節ではPyroの簡単な動作確認を行います。
An Introduction to Models in Pyro — Pyro Tutorials 1.4.0 documentation
上記の"An Introduction to Models in Pyro"を元に確認していきます。
import torch
import pyropyro.set_rng_seed(101)
loc = 0. # mean zero
scale = 1. # unit variance
normal = torch.distributions.Normal(loc, scale) # create a normal distribution object
x = normal.rsample() # draw a sample from N(0,1)
print("sample", x)
print("log prob", normal.log_prob(x)) # score the sample from N(0,1)
まず、上記のようにすることで正規分布からの値の生成を行うことができます。
実行結果は上記のようになります。ここで注意が必要なのが、pyro.set_rng_seed(101)です。正規分布などの確率分布からのサンプリングには乱数を用いているので、値に再現性を持たせるためには乱数を固定する必要があります。ここでは"101"を設定していますが、他の値を設定することもできます(他の値を設定すると他のサンプルが生成されるとだけご理解いただけたらと思います)。
x = pyro.sample("my_sample", pyro.distributions.Normal(loc, scale))
print(x)
また、ここで、上記のようにサンプリングにあたってPyTorchではなくPyroを用いることができるということも知っておくと良いです。
実行結果は上記のようになります。こちらの方が実装としてはスムーズだと思います。
もう一つ天気の例をご紹介します。
pyro.set_rng_seed(101)
def weather():
cloudy = pyro.sample('cloudy', pyro.distributions.Bernoulli(0.3))
cloudy = 'cloudy' if cloudy.item() == 1.0 else 'sunny'
mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]
scale_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]
temp = pyro.sample('temp', pyro.distributions.Normal(mean_temp, scale_temp))
return cloudy, temp.item()for _ in range(10):
print(weather())
少々変数の使用に違和感のある実装(変数cloudyにsunnyが入る可能性もある)ですが、それほど複雑でもないので確認自体は問題ないかと思います。
実行結果は上記のようになります。確率pが0.3のベルヌーイ分布に基づいて天気を決めているため、概ね3割が曇り、7割が晴れとなりますが、その通りの結果が出ています。また、temp.item()は曇りの場合は55を中心に、晴れの場合は75を中心にサンプリングを行っているので概ね納得のいく結果になっているのではないかと思います。
簡単なレベルでのPyroの実行が確認できたので今回はここまでとします。
4. まとめ
#1ではPyroの概要の把握とインストール、動作確認を行いました。
#2では引き続き、ドキュメントから"An Introduction to Inference in Pyro"の内容をご紹介します。
An Introduction to Inference in Pyro — Pyro Tutorials 1.4.0 documentation