はじめに
前回に引き続き学習済みモデルを使いながらPyTorchに慣れていきます。せっかくなのでGluonCVでは使えなかった比較的新しいモデルを使ってみようと思います。
github.com
環境
GPUがある環境とない環境のふたつで動作確認しました。GPUあり
Windows 11 Core i7-7700K + GTX 1080 Python 3.8.10
Pythonの新しい仮想環境を作り以下のように必要なものをインストールしました。引っかかるところは幸いありませんでした。
pip install torch==1.9.1+cu102 torchvision==0.10.1+cu102 torchaudio===0.9.1 -f https://download.pytorch.org/whl/torch_stable.html pip install matplotlib pip install scipy pip install packaging
このようになりました。
cycler==0.10.0 kiwisolver==1.3.2 matplotlib==3.4.3 numpy==1.21.2 packaging==21.0 Pillow==8.4.0 pyparsing==2.4.7 python-dateutil==2.8.2 scipy==1.7.1 six==1.16.0 torch==1.9.1+cu102 torchaudio==0.9.1 torchvision==0.10.1+cu102 typing-extensions==3.10.0.2
GPUなし
Windows 10 Core i7-1165G7 Python 3.8.10
Pythonの新しい仮想環境を作り以下のように必要なものをインストールしました。引っかかるところは幸いありませんでした。
pip install torch torchvision torchaudio pip install matplotlib pip install scipy pip install packaging
このようになりました。
cycler==0.10.0 kiwisolver==1.3.2 matplotlib==3.4.3 numpy==1.21.2 packaging==21.0 Pillow==8.4.0 pyparsing==2.4.7 python-dateutil==2.8.2 scipy==1.7.1 six==1.16.0 torch==1.9.1 torchaudio==0.9.1 torchvision==0.10.1 typing-extensions==3.10.0.2
Pythonスクリプト
import torch from torchvision import transforms from PIL import Image import urllib from matplotlib import pyplot as plt import numpy as np device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') CLASSES = [ 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' ] model = torch.hub.load('facebookresearch/detr:main', 'detr_resnet50', pretrained=True) model.to(device) model.eval() url = 'https://raw.githubusercontent.com/zhreshold/mxnet-ssd/master/data/demo/person.jpg' img_file = 'person.jpg' try: urllib.URLopener().retrieve(url, img_file) except: urllib.request.urlretrieve(url, img_file) transform = transforms.Compose([ transforms.Resize(800), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) img = Image.open(img_file) input_tensor = transform(img) # <class 'torch.Tensor'>, torch.Size([3, 800, 1207]) input_batch = input_tensor.unsqueeze(0) # <class 'torch.Tensor'>, torch.Size([1, 3, 800, 1207]) with torch.no_grad(): results = model(input_batch.to(device)) # <class 'dict'> prob = results['pred_logits'].softmax(-1)[0, :, :-1] # torch.Size([100, 91]) prob = prob.max(-1) # <class 'torch.return_types.max'> scores = prob.values # torch.Size([100]) class_IDs = prob.indices # torch.Size([100]) bboxes = results['pred_boxes'].squeeze() # torch.Size([100, 4]) fig = plt.figure() ax = fig.add_subplot(1, 1, 1) ax.imshow(np.array(img)) colors = dict() threshold = 0.8 for i, each_score in enumerate(scores): if each_score < threshold: continue score = '{:.3f}'.format(float(each_score)) class_id = int(class_IDs[i]) class_name = CLASSES[class_id] if class_id not in colors: colors[class_id] = plt.get_cmap('hsv')(class_id / len(CLASSES)) centerX, centerY, w, h = [float(x) for x in bboxes[i]] xmin = (centerX - 0.5 * w) * img.size[0] ymin = (centerY - 0.5 * h) * img.size[1] w = w * img.size[0] h = h * img.size[1] rect = plt.Rectangle((xmin, ymin), w, h, fill=False, edgecolor=colors[class_id], linewidth=3.5) ax.add_patch(rect) ax.text(xmin, ymin - 2, '{:s} {:s}'.format(class_name, score), bbox=dict(facecolor=colors[class_id], alpha=0.5), fontsize=12, color='white') plt.axis('off') plt.show()
結果
以下のような警告が出ましたが問題なく動作しています。警告を消す方法は今のところわかりません。
UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at ..\aten\src\ATen\native\BinaryOps.cpp:467.)
参考にさせていただいたサイト
さいごに
間違いや改善点があればコメント頂けましたら幸いです。追記
MMDetectionを使うとDETRが簡単に使用できます。touch-sp.hatenablog.com