PyTorchに入門してみる part2 GluonCVでは使えなかった物体検出モデル DETR を使ってみる

はじめに

前回に引き続き学習済みモデルを使いながらPyTorchに慣れていきます。


せっかくなのでGluonCVでは使えなかった比較的新しいモデルを使ってみようと思います。
github.com

環境

GPUがある環境とない環境のふたつで動作確認しました。

GPUあり

Windows 11 
Core i7-7700K + GTX 1080

Python 3.8.10

Pythonの新しい仮想環境を作り以下のように必要なものをインストールしました。引っかかるところは幸いありませんでした。

pip install torch==1.9.1+cu102 torchvision==0.10.1+cu102 torchaudio===0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
pip install matplotlib
pip install scipy
pip install packaging

このようになりました。

cycler==0.10.0
kiwisolver==1.3.2
matplotlib==3.4.3
numpy==1.21.2
packaging==21.0
Pillow==8.4.0
pyparsing==2.4.7
python-dateutil==2.8.2
scipy==1.7.1
six==1.16.0
torch==1.9.1+cu102
torchaudio==0.9.1
torchvision==0.10.1+cu102
typing-extensions==3.10.0.2

GPUなし

Windows 10
Core i7-1165G7

Python 3.8.10

Pythonの新しい仮想環境を作り以下のように必要なものをインストールしました。引っかかるところは幸いありませんでした。

pip install torch torchvision torchaudio
pip install matplotlib
pip install scipy
pip install packaging

このようになりました。

cycler==0.10.0
kiwisolver==1.3.2
matplotlib==3.4.3
numpy==1.21.2
packaging==21.0
Pillow==8.4.0
pyparsing==2.4.7
python-dateutil==2.8.2
scipy==1.7.1
six==1.16.0
torch==1.9.1
torchaudio==0.9.1
torchvision==0.10.1
typing-extensions==3.10.0.2

Pythonスクリプト

import torch
from torchvision import transforms
from PIL import Image
import urllib
from matplotlib import pyplot as plt
import numpy as np

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 

CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]

model = torch.hub.load('facebookresearch/detr:main', 'detr_resnet50', pretrained=True)
model.to(device)
model.eval()

url = 'https://raw.githubusercontent.com/zhreshold/mxnet-ssd/master/data/demo/person.jpg'
img_file = 'person.jpg'
try: urllib.URLopener().retrieve(url, img_file)
except: urllib.request.urlretrieve(url, img_file)

transform = transforms.Compose([
    transforms.Resize(800),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

img = Image.open(img_file)

input_tensor = transform(img)               # <class 'torch.Tensor'>, torch.Size([3, 800, 1207])
input_batch = input_tensor.unsqueeze(0)     # <class 'torch.Tensor'>, torch.Size([1, 3, 800, 1207])

with torch.no_grad():
    results = model(input_batch.to(device))    # <class 'dict'>

prob = results['pred_logits'].softmax(-1)[0, :, :-1]    # torch.Size([100, 91])
prob = prob.max(-1)                                     # <class 'torch.return_types.max'>

scores = prob.values                        # torch.Size([100])
class_IDs = prob.indices                    # torch.Size([100])
bboxes = results['pred_boxes'].squeeze()    # torch.Size([100, 4])

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.imshow(np.array(img))

colors = dict()
threshold = 0.8
for i, each_score in enumerate(scores):
    if each_score < threshold: continue
    score = '{:.3f}'.format(float(each_score))
    class_id = int(class_IDs[i])
    class_name = CLASSES[class_id]
    if class_id not in colors:
        colors[class_id] = plt.get_cmap('hsv')(class_id / len(CLASSES))
    centerX, centerY, w, h = [float(x) for x in bboxes[i]]
    xmin = (centerX - 0.5 * w) * img.size[0]
    ymin = (centerY - 0.5 * h) * img.size[1]
    w = w * img.size[0]
    h = h * img.size[1]
    rect = plt.Rectangle((xmin, ymin), w, h,
                             fill=False,
                             edgecolor=colors[class_id],
                             linewidth=3.5)
    ax.add_patch(rect)
    ax.text(xmin, ymin - 2,
                '{:s} {:s}'.format(class_name, score),
                bbox=dict(facecolor=colors[class_id], alpha=0.5),
                fontsize=12, color='white')

plt.axis('off')        
plt.show()

結果

f:id:touch-sp:20211017012044p:plain:w400
以下のような警告が出ましたが問題なく動作しています。警告を消す方法は今のところわかりません。

UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  ..\aten\src\ATen\native\BinaryOps.cpp:467.)

参考にさせていただいたサイト

さいごに

間違いや改善点があればコメント頂けましたら幸いです。

追記

MMDetectionを使うとDETRが簡単に使用できます。
touch-sp.hatenablog.com