OpenMMLab の MMSegmentation を使ってカメラからの入力に対して人物以外の背景を消す

はじめに

OpenMMLabの「MMdetection」を使ってもセグメンテーションはできますが、「MMdetection」とは別に「MMSegmentation」というのが存在することを知りました。

さっそく使ってみました。

OpenMMLabが開発するツールの中で過去に「MMdetection」「MMEditing」を使ったことがあります。

開発元が同じだけあってスクリプトはほとんど同じになります。

過去の記事へのリンクはこの記事の最後に貼っておきます。よかったら読んで下さい。

ここではスクリプトがいかに似ているかを見て頂くために一部抜粋して載せておきます。

スクリプトの比較

MMDetection

os.makedirs('models', exist_ok=True)
checkpoint_name = 'detr_r50_8x2_150e_coco'
config_fname = checkpoint_name + '.py'
checkpoint = download(package="mmdet", configs=[checkpoint_name], dest_root="models")[0]

model = init_detector(os.path.join('models', config_fname), os.path.join('models', checkpoint), device = device)

result = inference_detector(model, img_fname)

MMEditing

os.makedirs('models', exist_ok=True)
checkpoint_name = 'esrgan_x4c64b23g32_g1_400k_div2k'
config_fname = checkpoint_name + '.py'
checkpoint = download(package="mmedit", configs=[checkpoint_name], dest_root="models")[0]

model = init_model(os.path.join('models', config_fname), os.path.join('models', checkpoint), device = device)

result = restoration_inference(model, img_fname)

MMSegmentation

os.makedirs('models', exist_ok=True)
checkpoint_name = 'deeplabv3plus_r101-d8_512x512_40k_voc12aug'
config_fname = checkpoint_name + '.py'
checkpoint = download(package="mmsegmentation", configs=[checkpoint_name], dest_root="models")[0]

model = init_segmentor(os.path.join('models', config_fname), os.path.join('models', checkpoint), device = device)

result = inference_segmentor(model, frame)

人物以外の背景を消すPythonスクリプト(本題)

ここからが本題です。

非常に短いスクリプトで簡単に目的が実現できます。

モデルのダウンロードをPythonスクリプト外で「mim download」を使って事前に行っている人が多いと思いますがスクリプト内に埋め込むことも可能です。今回はその方法を採用しています。

つまり事前準備は必要ありません。下記のスクリプトを実行するだけです。

もちろん、必要なファイルがすでに存在する場合にはダウンロードの過程はスキップされます。

import os
import numpy as np
import cv2
import torch
from mmseg.apis import inference_segmentor, init_segmentor
from mim.commands.download import download

device = 'cuda' if torch.cuda.is_available() else 'cpu' 

os.makedirs('models', exist_ok=True)

checkpoint_name = 'deeplabv3plus_r101-d8_512x512_40k_voc12aug'
config_fname = checkpoint_name + '.py'

checkpoint = download(package="mmsegmentation", configs=[checkpoint_name], dest_root="models")[0]

model = init_segmentor(os.path.join('models', config_fname), os.path.join('models', checkpoint), device = device)

cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)

while True:
    ret, frame = cap.read()
    output = inference_segmentor(model, frame)
    result = output[0]
    mask_1 = np.where(result == 15, 1, 0)[...,np.newaxis]
    mask_2 = np.where(result == 15, 0, 255)[...,np.newaxis]
    result_img = (frame * mask_1 + mask_2).astype('uint8')
    cv2.imshow('result', result_img)
    if cv2.waitKey(1) & 0xFF == 27:
        break

cap.release()
cv2.destroyAllWindows()

動作環境

Windows 11
CUDA 11.6.2
Python 3.9.13
pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116
pip install mmcv-full==1.6.1 -f https://download.openmmlab.com/mmcv/dist/cu116/torch1.12.0/index.html
pip install mmsegmentation
pip install openmim
pip install mmengine

関連記事

MMDetection

touch-sp.hatenablog.com
touch-sp.hatenablog.com

MMEditing

touch-sp.hatenablog.com
touch-sp.hatenablog.com