PyTorchに入門してみる part 4 Object Trackingの結果を使ってYOLOv5の転移学習を行う

はじめに

前回Object Trackingの結果をVOCフォーマットで出力しました。

しかしPyTorchでの転移学習を調べているとYOLOv5が簡単そうでした。
github.com
そのためObject Trackingの出力がYOLOv5用になるようにスクリプトを書き換えました。


その後YOLO5の転移学習を行いました。

結果


以前MXNetでやったことをPyTorchに置き換えただけです。

PC環境

Windows 11 
Core i7-7700K + GTX 1080

Python 3.8.10

Python環境構築

新しいPython仮想環境を作り以下をインストールしました。

pip install torch==1.9.1+cu102 torchvision==0.10.1+cu102 torchaudio===0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r https://raw.githubusercontent.com/ultralytics/yolov5/master/requirements.txt
pip install yacs

Object Tracking(学習データの作成)

pysotのGitHubをcloneして以下のPythonスクリプトを実行します。
この部分の詳細は前回の記事を参照して下さい。

import os

import cv2
import torch

from pysot.core.config import cfg
from pysot.models.model_builder import ModelBuilder
from pysot.tracker.tracker_builder import build_tracker

from torchvision.datasets.utils import download_url

#=========================================================
video_list = ['target.mp4', 'green.mp4']

url_1 = 'https://github.com/dai-ichiro/robo-one/raw/main/video_1.mp4'
url_2 = 'https://github.com/dai-ichiro/robo-one/raw/main/video_2.mp4'

download_url(url_1, root = '.', filename = video_list[0])
download_url(url_2, root = '.', filename = video_list[1])

target_name = [x.split('.')[0] for x in video_list]

out_path = 'train_data'

config = 'experiments/siamrpn_r50_l234_dwxcorr/config.yaml'
snapshot = 'experiments/siamrpn_r50_l234_dwxcorr/model.pth'
#=========================================================

train_images_dir = os.path.join(out_path, 'images', 'train')
train_labels_dir = os.path.join(out_path, 'labels', 'train')

os.makedirs(train_images_dir)
os.makedirs(train_labels_dir)

init_rect_list = []

for video in video_list:
    cap = cv2.VideoCapture(video)
    ret, img = cap.read()
    cap.release()

    source_window = "draw_rectangle"
    cv2.namedWindow(source_window)
    rect = cv2.selectROI(source_window, img, False, False)

    init_rect_list.append(rect)
    cv2.destroyAllWindows()

# モデルを取得する
cfg.merge_from_file(config)
cfg.CUDA = torch.cuda.is_available() and cfg.CUDA
device = torch.device('cuda' if cfg.CUDA else 'cpu')
model = ModelBuilder()
model.load_state_dict(torch.load(snapshot,
    map_location=lambda storage, loc: storage.cpu()))
model.eval().to(device)
tracker = build_tracker(model)

for i, video in enumerate(video_list):
    # 映像ファイルを読み込む
    video_frames = []
    cap = cv2.VideoCapture(video)
    w = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
    h = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
    while(True):
        ret, img = cap.read()
        if not ret:
            break
        video_frames.append(img)
    cap.release()

    #トラッキングを実行
    jpeg_filenames_list = []

    for ind, frame in enumerate(video_frames):
        if ind == 0:
            tracker.init(frame, init_rect_list[i])
            bbox = init_rect_list[i]
        else:
            outputs = tracker.track(frame)
            bbox = outputs['bbox']
            
        filename = '%d%06d'%((i+1),ind)

        #画像の保存
        jpeg_filename = filename + '.jpg'
        cv2.imwrite(os.path.join(train_images_dir, jpeg_filename), frame)

        #ラベルテキストの保存
        txt_filename= filename + '.txt'
        with open(os.path.join(train_labels_dir, txt_filename), 'w') as f:
            center_x = (bbox[0] + bbox[2] / 2) / w
            center_y = (bbox[1] + bbox[3] / 2) / h
            width = bbox[2] / w
            height = bbox[3] / h
            f.write('%d %f %f %f %f'%(i, center_x, center_y, width, height))

with open('train.yaml', 'w', encoding='cp932') as f:
    f.write('path: %s'%out_path)
    f.write('\n')
    f.write('train: images/train')
    f.write('\n')
    f.write('val: images/train')
    f.write('\n')
    f.write('nc: %d'%len(video_list))
    f.write('\n')
    f.write('names: ')
    f.write('[')
    output_target_name = ['\'' + x + '\'' for x in target_name]
    f.write(', '.join(output_target_name))
    f.write(']')

このスクリプトを実行すると「train_data」フォルダと「train.yaml」が作成されます。

YAMLファイル

上記スクリプトで作成されたYAMLファイルの中身はこのようになっています。

path: train_data
train: images/train
val: images/train
nc: 2
names: ['target', 'green']

転移学習

まずはYOLOv5のGitHubをcloneする必要があります。(GitHubからZIPでダウンロードしても可)
train.pyと同じフォルダに先ほど作成した「train_data」フォルダと「train.yaml」を移動させて下さい。
そして以下のようにtrain.pyを実行すれば学習が始まります。
batchとepochsの数字は適当に変更して下さい。

python train.py --batch 16 --epochs 10 --data train.yaml --weights yolov5s.pt

学習が終わると「runs/train/exp/weights」フォルダに「best.pt」と「last.pt」が保存されます。

結果を確認する方法

ここからはGitHubのcloneは必要ありません。保存された「best.pt」をどこに移動しても構いません。
ただし、モデルを読み込むときにpathを設定して下さい。

静止画で結果確認

import torch
from torchvision.datasets.utils import download_url

url = 'https://github.com/dai-ichiro/robo-one/raw/main/test.jpg'
fname = url.split('/')[-1]
download_url(url, root = '.', filename = fname)

model = torch.hub.load('ultralytics/yolov5', 'custom', path = 'best.pt')
results = model([fname])
results.show()

冒頭で紹介した図になります。

Webカメラで結果確認

import torch
import cv2
from matplotlib import pyplot as plt

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 

model = torch.hub.load('ultralytics/yolov5', 'custom', path = 'best.pt')
model.to(device)
class_num = model.yaml['nc']

cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)

colors = dict()

while True:
    ret, frame = cap.read()

    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    results = model(frame_rgb)

    pandasDF = results.pandas().xyxy[0]

    for row in pandasDF.itertuples():
        
        if row.confidence < 0.8: break

        score = '{:.3f}'.format(row.confidence)
        xmin = int(row.xmin)
        ymin = int(row.ymin)
        xmax = int(row.xmax)
        ymax = int(row.ymax)
        class_id = row._6 #(class)
        class_name = row.name
        
        if class_id not in colors:
            colors[class_id] = plt.get_cmap('hsv')(class_id / class_num)

        bcolor = [x * 255 for x in colors[class_id]]
        cv2.rectangle(frame, (xmin, ymin), (xmax, ymax), bcolor, 3)

        y = ymin - 15 if ymin - 15 > 15 else ymin + 15
        cv2.putText(frame, '{:s} {:s}'.format(class_name, score),
                        (xmin, y), cv2.FONT_HERSHEY_SIMPLEX, 1.0,
                        bcolor, 3, lineType=cv2.LINE_AA)

    cv2.imshow('demo', frame)

    if cv2.waitKey(1) & 0xFF == 27:
        break

cap.release()
cv2.destroyAllWindows()

参考にさせて頂いたサイト

さいごに

間違いや改善点があればコメント頂けましたら幸いです。

(2022年4月9日追記)手順を簡略化(1クラス)

手順を簡略化した新しい記事を書きました。
touch-sp.hatenablog.com

(2022年7月16日追記)手順を簡略化(2クラス)

手順を簡略化した新しい記事を書きました。
touch-sp.hatenablog.com