【PyTorch】WSL2でdetectron2を使ってみる

はじめに

Meta Research(Facebook Reserchから改名)が開発しているdetectron2を使ってみます。

Meta ResearchはPyTorchそのものも開発しているので本家のComputer Visionシステムと言ってもいいと思います。

環境

Ubuntu 20.04 LTS on WSL2 (Windows 11) 
RTX 3060 Laptop
CUDA Toolkit 11.3

WSL2とCUDAの設定に関してはこちらを参照して下さい。
touch-sp.hatenablog.com

detectron2のインストール

pipのみで可能でした。

pip install torch==1.10.0+cu113 torchvision==0.11.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
pip install opencv-python
pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html

detectron2のインストールの途中でエラーが出るならおそらくpycocotoolsのインストールの部分だと思います。


こちらを参照して下さい。
touch-sp.hatenablog.com
最終的なバージョンはこの記事の最後に書いています。


後述するPythonスクリプトを実行した時にエラーは出ないけど画像が表示されないといった状況になればこちらを参照して下さい。
touch-sp.hatenablog.com

Object Detection(物体検出)

Pythonスクリプト

import cv2
from torchvision.datasets.utils import download_url
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog

img_url = 'https://raw.githubusercontent.com/zhreshold/mxnet-ssd/master/data/demo/person.jpg'
img_fname = img_url.split('/')[-1]
download_url(img_url, root = '.', filename = img_fname)
im = cv2.imread(img_fname)

model = 'COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml'
cfg = model_zoo.get_config(model, trained=True)
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7

predictor = DefaultPredictor(cfg)
outputs = predictor(im)

v = Visualizer(im, MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
img_array = v.get_image()

cv2.imshow ('demo', img_array)
cv2.waitKey(0)
cv2.destroyAllWindows()

モデルの取得を以下のように書いているサイトが多いです。

model = 'COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml'
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(model))
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(model)

しかしこの部分は以下のように簡潔に書けます。

model = 'COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml'
cfg = model_zoo.get_config(model, trained=True)

またBGR→RGB→BGRの変換を行っているサイトも多いです。
結局元に戻るので不要です。矩形や文字の色は変わってしまいますが。
以下の[:, :, ::-1]の部分です。

v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
img_array = v.get_image()[:, :, ::-1]

結果

f:id:touch-sp:20211114151411p:plain:w500

Segmentation

Pythonスクリプト

import cv2
from torchvision.datasets.utils import download_url
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog

img_url = 'https://raw.githubusercontent.com/zhreshold/mxnet-ssd/master/data/demo/person.jpg'
img_fname = img_url.split('/')[-1]
download_url(img_url, root = '.', filename = img_fname)
im = cv2.imread(img_fname)

model = 'COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml'
cfg = model_zoo.get_config(model, trained=True)
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7

predictor = DefaultPredictor(cfg)
outputs = predictor(im)

v = Visualizer(im, MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
img_array = v.get_image()

cv2.imshow ('demo', img_array)
cv2.waitKey(0)
cv2.destroyAllWindows()

結果

f:id:touch-sp:20211115092300p:plain:w500

Keypoint Detection

Pythonスクリプト

import cv2
from torchvision.datasets.utils import download_url
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog

img_url = 'http://images.cocodataset.org/val2017/000000458045.jpg'
img_fname = img_url.split('/')[-1]
download_url(img_url, root = '.', filename = img_fname)
im = cv2.imread(img_fname)

model = 'COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml'
cfg = model_zoo.get_config(model, trained=True)
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7

predictor = DefaultPredictor(cfg)
outputs = predictor(im)

v = Visualizer(im, MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
img_array = v.get_image()

cv2.imshow ('demo', img_array)
cv2.waitKey(0)
cv2.destroyAllWindows()

結果

f:id:touch-sp:20211115093701p:plain:w500

参考にさせて頂いたサイト

yamayou-1.hatenablog.com

モジュールのバージョン

absl-py==1.0.0
antlr4-python3-runtime==4.8
appdirs==1.4.4
black==21.4b2
cachetools==4.2.4
certifi==2021.10.8
charset-normalizer==2.0.7
click==8.0.3
cloudpickle==2.0.0
cycler==0.11.0
Cython==0.29.24
detectron2==0.6+cu113
future==0.18.2
fvcore==0.1.5.post20211023
google-auth==2.3.3
google-auth-oauthlib==0.4.6
grpcio==1.41.1
hydra-core==1.1.1
idna==3.3
importlib-resources==5.4.0
iopath==0.1.9
kiwisolver==1.3.2
Markdown==3.3.4
matplotlib==3.4.3
mypy-extensions==0.4.3
numpy==1.21.4
oauthlib==3.1.1
omegaconf==2.1.1
opencv-python==4.5.4.58
pathspec==0.9.0
Pillow==8.4.0
pkg_resources==0.0.0
portalocker==2.3.2
protobuf==3.19.1
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycocotools==2.0.2
pydot==1.4.2
pyparsing==3.0.6
python-dateutil==2.8.2
PyYAML==6.0
regex==2021.11.10
requests==2.26.0
requests-oauthlib==1.3.0
rsa==4.7.2
six==1.16.0
tabulate==0.8.9
tensorboard==2.7.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
termcolor==1.1.0
toml==0.10.2
torch==1.10.0+cu113
torchvision==0.11.1+cu113
tqdm==4.62.3
typing-extensions==3.10.0.2
urllib3==1.26.7
Werkzeug==2.0.2
yacs==0.1.8
zipp==3.6.0