M2Detの著者実装を読み解く|物体検出(Object Detection)の研究トレンドを俯瞰する #4

f:id:lib-arts:20190709005211p:plain

物体検出の研究については以前に論文読解で、FasterRCNNやYOLO、SSD、RetinaNetについて取り扱ったのですが、改めて研究トレンドや考え方の推移についてまとめられればということで新規でシリーズを作成させていただきました。
#1ではHOG(Histograms of Oriented Gradient)[2005]からR-CNN[2013]までについて、#2ではFast R-CNN、FasterRCNN、YOLO、SSDについて、#3ではFPN、RetinaNet、M2Detについて取り扱いました。

#3までで大体の概要がつかめたので、#4以降では実装の確認について行っていければと思います。#4ではM2DetのPyTorchを用いた著者実装が公開されているので、こちらを確認していきます。

GitHub - qijiezhao/M2Det: M2Det: A Single-Shot Object Detector based on Multi-Level Feature Pyramid Network

以下目次になります。
1. リポジトリ概要
2. 読解の目標設定&PyTorchの簡単な復習
3. Neural Netのアーキテクチャの定義とlossや最適化の実装部分の確認
4. まとめ


1. リポジトリ概要
1節ではリポジトリの概要について確認していきます。

f:id:lib-arts:20190709005901p:plain

GitHub - qijiezhao/M2Det: M2Det: A Single-Shot Object Detector based on Multi-Level Feature Pyramid Network
上記が論文にもリンクがある、著者実装のリポジトリです。

f:id:lib-arts:20190709010126p:plain
以下README.mdを確認していきます。ContentsにREADME.mdの目次が載っています。概要を掴むにあたっては、Introduction、Demo、Trainingについて確認すると良さそうなので下記ではそれらを中心に確認していきます。

 

・Introduction
まずIntroductionですが、実装に関連する論文の内容を抜粋し要約した形になっています。

f:id:lib-arts:20190709010827p:plain

まず実装のモチベーション(Motivation)として、上記を見ていきます。(Feature Pyramid Networkで取り扱っている)物体の大きさの変化(scale variation)だけでなく、複雑さの変化(ACV; Appearance-Complexity Variation)についても取り扱えるようにとあります。理由として、同じ大きさの物体でも複雑さがきわめて異なることがあるということがあります。
これを解決するにあたって、Feature Pyramid Networkで取り扱っているmulti-scaleだけではなく、(ネットワークの深さを表す)multi-levelという視点(dimension)を導入したとあります。ネットワークの深い層(deeper level)がACVの大きな歩行者(pedestrian)の認識を学習するのに対し、浅い層(shallower level)がACVの小さな信号(traffic light)を学習するとされています。

f:id:lib-arts:20190709005211p:plain

また、multi-levelについては上記で表現されています。2でMLFPN(Multi-Level Multi-Scale Detector)のアーキテクチャの全容が載っています。FFMやTUMなどのそれぞれの処理の詳細についてもMethodologyでモジュール単位の処理が記載されていますが、こちらは2節で扱うとしてここでは流します。

 

・Demo
次にDemoについて確認していきます。

f:id:lib-arts:20190709012426p:plain

Demoでは物体検出のデモの実行手順について記載されています。ここで注意したいのが、実行時のオプションで与えているconfigとweightsです。vggはbackboneのネットワーク構造として有名で大元の論文でも用いられているVGGNetを意味しています。また、weightsの拡張子として与えられている.pthはPyTorchのモデルを意味しているので、学習されたモデルを指定しているということがわかります。
また、実行結果としてbounding boxで物体が囲まれているというのも確認できます。

 

・Training
Trainingでは学習にあたっての手順が記載されています。

f:id:lib-arts:20190709013038p:plain

train.pyを実行していますが、configなどは同じものを使っているので、Demo実行時と大体同じモジュールを使っていることが推測できます。

 

ここまででリポジトリの概要は一通りつかめたので1節はここまでとします。


2. 読解の目標設定&PyTorchの簡単な復習
2節では読解の目標設定とPyTorchに関しての簡単な復習を行います。まず読解にあたっては、論文記載の大まかな処理の部分について把握したいので、下記の目標を掲げます。

・Multi-Level Feature Pyramid Networkの実装の全体像の把握
・Construct the base featureの実装の把握(Methodology)
・The Multi-level Multi-scale featureの実装の把握(Methodology)
・Scale-wise Feature Aggregation Moduleの実装の把握(Methodology)
・lossの計算部分の把握
・最適化の実装部分の把握

IntroductionのMethodologyの実装部分を中心に読んでいくので、先にMethodologyを確認していきます。

f:id:lib-arts:20190709014123p:plain

まずはConstruct the base featureです。ここではFFMv1(Feature Fusion Module v1)について言及がなされています。言葉の通り特徴量(Feature)を混ぜ合わせる(Fusion)ということを意味していますが、VGGNetのconv4-3にあたる(W/4,H/4)とconv6-2にあたる(W/8,H/8)を混ぜ合わせるとあります。これによって(768,40,40)のbase featureを構築するとされています。

f:id:lib-arts:20190709014138p:plain
次にThe Multi-level Multi-scale featureを見ていきます。こちらでは主にTUM(Thinned U-shaped Module)やFFMv2について言及されています。

f:id:lib-arts:20190709014151p:plain

次にScale-wise Feature Aggregation Moduleを確認します。こちらではSFAMという特徴量の集約部分について言及されています。
ここまでのMethodologyを中心に、lossや最適化についても実装の確認を行っていければと思います。

 

大体の目標がわかったところで、簡単にPyTorchの復習をしておきます。上記目標を元に考えるのであれば基本的にモデルのアーキテクチャやlossの定義と最適化の設定についてがわかれば良さそうなので、下記を元に簡単に復習をしておきます。

基本的にはネットワークをクラスで定義し、インスタンスとして構築したネットワークを用いてlossを定義し、loss.backward()で誤差逆伝播を行い、optimizer.step()で重みの更新を行うと考えておくと良さそうです。したがって、3節ではこちらを意識した上で実装を読み解いていきます。


3. Neural Netのアーキテクチャの定義とlossや最適化の実装部分の確認
3節では実装を確認していきます。まずは学習の実装ファイルのtrain.pyを確認します。

M2Det/train.py at master · qijiezhao/M2Det · GitHub

f:id:lib-arts:20190709020238p:plain

(中略)

f:id:lib-arts:20190709020254p:plain

上記がtrain.pyのmain部分の処理になります。loss.backward()やoptimizer.step()があることからこの周辺を探すとlossの実装が確認できます。また、lossの引数に用いられているoutを見ると、net(images)が代入されているので、netを確認すると良いというのがわかります。

f:id:lib-arts:20190709020733p:plain

netは上記の34行目でインスタンス生成されています。またこの実行の関数のモジュールは10行目でインポートされています。m2det.pyのbuild_net関数を読み込んでいることが読み取れるので、m2det.pyについて確認していきます。

M2Det/m2det.py at master · qijiezhao/M2Det · GitHub

上記がm2det.pyのファイルになります。

f:id:lib-arts:20190709021117p:plain

build_net関数については、187行目以降に実装されています。194行目で返り値としてM2Detでインスタンスを返しています。ちなみにこの際のsizeの320はデフォルトの画像サイズで、FFMv1について論文で言及されている40や20はこの数字を表していると思われます。また、ここに512があるのは実行にあたって設定するconfigファイルのconfigs/m2det512_vgg.pyの512であることも同時に推測できます。

f:id:lib-arts:20190709021325p:plain

M2Detについては同一ファイルの26行目以降で実装されています。

f:id:lib-arts:20190709153505p:plain

順伝播(推論)部分はforwardで実装されているため、101行目以降を確認します。

f:id:lib-arts:20190709153522p:plain
まずは上記部分でself.baseによってbase_featsやbase_featureが生成されています。vggやresなどが条件分岐に出てきていることから、VGGNetやResNetを示しており、これがbackboneのネットワーク部分であることがわかります。

f:id:lib-arts:20190709153849p:plain

実際にself.baseの中身を確認すると上記のようにget_backboneとあるので、こちらがbackboneとなるネットワークであることがわかります。この中身まで見ると大変そうなので、一旦forwardの実装がある101行目に戻ります。

f:id:lib-arts:20190709154957p:plain

111行目のbase_featureにFFMv1(Feature Fusion Module)の実装があることがわかります。base_featsの二つのレイヤーから値を取り出し、第二引数はF.interpolate(self.up_reduce...)を用いてupsampling(特徴マップの拡大)を行っているようです。

f:id:lib-arts:20190709155244p:plain

次は115行目以降の処理です。こちらではself.num_levelsの数だけTUM(Thinned U-shaped Module)を追加しています。結果がtum_outsに随時追加されています。

f:id:lib-arts:20190709155643p:plain

tum_outsをtum_outs->sources->loc,conf->outputという流れで処理の記述を行っています。ここでlocとconfはそれぞれlocationとconfidenceでそれぞれbounding boxの座標とbounding boxの確信度を表しています。
ここまでで処理の全体概要と、Methodologyのa〜cで表されたモジュールについて確認できたので、あとはlossと最適化について確認していきます。もう一度train.pyに戻ります。

M2Det/train.py at master · qijiezhao/M2Det · GitHub

f:id:lib-arts:20190709020254p:plain

93行目でloss.backward()94行目でoptimizer.step()を実行しているので、あとはlossの実装を遡ります。88行目で計算したloss_lとloss_cを足し合わせてlossを作っていますが、これまでの流れからloss_lがlocationのloss、loss_cがconfidenceのlossであることが推測できます。ここでoutが87行目のnet(images)で定義されていることから、M2Detクラスのインスタンスであることがわかります。また、同じくcriterionに引数として与えているpriorsは処理を遡るとAnchorより作成されるデフォルトボックスの座標(SSD系のアルゴリズムはデフォルトボックスをネットワークのアウトプットを補正項とみなし補正を行います)であることが推測できます。また、targetsは遡るとnext(batch_iterator)でimageと一緒に生成されていたり、途中の処理でannotationを示唆するannoが使用されていることから、これが学習にあたっての正解データであると考えられます。

 

上記までで、2節で設定した今回の読解目標は一通りクリアできたので、3節はここまでとします。今回は詳しいところまで確認していないので「推測できます」や「思われます」など断定的な言い回しを避けましたが、ほかの可能性が考えづらいところでもあるので大まかな流れは外していないかと思われます。(だいたいconfidenceが97%くらいのイメージです。)


4. まとめ
#4ではM2Detの大まかな実装の流れについて確認しました。
だいたい把握できたので、#5以降ではFaster R-CNN、RetinaNet、Mask R-CNNなどについて実装されているDetectronリポジトリを確認していければと思います。

GitHub - facebookresearch/Detectron: FAIR's research platform for object detection research, implementing popular algorithms like Mask R-CNN and RetinaNet.