【MXNet】Segmentationで背景を消したり、ぼかしたり(GluonCV deeplab_resnet152)

最終更新:2021年3月23日
「動画ファイルに対してのSegmentation」へのリンクを追加しました(記事の末尾)

はじめに

Segmentationモデルを使うとこんなことができる。

背景を消す

f:id:touch-sp:20200901111201j:plain:w250f:id:touch-sp:20200901111318p:plain:w250

背景をぼかす

f:id:touch-sp:20200901111201j:plain:w250f:id:touch-sp:20200901111242p:plain:w250

動作環境

Windows 10 GPUなし
Python 3.8.7

atomicwrites==1.4.0
attrs==20.3.0
autocfg==0.0.6
autogluon.core==0.0.16b20210107
autograd==1.3
bcrypt==3.2.0
boto3==1.16.51
botocore==1.19.51
certifi==2020.12.5
cffi==1.14.4
chardet==3.0.4
click==7.1.2
cloudpickle==1.6.0
colorama==0.4.4
ConfigSpace==0.4.16
cryptography==3.3.1
cycler==0.10.0
Cython==0.29.21
dask==2020.12.0
decord==0.4.2
dill==0.3.3
distributed==2020.12.0
future==0.18.2
gluoncv==0.9.1
graphviz==0.8.4
HeapDict==1.0.1
idna==2.6
iniconfig==1.1.1
jmespath==0.10.0
joblib==1.0.0
kiwisolver==1.3.1
matplotlib==3.3.3
msgpack==1.0.2
mxnet==1.7.0.post1
numpy==1.19.5
opencv-python==4.5.1.48
packaging==20.8
pandas==1.2.0
paramiko==2.7.2
Pillow==8.1.0
pluggy==0.13.1
portalocker==2.0.0
protobuf==3.14.0
psutil==5.8.0
py==1.10.0
pyaml==20.4.0
pycparser==2.20
PyNaCl==1.4.0
pyparsing==2.4.7
pytest==6.2.1
python-dateutil==2.8.1
pytz==2020.5
pywin32==300
PyYAML==5.3.1
requests==2.25.1
s3transfer==0.3.3
scikit-learn==0.23.2
scikit-optimize==0.8.1
scipy==1.4.1
six==1.15.0
sortedcontainers==2.3.0
tblib==1.7.0
tensorboardX==2.1
threadpoolctl==2.1.0
toml==0.10.2
toolz==0.11.1
tornado==6.1
tqdm==4.55.1
urllib3==1.26.2
yacs==0.1.8
zict==2.0.0

Pythonスクリプト

背景を消す

import numpy as np
from PIL import Image
import mxnet as mx
from mxnet import image
from mxnet.gluon.data.vision import transforms
import gluoncv

ctx = mx.cpu()

url = 'https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/segmentation/voc_examples/1.jpg'
filename = gluoncv.utils.download(url)

#画像をNDArryで読み込む
img = image.imread(filename)

#データの正規化
transform_fn = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([.485, .456, .406], [.229, .224, .225])
])
img = transform_fn(img)
img = img.expand_dims(0).as_in_context(ctx)

#モデルを読み込む
#初回時に指定がなければ(default)/.mxnet/modelsに保存される
#指定する場合にはrootで指定する
#2回目以降はそこから読み込む
model = gluoncv.model_zoo.get_model('deeplab_resnet152_voc', pretrained=True, root='./models')
model.collect_params().reset_ctx(ctx)

#モデルの適応
#人たけを抽出する(class:15)
output = model.predict(img)
predict = mx.nd.squeeze(mx.nd.argmax(output, 1)).asnumpy()
a = np.where(predict ==15, 255, 0)
b = Image.fromarray(a).convert('L')

#画像を改めてPILで読み込み、結果(b)と重ね合わせる
img = Image.open(filename)
img.putalpha(b)
img.save('result.png')


背景をぼかす

import numpy as np
from PIL import Image
import mxnet as mx
from mxnet import image
from mxnet.gluon.data.vision import transforms
import gluoncv

ctx = mx.cpu()

url = 'https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/segmentation/voc_examples/1.jpg'
filename = gluoncv.utils.download(url)

#画像をNDArryで読み込む
img = image.imread(filename)

#データの正規化
transform_fn = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([.485, .456, .406], [.229, .224, .225])
])
img = transform_fn(img)
img = img.expand_dims(0).as_in_context(ctx)

#モデルを読み込む
#初回時に指定がなければ(default)/.mxnet/modelsに保存される
#指定する場合にはrootで指定する
#2回目以降はそこから読み込む
model = gluoncv.model_zoo.get_model('deeplab_resnet152_voc', pretrained=True, root='./models')
model.collect_params().reset_ctx(ctx)

#モデルの適応
#人たけを抽出する(class:15)
output = model.predict(img)
predict = mx.nd.squeeze(mx.nd.argmax(output, 1)).asnumpy()
a = np.where(predict ==15, 255, 0)
b = Image.fromarray(a).convert('L')

#画像を改めてPILで読み込み、モザイク画像と結果を重ね合わせる
img = Image.open(filename)
mosaic = img.resize([x // 8 for x in img.size]).resize(img.size)
mosaic.paste(img,b)
mosaic.save('result2.png')

2020年11月16日追記

上記二つのスクリプトLinux上で実行すると以下のエラーがでる

Traceback (most recent call last):
  File "seg.py", line 37, in <module>
    b = Image.fromarray(a).convert('L')
  File "/home/<user name>/mxnet17/lib/python3.7/site-packages/PIL/Image.py", line 2753, in fromarray
    raise TypeError("Cannot handle this data type: %s, %s" % typekey) from e   
TypeError: Cannot handle this data type: (1, 1), <i8


次の1行を変更すると実行できる。

  • 変更前
a = np.where(predict ==15, 255, 0)
  • 変更後
a = np.where(predict ==15, 255, 0).astype('int32')

2021年3月23日追記

Webカメラからの動画に対してのSegmentationの記事を書きました。
touch-sp.hatenablog.com
touch-sp.hatenablog.com