【PyTorch】detectron2を使って独自データに対して物体検出モデルの学習を行う

はじめに

前回推論を行いました。
touch-sp.hatenablog.com
今回は独自データに対して物体検出モデルの学習を行ってみたいと思います。

学習データ

こちらで作成したデータ(VOCフォーマット)を使用します。
touch-sp.hatenablog.com

Pythonスクリプト

import os
from detectron2 import model_zoo
from detectron2.engine import DefaultTrainer
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data.datasets import load_voc_instances

# データの登録
DatasetCatalog.register('my_dataset', lambda: load_voc_instances('train_data', 'train', ['target', 'green']))
MetadataCatalog.get('my_dataset').thing_classes = ['target', 'green']

# モデルの取得
model = 'COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml'
cfg = model_zoo.get_config(model, trained=True)

# 訓練データをセット
cfg.DATASETS.TRAIN = ('my_dataset',)
# テストデータをセット(今回はセットしない)
cfg.DATASETS.TEST = () 

cfg.DATALOADER.NUM_WORKERS = 4                  # default:4
cfg.SOLVER.IMS_PER_BATCH = 4                    # default:16
cfg.SOLVER.BASE_LR = 0.00025                    # dafault:0.02
cfg.SOLVER.MAX_ITER = 2000                      # default:270000
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 64   # dafault:512 
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2             # dafault:80

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

import detectron2.data.transforms as T
from detectron2.data import DatasetMapper, build_detection_train_loader

train_augmentations = [
    T.RandomBrightness(0.5, 2),
    T.RandomContrast(0.5, 2),
    T.RandomSaturation(0.5, 2),
    T.RandomFlip(prob=0.5, horizontal=True, vertical=False),
    T.RandomFlip(prob=0.5, horizontal=False, vertical=True),
]

class AddAugmentationsTrainer(DefaultTrainer):
    @classmethod
    def build_train_loader(cls, cfg):
        custom_mapper = DatasetMapper(cfg, is_train=True, augmentations=train_augmentations)
        return build_detection_train_loader(cfg, mapper=custom_mapper)

trainer = AddAugmentationsTrainer(cfg)
trainer.train()

結果

上記スクリプトを実行するとoutputフォルダが作成されその中に結果が保存されます。

推論

独自の画像に対して推論を行います。

import cv2
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
from detectron2.data import Metadata
from torchvision.datasets.utils import download_url

# 画像の取得
url = 'https://github.com/dai-ichiro/robo-one/raw/main/test.jpg'
fname = url.split('/')[-1]
download_url(url, root = '.', filename = fname)
im = cv2.imread(fname)

# モデルの取得
model = 'COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml'
cfg = model_zoo.get_config(model, trained=False)
cfg.MODEL.WEIGHTS = 'output/model_final.pth'
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7  

# メタデータの新規作成
original_metadata = Metadata()
original_metadata.thing_classes = ['target', 'green']

predictor = DefaultPredictor(cfg)

outputs = predictor(im)

v = Visualizer(im, original_metadata, scale=1.0)
v = v.draw_instance_predictions(outputs['instances'].to('cpu'))
img_array = v.get_image()

cv2.imshow ('result', img_array)
cv2.waitKey(0)
cv2.destroyAllWindows()

このようになりました。

おそらく「cfg.SOLVER.MAX_ITER = 2000」が少なすぎると思われます。
増やせば結果は改善するはずです。

さいごに

私が知る限りでは物体検出モデルに独自データを学習させる時、PyTorchならYOLOv5をMXNetならAutoGluonを使うのが簡単だと思います。
touch-sp.hatenablog.com
touch-sp.hatenablog.com