【Realtime Segmentation】【GluonCV】Webカメラの動画に対してリアルタイムに人物以外の背景を消す

はじめに

以前静止画に対して背景を消すスクリプトを書きました。
touch-sp.hatenablog.com
今回はWebカメラからの動画に対して同様のことをしました。

結果

30行弱のスクリプトで目的を達成することができました。
学習済みモデル(パラメーター付き)をダウンロードするスクリプトもその中に含まれているのでPython環境を用意するだけで実行可能です。

環境

Windows 10 PC with GTX 1080
CUDA 10.2
Python 3.7.9

GPUがないとまともに動かないと思います。

Python環境にインストールが必要なのは「mxnet」「gluoncv」のみです。
両方ともpipでインストール可能ですがWindows用のGPU版mxnetはダウンロード先を指定する必要があります。

pip install mxnet-cu102 -f https://dist.mxnet.io/python/cu102
pip install gluoncv

Pythonスクリプト

import numpy as np
import mxnet as mx
from mxnet.gluon.data.vision import transforms
from  gluoncv.model_zoo import get_model
import cv2

ctx = mx.gpu() if mx.context.num_gpus() >0 else mx.cpu()
transform_fn = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([.485, .456, .406], [.229, .224, .225])
])
model = get_model('deeplab_resnet152_voc', pretrained=True, root='./models', ctx=ctx)
cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
while True:
    ret, frame = cap.read()
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    img = transform_fn(mx.nd.array(frame_rgb))
    img = img.expand_dims(0).as_in_context(ctx)
    output = model.predict(img)
    predict = mx.nd.squeeze(mx.nd.argmax(output, 1)).asnumpy()
    mask_1 = np.where(predict == 15, 1, 0)[...,np.newaxis]
    mask_2 = np.where(predict == 15, 0, 255)[...,np.newaxis]
    result_img = (frame * mask_1 + mask_2).astype('uint8')
    cv2.imshow('result', result_img)
    if cv2.waitKey(1) & 0xFF == 27:
        break
cap.release()
cv2.destroyAllWindows()

たったこれだけと思われるかもしれませんが、実際これだけです。
GPU搭載PCを持っているなら一度試してみて下さい。
「Esc」キーを押せば終了できます。

スクリプトの解説

opencvでカメラ画像を取得して表示する時の基本的な書き方が以下です。

import cv2

cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)

while True:
    ret, frame = cap.read()
    cv2.imshow('demo', frame)
    if cv2.waitKey(1) & 0xFF == 27:
        break
cap.release()
cv2.destroyAllWindows()

この基本形は毎回コピペして使用しています。
2行目の「cv2.CAP_DSHOW」は必要なのか否か意見が分かれるようですが自分は深く考えずにつけています。
この記事などが参考になると思います。
データはnumpy arrayで取得できます。第三軸がRGBではなくBGRなのでMXNetで使用する時には以下の1行が必要になります。

frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)


MXNetではGPUを使う時はそれを明示する必要があります。しなければCPUが使われてしまいます。環境によって自動的に決定してくれるのが以下の1行です。

ctx = mx.gpu() if mx.context.num_gpus() >0 else mx.cpu()


学習済みモデルのダウンロードは以下の1行です。

model = get_model('deeplab_resnet152_voc', pretrained=True, root='./models', ctx=ctx)

modelsというフォルダを用意していなくても勝手に作成してくれ、その中に学習済みモデルがダウンロードされます。2回目以降はダウンロードせずにそこから読み込んでくれます。

MXNetのモデルに画像データを入力する時には(1, channel, height, width)にしなければいけません。また値は正規化、標準化が必要です。
それを行っているのが以下の部分です。

transform_fn = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([.485, .456, .406], [.229, .224, .225])
])

img = transform_fn(mx.nd.array(frame_rgb))
img = img.expand_dims(0).as_in_context(ctx)

「ToTensor()」で値を255で割り算して(height, width, channel)を(channel, height, width)に変換しています。
「expand_dims(0)」で(channel, height, width)を(1, channel, height, width)にしています。
「as_in_context(ctx)」でデータをGPUに渡しています。
GPUのメモリを節約するためにモデルに渡す直前にGPUに渡すことが一般的です。

推論は以下の1行です。

output = model.predict(img)


人物以外を白塗りにするためにはいろいろな方法が考えられます。
pillowやopencv-pythonのマスクというものを使うのが一般的ですが今回は非常にわかりやすく人物以外の部分にいったん0を掛けて255を足しただけでです。
人物部分は1を掛けて0を足しています(つまり何も変わらない)。

mask_1 = np.where(predict == 15, 1, 0)[...,np.newaxis]
mask_2 = np.where(predict == 15, 0, 255)[...,np.newaxis]
result_img = (frame * mask_1 + mask_2).astype('uint8')

「predict == 15」は人物であることを示しています。
人物部分は1, その他は0としたのがmask_1です。
人物部分は0, その他は255としたのがmask_2です。
newaxisで軸を付け足したのは最後の計算でうまくブロードキャストしてもらうためです。
(height, width, 3)と(height, width, 1)で計算しています。(height, width, 3)と(height, width)ではうまくブロードキャストされませんでした。

最終的なPython環境

autocfg==0.0.8
autogluon.core==0.1.0
autograd==1.3
bcrypt==3.2.0
boto3==1.17.27
botocore==1.20.27
certifi==2020.12.5
cffi==1.14.5
chardet==3.0.4
click==7.1.2
cloudpickle==1.6.0
ConfigSpace==0.4.18
cryptography==3.4.6
cycler==0.10.0
Cython==0.29.22
dask==2021.3.0
decord==0.5.2
dill==0.3.3
distributed==2021.3.0
future==0.18.2
gluoncv==0.10.0
graphviz==0.8.4
HeapDict==1.0.1
idna==2.6
jmespath==0.10.0
joblib==1.0.1
kiwisolver==1.3.1
matplotlib==3.3.4
msgpack==1.0.2
mxnet-cu102==1.7.0
numpy==1.19.5
opencv-python==4.5.1.48
pandas==1.2.3
paramiko==2.7.2
Pillow==8.1.2
portalocker==2.2.1
protobuf==3.15.6
psutil==5.8.0
pycparser==2.20
PyNaCl==1.4.0
pyparsing==2.4.7
python-dateutil==2.8.1
pytz==2021.1
pywin32==300
PyYAML==5.4.1
requests==2.25.1
s3transfer==0.3.4
scikit-learn==0.24.1
scipy==1.5.4
six==1.15.0
sortedcontainers==2.3.0
tblib==1.7.0
tensorboardX==2.1
threadpoolctl==2.1.0
toolz==0.11.1
tornado==6.1
tqdm==4.59.0
urllib3==1.26.3
yacs==0.1.8
zict==2.0.0

参考にさせて頂いたサイト

PysimpleGUIを用いたOpenCVのカメラ画像表示 - Qiita

2022年5月20日追記

各種バージョンを新しくして動作確認してみました。
touch-sp.hatenablog.com