【物体検出】MXNet 2.0(ベータ)でGluonCVの学習済みモデルを使用する

公開日:2021年10月16日
最終更新日:2022年9月14日

はじめに

前回MXNet 2.0(ベータ)でGluonCVの学習済み画像分類モデルを使用する方法を書きました。
touch-sp.hatenablog.com
今回はMXNet 2.0(ベータ)でGluonCVの学習済み物体検出モデルを使用する方法を書きます。


手順は前回と同様です。

手順

MXNet 1.xでモデルをdownloadしてexport

ここではGluonCVが必要です。

from gluoncv import model_zoo
from gluoncv.utils import export_block

net = model_zoo.get_model('yolo3_darknet53_voc', pretrained=True, root='models')
net.hybridize()

export_block('yolo3', net, preprocess=None, layout='CHW')

with open('detection_class_names.txt', 'w') as f:
    f.writelines('\n'.join(net.classes))

MXNet 2.0(ベータ)でモデルを読み込んで使用する

MXNet 2.0(ベータ)の環境にGluonCVは必要ありません。
MXNet以外にmatplotlibのインストールが必要です。

Method 1

GluonCVの「utils.viz.plot_bbox」を使わない方法です。
そのため結果表示のスクリプトが長くなってしまいます。

from mxnet import np, npx, gluon, image
from mxnet.gluon.data.vision import transforms
from matplotlib import pyplot as plt

device = npx.gpu() if npx.num_gpus() > 0 else npx.cpu()

with open('detection_class_names.txt', 'r') as f:
    classes = [x.strip() for x in f.readlines()]

url = 'https://raw.githubusercontent.com/zhreshold/mxnet-ssd/master/data/demo/person.jpg'
img = gluon.utils.download(url)
x = image.imread(img)
x = image.resize_short(x, 512)

transformer = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

net = gluon.SymbolBlock.imports("yolo3-symbol.json",['data'], "yolo3-0000.params")
net.reset_device(device)

class_IDs, scores, bounding_boxs = net(np.expand_dims(transformer(x), axis=0).to_device(device))

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.imshow(x.asnumpy())

colors = dict()
threshold = 0.8
for i, each_score in enumerate(scores[0]):
    if each_score < threshold: break
    score = '{:.3f}'.format(each_score.item())
    class_id = int(class_IDs[0][i].item())
    class_name = classes[class_id]
    if class_id not in colors:
        colors[class_id] = plt.get_cmap('hsv')(class_id / len(classes))
    xmin, ymin, xmax, ymax = [int(x) for x in bounding_boxs[0][i]]
    rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                             fill=False,
                             edgecolor=colors[class_id],
                             linewidth=3.5)
    ax.add_patch(rect)
    ax.text(xmin, ymin - 2,
                '{:s} {:s}'.format(class_name, score),
                bbox=dict(facecolor=colors[class_id], alpha=0.5),
                fontsize=12, color='white')

plt.axis('off')        
plt.show()

Method 2

MXNet 2.0(ベータ)とGluonCVは共存不可能です。
GluonCVをインポートしようとするとエラーが出ます。
最小限のファイルのみをダウンロードして無理やりGluonCVの「utils.viz.plot_bbox」を使う方法です。

from mxnet import np, npx, gluon, image
from mxnet.gluon.utils import download
from mxnet.gluon.data.vision import transforms
from matplotlib import pyplot as plt

device = npx.gpu() if npx.num_gpus() > 0 else npx.cpu()

download('https://raw.githubusercontent.com/dmlc/gluon-cv/master/gluoncv/utils/viz/bbox.py', 'viz/bbox.py')
download('https://raw.githubusercontent.com/dmlc/gluon-cv/master/gluoncv/utils/viz/image.py', 'viz/image.py')
from viz.bbox import plot_bbox

with open('detection_class_names.txt', 'r') as f:
    classes = [x.strip() for x in f.readlines()]

url = 'https://raw.githubusercontent.com/zhreshold/mxnet-ssd/master/data/demo/person.jpg'
img = download(url)
x = image.imread(img)
x = image.resize_short(x, 512)

transformer = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

net = gluon.SymbolBlock.imports("yolo3-symbol.json",['data'], "yolo3-0000.params")
net.reset_device(device)

class_IDs, scores, bounding_boxs = net(np.expand_dims(transformer(x), axis=0).to_device(device))

ax = plot_bbox(x, bounding_boxs[0], scores[0], class_IDs[0], class_names=classes)
plt.axis('off')
plt.show()

結果


環境

GPUなし

Ubuntu 20.04LTS on WSL2
Python 3.8.10
certifi==2021.10.8
charset-normalizer==2.0.7
cycler==0.11.0
graphviz==0.8.4
idna==3.3
kiwisolver==1.3.2
matplotlib==3.4.3
mxnet==2.0.0b20211105
numpy==1.21.4
Pillow==8.4.0
pkg_resources==0.0.0
pyparsing==3.0.4
python-dateutil==2.8.2
requests==2.26.0
six==1.16.0
urllib3==1.26.7

GPUあり

Ubuntu 20.04 LTS

Ubuntu 20.04LTS on WSL2
Python 3.8.10
certifi==2021.10.8
charset-normalizer==2.0.7
cycler==0.11.0
fonttools==4.28.2
graphviz==0.8.4
idna==3.3
kiwisolver==1.3.2
matplotlib==3.5.0
mxnet-cu112==2.0.0b20211121
numpy==1.21.4
packaging==21.3
Pillow==8.4.0
pkg_resources==0.0.0
pyparsing==3.0.6
python-dateutil==2.8.2
requests==2.26.0
setuptools-scm==6.3.2
six==1.16.0
tomli==1.2.2
urllib3==1.26.7

Ubuntu 22.04 LTS

Ubuntu 22.04LTS on WSL2
Python 3.10.4
certifi==2021.10.8
charset-normalizer==2.0.12
cycler==0.11.0
fonttools==4.33.3
graphviz==0.8.4
idna==3.3
kiwisolver==1.4.2
matplotlib==3.5.2
mxnet-cu112==2.0.0b20220504
numpy==1.22.3
packaging==21.3
Pillow==9.1.0
pyparsing==3.0.8
python-dateutil==2.8.2
requests==2.27.1
six==1.16.0
urllib3==1.26.9

補足

MXNet 1.xとGluonCVを使う場合は非常に簡単です。こちらを参照して下さい。
touch-sp.hatenablog.com