公式Tutorialに学ぶPyTorch③(Training a Classifier)|DeepLearningの実装 #11

f:id:lib-arts:20190620151942p:plain

連載経緯は#1をご確認ください。

#1はKeras、#2~#7まではTensorFLow、#8からはPyTorchを取り扱っています。

#8ではPyTorchの概要やインストール、簡易実行について、#9はAutograd、#10ではNeural Networkについて取り扱いました。
https://lib-arts.hatenablog.com/entry/implement_dl8
https://lib-arts.hatenablog.com/entry/implement_dl9
https://lib-arts.hatenablog.com/entry/implement_dl10
#11では引き続き"Deep Learning with PyTorch: A 60 Minute Blitz"より"Training a Classifier"について取り扱います。

Training a Classifier — PyTorch Tutorials 1.1.0 documentation
以下目次になります。
1. What about data?
2. Training an image classifier
2-1. Loading and normalizing CIFAR10
2-2. Define a Convolutional Neural Network
2-3. Define a Loss function and optimizer
2-4. Train the network
2-5. Test the network on the test data
3. まとめ


1. Training a Classifierチュートリアルについて
1節ではTraining a Classifierチュートリアルの概要について話を進めていきます。

f:id:lib-arts:20190625154853p:plain

Training a Classifier — PyTorch Tutorials 1.1.0 documentation

概要を知るにあたり簡単に要約を行います。

要約:
ニューラルネットワークの定義をして、誤差関数を計算し、ネットワークの重みを更新するというのについては見てきました。

概要がつかめたところで次に気になるのは『データはどうなのか?(What about data?)』ということです。

一般的に画像やテキスト、音声、ビデオなどのデータを取り扱う際には、NumPy配列の形式でデータをロードする標準的なPythonのパッケージを用いることができます。その後に、torch.*Tensorの形式に配列を変換できます。

・画像データに対してはPillowやOpenCVなどのパッケージが便利です。
・音声データに対してはSciPyやlibrosaがあります。
・テキストデータに関してはデフォルトのPythonやCython、NLTKやSpaCyなどが便利です。

画像データに関しては、torchvisionというパッケージがあり、ImagenetやCIFAR10、MNISTなどの一般的なデータセットのデータローダーが実装されています。また画像の変換にあたってはtorchvision.datasetsやtorch.utils.data.DataLoaderなどがあります。

これにより重要度が低いテンプレ的なコードを書くことを減らすことができ、非常に便利になります。


また、下記のようにチュートリアルではCIFAR10のデータセットを用いるとされています。

f:id:lib-arts:20190625160101p:plain
CIFAR10は飛行機、自動車、鳥、猫、鹿、犬、蛙、馬、船、トラックの10クラスのデータを持つデータセットです。データサイズは32×32×3と、28×28のMNISTと同様のサイズとなっています。写真がベースの画像のため、RGBのカラーとなるように×3がされています。
"What about data?"について大体の概要はつかめたので、次の2節では"Training an image classifier"の内容を確認していきます。


2. Training an image classifier
2節では"Training an image classifier"の内容について確認していきます。

f:id:lib-arts:20190625160827p:plain
まずは簡単に冒頭部を訳します。

要約:
以下のステップを順番に踏みます。
1) torchvisionを用いてCIFAR10の学習用と検証用のデータセットの読み込みと正規化を行う。
2) 畳み込みニューラルネットワーク(CNN)を定義する。
3) 誤差関数を定義する。
4) 学習用データを用いてニューラルネットワークを学習させる。
5) テストデータを用いてニューラルネットワークの検証を行う。

以下2-1〜2-5でそれぞれについて確認していきます。


2-1. Loading and normalizing CIFAR10
2-1ではtorchvisionを用いたCIFAR10のデータロードに関して確認していきます。チュートリアルに従って動かすことで、下記のようにデータのロード(初回時でデータが手元にない時は自動でダウンロードを行ってくれます)を行うことができます。

f:id:lib-arts:20190625161646p:plain
実行結果が上記になっています。データセットが手元にない際は上記のようにダウンロード関連の表示もされるように実装されています。この際、"The output of torchvision datasets are PILImage images of range [0, 1]. We transform them to Tensors of normalized range [-1, 1]."とも言及されているように、PILImageの形式だと[0,1]の値におさまるのに対して、変換を行なって[-1,1]の範囲になるようにしているとされています。このことは下記のように画像を表示する際などに注意が必要です。

f:id:lib-arts:20190625162406p:plain

上記のようにすることで画像やラベルの確認を行うことができます。CIFAR10は32×32×3のデータセットのため、大きく表示するとところどころがぼやけてしまうことも把握しておくと良いかと思われます。
Loadingとnormalizingについては概ねつかめたので、次の2-2では"Define a Convolutional Neural Network"について取り扱います。


2-2. Define a Convolutional Neural Network
#10で取り扱ったNeural Networkとほぼ同様のため、ほとんどソースのみの記載となっています。入力のチャネルがカラー画像のため、3チャネルにしたと書かれています。

f:id:lib-arts:20190625162919p:plain
実行結果は上記のようになります(print文だけ付け加えておきました)。


2-3. Define a Loss function and optimizer

f:id:lib-arts:20190625163102p:plain
こちらもほぼ同様なので、上記を実行する形で良さそうです。


2-4. Train the network
"Train the network"では実際に学習を行なっていきます。

f:id:lib-arts:20190625163650p:plain

記載のコードの実行を行うことで、上記のような学習を行います。誤差関数(loss)が学習ステップが進むにつれてだんだん減っているのがわかります。ニューラルネットワークの学習によるパラメータの調整は勾配などの情報を用いて繰り返し(iteration)の計算でアップデートしていくので、optimizer.stepはfor文の中で実行していることに注意です。


2-5. Test the network on the test data
2-5ではテストデータを用いたニューラルネットワークの検証を行っています。まずは全体の正答率(accuracy)を下記のように計算しています。

f:id:lib-arts:20190625164314p:plain
また、それぞれのラベルにおける正答率も下記のように得ることができます。

f:id:lib-arts:20190625164329p:plain
このようにテストデータで検証することで、汎用性について確認することができます。また、カテゴリごとの正答率を出すことで、誤分類に関する差について考察することもできます。


3. まとめ
#11ではTraining a Classifierということで、#10で取り扱ったNeural Networkを用いながらPyTorchデータのロードに関して確認してきました。

Welcome to PyTorch Tutorials — PyTorch Tutorials 1.1.0 documentation

こちらをもって"Deep Learning with PyTorch: A 60 Minute Blitz"の確認は一旦ここまでとし、#12以降は上記のチュートリアルよりまた違った題材を取り扱えればと思います。