公式Tutorialに学ぶPyTorch④(Reinforcement Learning)|DeepLearningの実装 #12

f:id:lib-arts:20190707133205j:plain

連載経緯は#1をご確認ください。

#1はKeras、#2~#7まではTensorFLow、#8からはPyTorchを取り扱っています。

#8ではPyTorchの概要やインストール、簡易実行について、#9はAutograd、#10ではNeural Network、#11ではTraining a Classifierについて取り扱いました。

公式ドキュメントやチュートリアルを元にPyTorchの概要を把握する|DeepLearningの実装 #8 - lib-arts’s diary

公式Tutorialに学ぶPyTorch①(Tutorialの全体像&Autograd)|DeepLearningの実装 #9 - lib-arts’s diary

公式Tutorialに学ぶPyTorch②(Neural Network)|DeepLearningの実装 #10 - lib-arts’s diary

#12では強化学習について取り扱っている、Reinforcement Learning Tutorialについて取り扱います。

Reinforcement Learning (DQN) Tutorial — PyTorch Tutorials 1.1.0 documentation

以下目次になります。
1. Introduction
2. Replay Memory
3. DQN algorithm
3-1. Q-network
3-2. Input extraction
4. Training
4-1. Hyperparameters and utilities
4-2. Training loop
5. まとめ


1. Introduction
まずは冒頭部について見ていきます。

f:id:lib-arts:20190707133618p:plain

Reinforcement Learning (DQN) Tutorial — PyTorch Tutorials 1.1.0 documentation
冒頭部では、今回取り扱う対象であるOpenAI Gymで実装されたCartPole問題について主にまとめられています。CartPole-v0の環境(environment)については下記の記事でルールベースの解法とともに詳しくまとめているので、こちらをご確認ください。

実装にあたっては下記のパッケージを読み込む(import)とされています。

f:id:lib-arts:20190707134515p:plain

実際の実行のコードは下記のように記載されているので、こちらを実行します。

f:id:lib-arts:20190707134533p:plain

冒頭部については主にCartPole-v0についてでしたが、以前の記事で詳しく取り扱っていたため、詳細は省略しました。


2. Replay Memory
2節ではReplay Memoryについて取り扱っていきます。これはDeep Q-Networkにおいて行われた工夫の一つであるexperience replayについてを意味しています。

f:id:lib-arts:20190707134906p:plain

上記において、色をつけた部分が重要なので、簡単に和訳しておきます。

experience replay memoryはエージェントが観測した遷移(transition)を蓄え、このデータを後で再利用できるようにする。experience replay memoryからランダムにサンプリングすることによって、学習用のバッチにおける相関をなくす(decorrelated)ことができる。このことによってDeep Q-Networkの学習手順(training procedure)の安定と改善を可能にすることが示されている。

和訳を解釈すると、Replay Memoryを用い、複数のepisode(一回のゲームの一連の実行をepisodeと呼んでいます)からランダムサンプリングすると、学習データ間の相関をなくすことができます。遷移(transition)の情報は1episode内だとどうしても前後のサンプルで相関が生まれてしまうので、これを回避するためにexperience replayが用いられています。
また、実装にあたっては、TransitionとReplayMemoryについて言及されています。実際の実装は下記のようになっています。

f:id:lib-arts:20190707140337p:plain

チュートリアルに従い上記を実行し、次の内容に移ります。


3. DQN algorithm
Deep Q-Networkのアルゴリズムについて言及されています。
https://lib-arts.hatenablog.com/entry/rl_trend6
大枠は上記で言及しているため省略します。

f:id:lib-arts:20190707142426p:plain
TD誤差である\deltaとそれを用いて計算するHuber lossについては取り扱っていなかったので、簡単に補足しておきます。Huber lossは誤差が小さい時は二乗誤差を計算し、大きい時は絶対値を用いるというものです。これによって推定したQがnoisy(ノイズが大きい)だったとしても、安定して学習を行うことが可能になるとされています。
数式だけだとイメージがつきづらいので、下記を実行することでHuber lossについて可視化することができます。

import numpy as np
import matplotlib.pyplot as plt

x = np.arange(-3,3,0.01)
y_square = (x**2)/2
y_huber = np.zeros(600)

abs_x = np.abs(x)
for i in range(x.shape[0]):
if abs_x[i]>1:
y_huber[i]=abs_x[i] - 1/2
else:
y_huber[i]=(abs_x[i]**2)/2

plt.plot(x,y_square,label="y_square")
plt.plot(x,y_huber, color="orange",label="y_huber")
plt.legend()
plt.show()

実行結果は下記のようになります。

f:id:lib-arts:20190707143337p:plain

上記を確認することによって、Huber lossを導入することで外れ値(outliers)の影響を受けにくくすることができます。


3-1. Q-network
3-1ではQ-Networkの実装部分について確認していきます。

f:id:lib-arts:20190707143540p:plain

上記が冒頭部の説明ですが、こちらでQ-Networkの実装の概要が言及されています。畳み込みニューラルネットワーク(CNN; Convolutional Neural Network)を用いて実装しており、状態を表すs(state)をネットワークの入力とすることで、出力としてQ(s,left)Q(s,right)を得るとあります。

f:id:lib-arts:20190707144104p:plain

実装は、チュートリアルからコピーして上記のように実行できます。#9〜#11の内容をおさえていればネットワーク自体はオーソドックスな畳み込みニューラルネットワークの実装になっていることがわかります。


3-2. Input extraction
3-2ではInput extractionについて確認します。

f:id:lib-arts:20190707144628p:plain

上記によると、ゲームを実行し、環境から画像を生成するとあります。この際に画像の変換を容易に行うためのtorchvisionパッケージを利用するとあります。

f:id:lib-arts:20190707144718p:plain

(中略)

f:id:lib-arts:20190707144736p:plain
実装は、チュートリアルからコピーして上記のように実行できます。ここで注意なのが、OpenAI Gymで実装されたCartPole-v0より得られるobservationは位置、速度、Poleの角度、Poleの速度の4値なのですが、それとは別にget_screen()を用いて画像を生成しているということです。

f:id:lib-arts:20190707145729p:plain

get_screen()によって得られる画像に関しては上記のように、.shapeを出力することで、配列の大きさを確認することができます。


4. Training
4-1. Hyperparameters and utilities
4-1ではハイパーパラメータやユーティリティについて実装されています。

f:id:lib-arts:20190707150006p:plain

概要については上記で記載されています。select_actionとplot_durationsについて実装の概要が書かれています。select_actionはε-greedyに基づく行動選択、plot_durationsについては可視化について実装されています。

f:id:lib-arts:20190707152718p:plain
またハイパーパラメータについても上記のように設定されています。バッチサイズが128、割引率の\gammaが0.999で設定されていると把握しておくと良さそうです。実装についてはチュートリアルのコードをコピーして実行すれば良いです。


4-2. Training loop
4-2ではTraining loopについて取り扱います。

f:id:lib-arts:20190707160414p:plain

まず、Training loopではoptimize_modelという関数を実装しています。このoptimize_modelでは学習の1ステップの実装を行なっています。

f:id:lib-arts:20190707173710p:plain

上記ではHuber lossの実装や、最適化についてのloss.backward()やoptimizer.step()が記述されています。Huber lossを実装するにあたって引数として与えている、state_action_valueはpolicy_net(Online Deep Q-Network)からの出力、expected_state_action_valuesはtarget_net(Target Deep Q-Network)からの出力となっています。これはDeep Q-Networkの学習を安定させるために二つのDeep Q-Networkを作成するということを反映しており、それぞれを同じネットワーク構造にした上で定期的にOnline Deep Q-NetworkのパラメータをTarget Deep Q-Networkにコピーします。

f:id:lib-arts:20190707174323p:plain

(中略)

f:id:lib-arts:20190707174436p:plain

optimize_model()は上記の繰り返し(Iteration)処理の中で呼び出されています。この繰り返し処理の中で、Q-Networkのパラメータが更新されていきます。


5. まとめ
#12ではPyTorchチュートリアルのDeep Q-Networkについて取り扱いました。
続く#13,#14では同じくPyTorchチュートリアルより、Object Detection Finetuningについて取り扱えればと思います。

TorchVision 0.3 Object Detection Finetuning Tutorial — PyTorch Tutorials 1.1.0 documentation