変分推論の枠組みにおけるEMアルゴリズム|Python実装で理解する変分推論(VariationalInference) #3
当シリーズはPython実装を通して変分推論を理解していこうということで進めています。下記などを主に参照しています。
Pattern Recognition and Machine Learning | Christopher Bishop | Springer
#1、#2ではKLダイバージェンスやイェンセンの不等式について確認を行いました。
#3では変分推論の枠組みにおけるEMアルゴリズムを混合正規分布(Gaussian Mixture Models)を題材にしながら確認していきます。主にBishop[2006]のSection9.2〜10.1の内容を参照します。
以下、目次になります。
1. 基礎的な理論の確認(変分推論とEMアルゴリズム)
2. 混合正規分布とEMアルゴリズム
3. まとめ
1. 基礎的な理論の確認(変分推論とEMアルゴリズム)
1節では変分推論とEMアルゴリズムに関する基本的な理論について確認します。
変分推論を考えるにあたってまず抑えておきたいのが上記の数式です。(9.70)式におけるは対数尤度を表しており、通常の最尤推定(MLE; Maxmum Likelihood Estimation)は対数尤度をパラメータに対して最大化します。ここでは通常の最適化問題とは異なり、潜在変数(latent variables)のに対して確率分布のを導入することで、との和の形に対数尤度を変形しています。とについてはそれぞれ(9.71)と(9.72)のように表現されています。((9.70)の導出については基本的に前後の記載をそのまま用いているのでここでは省略します)
さて、ここでを考えるにあたって注意しておくと良いのが、はパラメータに関する関数であるのと同時に、確率分布に関する汎関数(functional)であるということです。汎関数は以前の記事でも述べましたが、関数を入力として値を返す関数です。(この時点で何かしらの変分法的な取り扱いが行われることがわかります。)
(9.70)式を図的なイメージで表すと上記のようになります。これは化学式などでエネルギーを考慮する際に用いる図と基本的には同じです。
EMアルゴリズムにおけるEステップの解釈は上記になります。パラメータの(混合正規分布だと、各正規分布の平均や分散に相当)を固定した上でに対してを最大化します。より具体的にはをに固定した状況におけるKLダイバージェンスを0にします。KLダイバージェンスを0にするにあたってはの中を全ての要素に対して1にする必要があるので、とを一致させる必要があります。こうすることでFigure9.12の青矢印のように固定したに関してlower boundを最大化させることができます。
一方Mステップの解釈は上記とされています。
パラメータ空間におけるEMアルゴリズムの処理概要は上記とされています。Eステップではにおいて赤と青の分布を一致させ、Mステップでは青の分布を最大にするを求め、としています。
大体の流れがつかめたので、2節では具体的に混合正規分布をここまでの流れにあてはめて理解していきます。
2. 混合正規分布とEMアルゴリズム
2節では混合正規分布を題材にEMアルゴリズムについて確認していきます。
まず、一般的な意味合いでのEMアルゴリズムの記載は上記のようになります。Mステップにおいて(9.33)式のはサンプルの負担率(responsibility)を固定した上での対数尤度と同様の役割をしています。一方、Eステップではを固定した上で負担率(responsibility)の再計算を行っています。
また、混合正規分布におけるEMアルゴリズムは上記のようになります。Mステップは負担率を固定した上で、パラメータに尤度の最大化を行うことで導出しています。
記載だけで理解するのは大変なので、以下、実装を通してイメージを掴みます。
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
%matplotlib inlinex_1 = norm.rvs(loc=0, scale=1, size=600, random_state=None)
x_2 = norm.rvs(loc=-3, scale=1, size=300, random_state=None)
x_3 = norm.rvs(loc=5, scale=2, size=100, random_state=None)
observed_x = np.r_[x_1, x_2, x_3]plt.hist(observed_x, bins=20)
plt.show()
問題設定については上記のようなヒストグラムの観測が得られた際の背景にある混合分布の推定を行います。この分布は下記の混合正規分布からサンプリングしたのと同様な結果になります(累積分布関数の逆関数からサンプリングする考え方もありますが、簡易化のためこのようなサンプリングとしました)。
上記でイメージがつきやすいかと思います。
上記のように初期値を与え、下記を実行します(簡易化のため、分散は1で固定しました)。
pi = pi_init
mu = mu_init
sigma = sigma_init
gamma = np.zeros([observed_x.shape[0], 3])
for i in range(10):
# E-step
for j in range(observed_x.shape[0]):
for k in range(3):
gamma[j, k] = pi[k] * norm.pdf(observed_x[j], loc=mu[k], scale=sigma[k])
gamma[j, :] = gamma[j, :]/np.sum(gamma[j, :])
# M-step
N_k = np.zeros([3])
for k in range(3):
N_k[k] = np.sum(gamma[:, k])
pi[k] = N_k[k]/1000.
mu[k] = np.sum(gamma[:, k] * observed_x)/N_k[k]
print(N_k)
print(mu)
print("====")
実行結果を確認すると混合分布の推定が概ねうまくいっていることがわかります。
3. まとめ
#3では変分推論を念頭に置きながらEMアルゴリズムについて確認を行いました。
続く#4ではEMの全体像について確認し、論旨の流れを再度確認できればと思います。