AutoGluonを使ってMMDetectionの推論が簡単に行えるよ(たったの7行)

公開日:2022年8月19日
最終更新日:2022年9月14日

はじめに

「MMDetection」は「Detection」という名前の通り物体検出モデルを使うためのツールです。

さまざまな学習済みモデルが公開されています。

以前「MMDetection」の推論を行う記事を書きました。
touch-sp.hatenablog.com
今回は「AutoGluon」から「MMDetection」を使用する方法を紹介します。


今のところはまだ推論しかできないようです。

Pythonスクリプト

モデルを指定しない場合は「yolov3_mobilenetv2_320_300e_coco」という学習済みモデルが使われます。

その際のスクリプトはサンプル画像のダウンロードを含めてたったの7行です。

学習済みモデルの事前準備(ダウンロード)は必要ありません。

from autogluon.multimodal import download, MultiModalPredictor
from mmdet.apis import show_result_pyplot

url = "https://raw.githubusercontent.com/open-mmlab/mmdetection/master/demo/demo.jpg"
mmdet_image_name = download(url)

od = MultiModalPredictor(pipeline="object_detection")

results = od.predict({"image": [mmdet_image_name]})
show_result_pyplot(od._model.model, mmdet_image_name, results[0][0], palette='coco')

結果

このような結果が表示されます。

その他の学習済みモデル

「yolov3_mobilenetv2_320_300e_coco」以外に以下のモデルなどが使えるようです。

"faster_rcnn_r50_fpn_2x_coco"
"yolov3_mobilenetv2_320_300e_coco"
"cascade_rcnn_x101_64x4d_fpn_20e_coco"
"detr_r50_8x2_150e_coco"



モデルを指定する場合には以下のようにhyperparametersを記述します。

from autogluon.multimodal import download, MultiModalPredictor
from mmdet.apis import show_result_pyplot

url = "https://raw.githubusercontent.com/open-mmlab/mmdetection/master/demo/demo.jpg"
mmdet_image_name = download(url)

od = MultiModalPredictor(
        hyperparameters={"model.mmdet_image.checkpoint_name": "detr_r50_8x2_150e_coco"},
        pipeline="object_detection"
    )

results = od.predict({"image": [mmdet_image_name]})
show_result_pyplot(od._model.model, mmdet_image_name, results[0][0], palette='coco')

環境構築

Ubuntu 20.04 on WSL2
CUDA 11.3.1
python 3.8.10



Pythonのライブラリはすべてpipで可能です。


MMDetectionを使う方法が実装されているのは今のところAutoGluonのプレビュー版のみです。プレビュー版のAutoGluonをインストールする必要があります。

pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchtext==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu113
pip install mmcv-full==1.6.1 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.12.0/index.html
pip install mmdet
pip install autogluon --pre
absl-py==1.2.0
addict==2.4.0
aiohttp==3.8.1
aiosignal==1.2.0
antlr4-python3-runtime==4.8
async-timeout==4.0.2
attrs==22.1.0
autocfg==0.0.8
autogluon==0.5.3b20220905
autogluon-contrib-nlp==0.0.1b20220208
autogluon.common==0.5.3b20220905
autogluon.core==0.5.3b20220905
autogluon.features==0.5.3b20220905
autogluon.multimodal==0.5.3b20220905
autogluon.tabular==0.5.3b20220905
autogluon.text==0.5.3b20220905
autogluon.timeseries==0.5.3b20220905
autogluon.vision==0.5.3b20220905
blis==0.7.8
boto3==1.24.66
botocore==1.27.66
cachetools==5.2.0
catalogue==2.0.8
catboost==1.0.6
certifi==2022.6.15
charset-normalizer==2.1.1
click==8.0.4
cloudpickle==2.1.0
colorama==0.4.5
commonmark==0.9.1
contextvars==2.4
convertdate==2.4.0
cycler==0.11.0
cymem==2.0.6
Cython==3.0.0a11
dask==2021.11.2
Deprecated==1.2.13
distlib==0.3.6
distributed==2021.11.2
fairscale==0.4.6
fastai==2.7.9
fastcore==1.5.24
fastdownload==0.0.7
fastprogress==1.0.3
filelock==3.8.0
flake8==5.0.4
fonttools==4.37.1
frozenlist==1.3.1
fsspec==2022.8.2
future==0.18.2
gluoncv==0.10.5.post0
gluonts==0.9.9
google-auth==2.11.0
google-auth-oauthlib==0.4.6
graphviz==0.20.1
grpcio==1.43.0
HeapDict==1.0.1
hijri-converter==2.2.4
holidays==0.15
huggingface-hub==0.9.1
hyperopt==0.2.7
idna==3.3
imageio==2.21.2
immutables==0.18
importlib-metadata==4.12.0
importlib-resources==5.9.0
Jinja2==3.1.2
jmespath==1.0.1
joblib==1.1.0
jsonschema==4.15.0
kiwisolver==1.4.4
korean-lunar-calendar==0.2.1
langcodes==3.3.0
lightgbm==3.3.2
llvmlite==0.39.1
locket==1.0.0
lxml==4.9.1
Markdown==3.4.1
MarkupSafe==2.1.1
matplotlib==3.5.3
mccabe==0.7.0
mmcv-full==1.6.1
mmdet==2.25.1
model-index==0.1.11
msgpack==1.0.4
multidict==6.0.2
murmurhash==1.0.8
networkx==2.8.6
nlpaug==1.1.10
nltk==3.7
nptyping==1.4.4
numba==0.56.2
numpy==1.22.4
oauthlib==3.2.0
omegaconf==2.1.2
opencv-python==4.6.0.66
openmim==0.2.1
ordered-set==4.1.0
packaging==21.3
pandas==1.4.4
partd==1.3.0
pathy==0.6.2
patsy==0.5.2
Pillow==9.0.1
pkg_resources==0.0.0
pkgutil_resolve_name==1.3.10
platformdirs==2.5.2
plotly==5.10.0
pmdarima==1.8.5
portalocker==2.5.1
preshed==3.0.7
protobuf==3.18.1
psutil==5.8.0
py4j==0.10.9.7
pyarrow==9.0.0
pyasn1==0.5.0rc2
pyasn1-modules==0.3.0rc1
pycocotools==2.0.4
pycodestyle==2.9.1
pydantic==1.9.2
pyDeprecate==0.3.2
pyflakes==2.5.0
Pygments==2.13.0
PyMeeus==0.5.11
pyparsing==3.0.9
pyrsistent==0.18.1
python-dateutil==2.8.2
pytorch-lightning==1.6.5
pytorch-metric-learning==1.3.2
pytz==2022.2.1
PyWavelets==1.3.0
PyYAML==6.0
ray==1.13.0
regex==2022.8.17
requests==2.28.1
requests-oauthlib==1.3.1
rich==12.5.1
rsa==4.9
s3transfer==0.6.0
sacrebleu==2.2.0
sacremoses==0.0.53
scikit-image==0.19.3
scikit-learn==1.0.2
scipy==1.7.3
sentencepiece==0.1.95
six==1.16.0
sktime==0.13.2
smart-open==5.2.1
sortedcontainers==2.4.0
spacy==3.4.1
spacy-legacy==3.0.10
spacy-loggers==1.0.3
srsly==2.4.4
statsmodels==0.13.2
tabulate==0.8.10
tbats==1.1.0
tblib==1.7.0
tenacity==8.0.1
tensorboard==2.10.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorboardX==2.5.1
terminaltables==3.1.10
thinc==8.1.0
threadpoolctl==3.1.0
tifffile==2022.8.12
timm==0.5.4
tokenizers==0.12.1
toolz==0.12.0
torch==1.12.0+cu113
torchmetrics==0.7.3
torchtext==0.13.0
torchvision==0.13.0+cu113
tornado==6.2
tqdm==4.64.1
transformers==4.20.1
typer==0.4.2
typing_extensions==4.3.0
typish==1.9.3
urllib3==1.26.12
virtualenv==20.16.4
wasabi==0.10.1
Werkzeug==2.2.2
wrapt==1.14.1
xgboost==1.4.2
yacs==0.1.8
yapf==0.32.0
yarl==1.8.1
zict==2.2.0
zipp==3.8.1

補足

すでに画像が準備できている場合には以下のスクリプト(ファイル名は「ob_exe.py」としています)を実行すると結果が表示されます。
画像名は「person.jpg」とします。

import sys 
from autogluon.multimodal import download, MultiModalPredictor
from mmdet.apis import show_result_pyplot

mmdet_image_name = sys.argv[1]

od = MultiModalPredictor(pipeline="object_detection")

results = od.predict({"image": [mmdet_image_name]})
show_result_pyplot(od._model.model, mmdet_image_name, results[0][0], palette='coco')

実行方法

python ob_exe.py person.jpg