変分推論の概要と簡単な実装例|Python実装で理解する変分推論(VariationalInference) #5

f:id:lib-arts:20201022172002p:plain

当シリーズはPython実装を通して変分推論を理解していこうということで進めています。下記などを主に参照しています。

Pattern Recognition and Machine Learning | Christopher Bishop | Springer
#3、#4ではEMアルゴリズムについて確認しました。

基本的な論理展開については大体抑えられたかと思いますので、#5からは変分推論の内容に入っていきます。
以下、今回の目次になります。
1. 変分推論の概要
2. 簡単な実装例
3. まとめ


1. 変分推論の概要
1節では変分推論の概要について簡単に取り扱います。

f:id:lib-arts:20201024183044p:plain

#4で上記を確認しましたが、EMアルゴリズムを考えるにあたって導入した確率分布qですが、変分推論では同様の数式に対し、qに対する最適化を考えます。変分推論のベースとなっている変分法では関数を入力とする汎関数に対する最適化を行いますが、変分推論では確率分布qに関する最適化を行います。

f:id:lib-arts:20201102002124p:plain

f:id:lib-arts:20201102002319p:plain

さて、汎関数の最適化にあたってのアプローチとして、参照テキストでは主に上記の2つのアプローチが紹介されています。ざっくりまとめるなら、「①確率分布のq(Z)に制約を設ける(restricted family of distributions)」、「②パラメトリックな確率分布のq(Z|ω)を設定する」の2つが紹介されています。参照テキストのSection10-1-1では①のアプローチが取られているので、以下そちらについて確認していきます。

f:id:lib-arts:20201102002745p:plain

分布の分解にあたっては、上記の式(10.5)のように、同時分布を独立した確率分布の積で表すことで、各次元の相関を取り除いています。

f:id:lib-arts:20201102003012p:plain

(10.5)式を(10.3)式に代入し、数式変形を行っているのが(10.6)式です。また、(10.7)式と(10.8)式は式の定義なので、抑えておきましょう。一見難しく見えますが、単に(10.6)式の第1項の置き換えを試みているだけですので、単にそういう風に定義して式を整理しただけと理解すれば十分です。ここで、(10.6)式が負のKLダイバージェンス(negative Kullback-Leibler divergence)を表していることに着目すると、q_{j}の最適解を得ることができます。

f:id:lib-arts:20201102004022p:plain

q_{j}の最適解は上記の(10.9)のように求めることができます。

ここまでで分布の分解を利用した変分推論の一般的な導出の流れについて確認できたかと思います。続く2節ではこの例を2次元の正規分布に適用した例の理論展開と簡単な実装を確認していきます。


2. 簡単な実装例
2節では2次元の正規分布における分布の分解を利用した変分推論について確認していきます。

f:id:lib-arts:20201102004525p:plain

まず問題設定ですが、上記のように2次元の正規分布q_{1}(z_{1})q_{2}(z_{2})のように分解して変分推論を行うとしています。

f:id:lib-arts:20201102004814p:plain

f:id:lib-arts:20201102004832p:plain

(10.9)式に基づきq_{1}(z_{1})q_{2}(z_{2})についてそれぞれ最適解を求めると上記のようになります。(10.13)式と(10.15)式を交互に用いることで推論を行っていくことができます。ここでz_{1}z_{2}の期待値はそれぞれ、その時点でのm_{1}m_{2}に該当しています。

少し数式が並んで大変なので、以下具体的にPython実装で確認していきましょう。

f:id:lib-arts:20201102005327p:plain

まず、問題設定としては上記のように2次元の正規分布を考えるとします。平均と共分散行列を与えており、共分散に0.7が入っていることで分布が傾いていることが確認できるかと思います。

f:id:lib-arts:20201102005546p:plain

続いて上記はm_{1}m_{2}に初期値の0を与えた上で、(10.13)式と(10.15)式を交互に実行することであたいの計算を行った結果です。m_{1}が0.8、m_{2}が-0.2に収束することで2次元の正規分布と同じ中心を求めることができたと確認ができるかと思います。

このように、分布の分解を用いた変分推論では、各次元に対し交互に値をの更新を行うことで元の確率分布pをよりシンプルな分布であるqで近似を行います。

 

3. まとめ
#5では変分推論の概要と、分布の分解を用いた確率分布の近似の手法、2次元の正規分布を題材にした実装の確認を行いました。詳しく式の導出についてまでは追いませんでしたが、大体の流れはつかめたかと思います。