Pose Estimationの結果をテキストで表示する

結果

f:id:touch-sp:20201017160024p:plain:w400

Pythonスクリプト

from gluoncv import model_zoo, data, utils
from gluoncv.data.transforms.pose import detector_to_simple_pose, heatmap_to_coord

detector = model_zoo.get_model('yolo3_mobilenet1.0_coco', pretrained=True, root='./model')
pose_net = model_zoo.get_model('simple_pose_resnet18_v1b', pretrained=True, root='./model')
detector.reset_class(["person"], reuse_weights=['person'])

url = 'https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/segmentation/voc_examples/1.jpg'
filename = 'sample.jpg'
utils.download(url, filename)

x, img = data.transforms.presets.ssd.load_test(filename, short=512)
class_IDs, scores, bounding_boxs = detector(x)
pose_input, upscale_bbox = detector_to_simple_pose(img, class_IDs, scores, bounding_boxs)
predicted_heatmap = pose_net(pose_input)
pred_coords, confidence = heatmap_to_coord(predicted_heatmap, upscale_bbox)

from PIL import Image, ImageDraw, ImageFont
font = ImageFont.truetype("arial.ttf", 20)

pil_image = Image.fromarray(img)
draw = ImageDraw.Draw(pil_image)
keypoints = data.mscoco.keypoints.COCOKeyPoints.KEYPOINTS
for keypoint_id in range(len(keypoints)):
    pred = pred_coords[:,keypoint_id,:]
    for i in range(pred.shape[0]):
        if (confidence[i,keypoint_id,:] > 0.2) == 1:
            draw.text(pred[i,:].asnumpy(),text=keypoints[keypoint_id], fill='red', font=font)
pil_image.save('result.png')

動作環境

Windows 10
Python 3.7.8

インストールしたのは「mxnet」と「gluoncv」のみ。その他は勝手についてきた。

pip install mxnet==1.7.0 -f https://dist.mxnet.io/python/cpu
pip install gluoncv --pre

certifi==2020.6.20
chardet==3.0.4
cycler==0.10.0
gluoncv==0.9.0b20201016
graphviz==0.8.4
idna==2.6
kiwisolver==1.2.0
matplotlib==3.3.2
mxnet==1.7.0
numpy==1.16.6
Pillow==8.0.0
portalocker==2.0.0
pyparsing==3.0.0a2
python-dateutil==2.8.1
pywin32==228
requests==2.18.4
scipy==1.5.2
six==1.15.0
tqdm==4.50.2
urllib3==1.22

2020年10月18日追記

WSL2上のUbuntuでも実行可能であった。(Python 3.7.5)
ただしフォントの変更は出来なかった。
以下の部分でエラーを吐き出す。

font = ImageFont.truetype("arial.ttf", 20)

certifi==2020.6.20
chardet==3.0.4
cycler==0.10.0
gluoncv==0.9.0b20201017
graphviz==0.8.4
idna==2.10
kiwisolver==1.2.0
matplotlib==3.3.2
mxnet==1.9.0b20201015
numpy==1.19.2
Pillow==8.0.0
pkg-resources==0.0.0
portalocker==2.0.0
pyparsing==3.0.0a2
python-dateutil==2.8.1
requests==2.24.0
scipy==1.5.3
six==1.15.0
tqdm==4.50.2
urllib3==1.25.10