はじめに
前回推論を行いました。touch-sp.hatenablog.com
今回は独自データに対して物体検出モデルの学習を行ってみたいと思います。
学習データ
こちらで作成したデータ(VOCフォーマット)を使用します。touch-sp.hatenablog.com
Pythonスクリプト
import os from detectron2 import model_zoo from detectron2.engine import DefaultTrainer from detectron2.data import MetadataCatalog, DatasetCatalog from detectron2.data.datasets import load_voc_instances # データの登録 DatasetCatalog.register('my_dataset', lambda: load_voc_instances('train_data', 'train', ['target', 'green'])) MetadataCatalog.get('my_dataset').thing_classes = ['target', 'green'] # モデルの取得 model = 'COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml' cfg = model_zoo.get_config(model, trained=True) # 訓練データをセット cfg.DATASETS.TRAIN = ('my_dataset',) # テストデータをセット(今回はセットしない) cfg.DATASETS.TEST = () cfg.DATALOADER.NUM_WORKERS = 4 # default:4 cfg.SOLVER.IMS_PER_BATCH = 4 # default:16 cfg.SOLVER.BASE_LR = 0.00025 # dafault:0.02 cfg.SOLVER.MAX_ITER = 2000 # default:270000 cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 64 # dafault:512 cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # dafault:80 os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) import detectron2.data.transforms as T from detectron2.data import DatasetMapper, build_detection_train_loader train_augmentations = [ T.RandomBrightness(0.5, 2), T.RandomContrast(0.5, 2), T.RandomSaturation(0.5, 2), T.RandomFlip(prob=0.5, horizontal=True, vertical=False), T.RandomFlip(prob=0.5, horizontal=False, vertical=True), ] class AddAugmentationsTrainer(DefaultTrainer): @classmethod def build_train_loader(cls, cfg): custom_mapper = DatasetMapper(cfg, is_train=True, augmentations=train_augmentations) return build_detection_train_loader(cfg, mapper=custom_mapper) trainer = AddAugmentationsTrainer(cfg) trainer.train()
結果
上記スクリプトを実行するとoutputフォルダが作成されその中に結果が保存されます。推論
独自の画像に対して推論を行います。import cv2 from detectron2 import model_zoo from detectron2.engine import DefaultPredictor from detectron2.utils.visualizer import Visualizer from detectron2.data import Metadata from torchvision.datasets.utils import download_url # 画像の取得 url = 'https://github.com/dai-ichiro/robo-one/raw/main/test.jpg' fname = url.split('/')[-1] download_url(url, root = '.', filename = fname) im = cv2.imread(fname) # モデルの取得 model = 'COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml' cfg = model_zoo.get_config(model, trained=False) cfg.MODEL.WEIGHTS = 'output/model_final.pth' cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 # メタデータの新規作成 original_metadata = Metadata() original_metadata.thing_classes = ['target', 'green'] predictor = DefaultPredictor(cfg) outputs = predictor(im) v = Visualizer(im, original_metadata, scale=1.0) v = v.draw_instance_predictions(outputs['instances'].to('cpu')) img_array = v.get_image() cv2.imshow ('result', img_array) cv2.waitKey(0) cv2.destroyAllWindows()
このようになりました。
おそらく「cfg.SOLVER.MAX_ITER = 2000」が少なすぎると思われます。
増やせば結果は改善するはずです。
さいごに
私が知る限りでは物体検出モデルに独自データを学習させる時、PyTorchならYOLOv5をMXNetならAutoGluonを使うのが簡単だと思います。touch-sp.hatenablog.com
touch-sp.hatenablog.com