対数尤度における指数型分布族を考える|Python実装で理解する変分推論(VariationalInference) #Appendix2

f:id:lib-arts:20201022172002p:plain

#3、#4ではEMアルゴリズムについて取り扱いましたが、例として出てくる正規分布などを一般化した分布である指数型分布族についてもう少し考えておくと良いと思われたため、Appendix2としてまとめます。
以下、目次になります。
1. EMアルゴリズムにおける指数型分布族の記述について
2. 指数型分布族と対数尤度
3. まとめ


1. EMアルゴリズムにおける指数型分布族の記述について
1節ではEMアルゴリズムにおける指数型分布族の記載について簡単に確認しようと思います。

f:id:lib-arts:20201024191403p:plain

f:id:lib-arts:20201024191420p:plain

1つ目の記載は(9.29)式の後になります。「同時分布(joint distribution)のp(X|θ)におけるsum(和; Σ)によってlogarithm(log)が直接作用することを妨げる」と記載されています。混合分布の割合を表すπにより、#4の(9.14)式→(9.16)式のような計算が複雑になるということについて示唆しています。

f:id:lib-arts:20201024191433p:plain

2つ目の記載は(9.36)式の後になります。(9.36)式ではlogarithmが直接正規分布(Gaussian distribution)や指数型分布族(exponential family)に作用すると記載されています。(9.14)式の計算(主に微分)が難しい一方で、(9.36)式の計算が比較的容易であるともされています。

f:id:lib-arts:20201024191450p:plain

3つ目の記載は(9.74)式の後になります。同時分布のp(Z,X|θ)が指数型分布族やその積で表される時、logarithmは指数型分布族のexponentialと打ち消し合い、Mステップの計算が容易になると記載されています。EMアルゴリズムにおいて、Eステップはサンプルの割り当てを変え、Mステップがパラメータの最適化を行なっているので、Mステップの計算は基本的な流れでの最尤法と似たような式変形となります。なので、簡単なモデリングにおける最尤法やそれと同様のMステップにおいて、確率分布に指数型分布族を仮定することで確率分布におけるexponentialと対数尤度を考える際のlogarithmが打ち消しあうことで計算がスムーズになると理解して良いと思います。

1節では参照テキストのChapter9における指数型分布族(exponential family)の記載について確認してきましたが、「簡単なモデリングにおける対数尤度を計算するにあたって指数型分布族を確率分布として設定することで計算を簡単にすることができる」ということをここでは抑えておいてください。計算については2節で詳しく確認します。

 

2. 指数型分布族と対数尤度
2節では指数型分布族と対数尤度について確認します。

https://www.amazon.co.jp/dp/B08FYMTYBW

指数型分布族の記載については上記の2-2節に関連の式変形を載せていますので、こちらも合わせて参照ください。

まずは定義式から確認します。
P(x|\theta) = exp\{a(x)b(\theta) + c(\theta) + d(x)\}
指数型分布族は上記のように表すことのできる確率分布とされています。観測値の集合をX=\{ x_1, x_2, ..., x_n \}とした際に、尤度は同時確率をパラメータについて着目したものと考えることができるので下記のように表すことができます。
L(\theta) = P(X|\theta) = P(x_1, x_2, ..., x_n|\theta)
= P(x_1|\theta) \times P(x_2|\theta) \times ... \times P(x_n|\theta)
上記は一般的な表記ですが、ここにP(x_k|\theta) = exp\{a(x_k)b(\theta) + c(\theta) + d(x_k)\}を代入します。
L(\theta) = P(X|\theta) = exp\{a(x_1)b(\theta) + c(\theta) + d(x_1)\} \times ... \times exp\{a(x_n)b(\theta) + c(\theta) + d(x_n)\}
\displaystyle = exp \left( \sum_{k=1}^{N} \{a(x_k)b(\theta) + c(\theta) + d(x_k) \} \right)
指数型分布族に関する尤度L(θ)は上記のようになり、指数関数の中にΣがあるような式の形となります。
\displaystyle log L(\theta) = \sum_{k=1}^{N} \{a(x_k)b(\theta) + c(\theta) + d(x_k) \}
これに対する対数尤度は上記になります。なんとlogとexpが打ち消しあって、非常にシンプルな式として表記することができます。
\displaystyle l_k = a(x_k)b(\theta) + c(\theta) + d(x_k)
ここで上記のようにl_kを定義することで、\displaystyle \frac{dl_k}{d \theta}を計算した上で\displaystyle \sum_{k=1}^{N} \frac{dl_k}{d \theta} = 0を解くと、非常にスムーズに確率分布のパラメータを求めることができます。
\displaystyle log L(\theta) = \sum_{k=1}^{N} \{a(x_k)b(\theta) + c(\theta) + d(x_k) \}
指数型分布族を考えるにあたっては、上記の対数尤度の式が非常にシンプルなので、こちらを計算した上で\displaystyle l_k = a(x_k)b(\theta) + c(\theta) + d(x_k)を設定し、a〜dに用いる指数型分布族の式の形を代入するのが計算としてやりやすいと思います。
\displaystyle a(x_k)=x_k
\displaystyle b(\mu)=\frac{\mu}{\sigma^2}
\displaystyle c(\mu)=-\frac{\mu^2}{2 \sigma^2}-\frac{1}{2}log(2 \pi \sigma^2)
\displaystyle d(x_k)=-\frac{x_{k}^2}{2 \sigma^2}
例えば正規分布について、平均の\mu\thetaとして考える際はa〜dは上記のように与えることができます。

対数尤度の導出がしんどいと感じるようでしたら、先に指数型分布族の式として対数尤度まで計算した上で後からa〜dを用いるというのも計算をシンプルに行うためのテクニックとしてありだと思います。


3. まとめ
Appendix2では指数型分布族の式における対数尤度を考えることで、式をシンプルに保ったまま式変形が行えることについてまとめました。式変形が複雑だと理解できるはずの内容まで理解できなくなるので、なるべくシンプルに記述することは様々な場面で役に立つのではないかと思います。