【MXNet】写真に写っている人を数える(Faster RCNN resnet101を用いた物体検出)

2020年12月24日記事を更新しました。

はじめに

f:id:touch-sp:20201018130207j:plain
たとえば上の写真に車が何台写っているかを数えてみる。

Pythonスクリプト

import mxnet as mx
from gluoncv import model_zoo, data, utils

url = 'https://cdn-ak.f.st-hatena.com/images/fotolife/t/touch-sp/20190814/20190814122423.jpg'
filename = 'cars.jpg'
utils.download(url, filename)

net = model_zoo.get_model('faster_rcnn_fpn_resnet101_v1d_coco', pretrained=True, root='./models')
net.reset_class(['car'], reuse_weights=['car'])

x, img = data.transforms.presets.rcnn.load_test(filename)

class_IDs, scores, bounding_boxs = net(x)

count = int(mx.nd.sum(scores[0]>0.5).asscalar())    

print(count)

結果

3

うまくカウントできている。
GluonCVを使えば非常に短いスクリプトでカウントできる。
たったの11行。しかもそのうち3行は画像の準備(ダウンロード)。
importとprintを除けば実質5行でカウントできている。

次のような警告がでるが無視して問題なし。

UserWarning: Cannot decide type for the following arguments. Consider providing them as input:
        data: None
  input_sym_arg_type = in_param.infer_type()[0]

今回の動作環境

Ubuntu 18.04LTS (WSL2)
Python 3.7.5

certifi==2020.6.20
chardet==3.0.4
cycler==0.10.0
gluoncv==0.9.0b20201017
graphviz==0.8.4
idna==2.10
kiwisolver==1.2.0
matplotlib==3.3.2
mxnet==1.9.0b20201015
numpy==1.19.2
Pillow==8.0.0
pkg-resources==0.0.0
portalocker==2.0.0
pyparsing==3.0.0a2
python-dateutil==2.8.1
requests==2.24.0
scipy==1.5.3
six==1.15.0
tqdm==4.50.2
urllib3==1.25.10

インストールしたのは「mxnet」と「gluoncv」のみ。その他は勝手についてきた。

pip install mxnet==1.9.0b20201015 -f https://dist.mxnet.io/python/cpu
pip install gluoncv --pre

何がカウントできるか?

以下の3行を実行してみる。

from gluoncv import model_zoo
net = model_zoo.get_model('faster_rcnn_fpn_resnet101_v1d_coco', pretrained=True, root='./models')
print(net.classes)

すると以下のように出力される。

['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']

80種類のカウントが可能である。
Pythonスクリプトの'car'の部分を任意に変更すればよい。

net.reset_class(['car'], reuse_weights=['car'])

補足(Ubuntu18.04にPython3.7を入れる方法)

こちらを参照して下さい。
touch-sp.hatenablog.com