公式Tutorialに学ぶPyTorch④(Reinforcement Learning)|DeepLearningの実装 #12
連載経緯は#1をご確認ください。
#1はKeras、#2~#7まではTensorFLow、#8からはPyTorchを取り扱っています。
#8ではPyTorchの概要やインストール、簡易実行について、#9はAutograd、#10ではNeural Network、#11ではTraining a Classifierについて取り扱いました。
公式ドキュメントやチュートリアルを元にPyTorchの概要を把握する|DeepLearningの実装 #8 - lib-arts’s diary
公式Tutorialに学ぶPyTorch①(Tutorialの全体像&Autograd)|DeepLearningの実装 #9 - lib-arts’s diary
公式Tutorialに学ぶPyTorch②(Neural Network)|DeepLearningの実装 #10 - lib-arts’s diary
#12では強化学習について取り扱っている、Reinforcement Learning Tutorialについて取り扱います。
Reinforcement Learning (DQN) Tutorial — PyTorch Tutorials 1.1.0 documentation
以下目次になります。
1. Introduction
2. Replay Memory
3. DQN algorithm
3-1. Q-network
3-2. Input extraction
4. Training
4-1. Hyperparameters and utilities
4-2. Training loop
5. まとめ
1. Introduction
まずは冒頭部について見ていきます。
Reinforcement Learning (DQN) Tutorial — PyTorch Tutorials 1.1.0 documentation
冒頭部では、今回取り扱う対象であるOpenAI Gymで実装されたCartPole問題について主にまとめられています。CartPole-v0の環境(environment)については下記の記事でルールベースの解法とともに詳しくまとめているので、こちらをご確認ください。
実装にあたっては下記のパッケージを読み込む(import)とされています。
実際の実行のコードは下記のように記載されているので、こちらを実行します。
冒頭部については主にCartPole-v0についてでしたが、以前の記事で詳しく取り扱っていたため、詳細は省略しました。
2. Replay Memory
2節ではReplay Memoryについて取り扱っていきます。これはDeep Q-Networkにおいて行われた工夫の一つであるexperience replayについてを意味しています。
上記において、色をつけた部分が重要なので、簡単に和訳しておきます。
experience replay memoryはエージェントが観測した遷移(transition)を蓄え、このデータを後で再利用できるようにする。experience replay memoryからランダムにサンプリングすることによって、学習用のバッチにおける相関をなくす(decorrelated)ことができる。このことによってDeep Q-Networkの学習手順(training procedure)の安定と改善を可能にすることが示されている。
和訳を解釈すると、Replay Memoryを用い、複数のepisode(一回のゲームの一連の実行をepisodeと呼んでいます)からランダムサンプリングすると、学習データ間の相関をなくすことができます。遷移(transition)の情報は1episode内だとどうしても前後のサンプルで相関が生まれてしまうので、これを回避するためにexperience replayが用いられています。
また、実装にあたっては、TransitionとReplayMemoryについて言及されています。実際の実装は下記のようになっています。
チュートリアルに従い上記を実行し、次の内容に移ります。
3. DQN algorithm
Deep Q-Networkのアルゴリズムについて言及されています。
https://lib-arts.hatenablog.com/entry/rl_trend6
大枠は上記で言及しているため省略します。
TD誤差であるとそれを用いて計算するHuber lossについては取り扱っていなかったので、簡単に補足しておきます。Huber lossは誤差が小さい時は二乗誤差を計算し、大きい時は絶対値を用いるというものです。これによって推定したQがnoisy(ノイズが大きい)だったとしても、安定して学習を行うことが可能になるとされています。
数式だけだとイメージがつきづらいので、下記を実行することでHuber lossについて可視化することができます。
import numpy as np
import matplotlib.pyplot as pltx = np.arange(-3,3,0.01)
y_square = (x**2)/2
y_huber = np.zeros(600)abs_x = np.abs(x)
for i in range(x.shape[0]):
if abs_x[i]>1:
y_huber[i]=abs_x[i] - 1/2
else:
y_huber[i]=(abs_x[i]**2)/2plt.plot(x,y_square,label="y_square")
plt.plot(x,y_huber, color="orange",label="y_huber")
plt.legend()
plt.show()
実行結果は下記のようになります。
上記を確認することによって、Huber lossを導入することで外れ値(outliers)の影響を受けにくくすることができます。
3-1. Q-network
3-1ではQ-Networkの実装部分について確認していきます。
上記が冒頭部の説明ですが、こちらでQ-Networkの実装の概要が言及されています。畳み込みニューラルネットワーク(CNN; Convolutional Neural Network)を用いて実装しており、状態を表すs(state)をネットワークの入力とすることで、出力としてとを得るとあります。
実装は、チュートリアルからコピーして上記のように実行できます。#9〜#11の内容をおさえていればネットワーク自体はオーソドックスな畳み込みニューラルネットワークの実装になっていることがわかります。
3-2. Input extraction
3-2ではInput extractionについて確認します。
上記によると、ゲームを実行し、環境から画像を生成するとあります。この際に画像の変換を容易に行うためのtorchvisionパッケージを利用するとあります。
(中略)
実装は、チュートリアルからコピーして上記のように実行できます。ここで注意なのが、OpenAI Gymで実装されたCartPole-v0より得られるobservationは位置、速度、Poleの角度、Poleの速度の4値なのですが、それとは別にget_screen()を用いて画像を生成しているということです。
get_screen()によって得られる画像に関しては上記のように、.shapeを出力することで、配列の大きさを確認することができます。
4. Training
4-1. Hyperparameters and utilities
4-1ではハイパーパラメータやユーティリティについて実装されています。
概要については上記で記載されています。select_actionとplot_durationsについて実装の概要が書かれています。select_actionはε-greedyに基づく行動選択、plot_durationsについては可視化について実装されています。
またハイパーパラメータについても上記のように設定されています。バッチサイズが128、割引率のが0.999で設定されていると把握しておくと良さそうです。実装についてはチュートリアルのコードをコピーして実行すれば良いです。
4-2. Training loop
4-2ではTraining loopについて取り扱います。
まず、Training loopではoptimize_modelという関数を実装しています。このoptimize_modelでは学習の1ステップの実装を行なっています。
上記ではHuber lossの実装や、最適化についてのloss.backward()やoptimizer.step()が記述されています。Huber lossを実装するにあたって引数として与えている、state_action_valueはpolicy_net(Online Deep Q-Network)からの出力、expected_state_action_valuesはtarget_net(Target Deep Q-Network)からの出力となっています。これはDeep Q-Networkの学習を安定させるために二つのDeep Q-Networkを作成するということを反映しており、それぞれを同じネットワーク構造にした上で定期的にOnline Deep Q-NetworkのパラメータをTarget Deep Q-Networkにコピーします。
(中略)
optimize_model()は上記の繰り返し(Iteration)処理の中で呼び出されています。この繰り返し処理の中で、Q-Networkのパラメータが更新されていきます。
5. まとめ
#12ではPyTorchチュートリアルのDeep Q-Networkについて取り扱いました。
続く#13,#14では同じくPyTorchチュートリアルより、Object Detection Finetuningについて取り扱えればと思います。
TorchVision 0.3 Object Detection Finetuning Tutorial — PyTorch Tutorials 1.1.0 documentation