結果
Pythonスクリプト
from gluoncv import model_zoo, data, utils from gluoncv.data.transforms.pose import detector_to_simple_pose, heatmap_to_coord detector = model_zoo.get_model('yolo3_mobilenet1.0_coco', pretrained=True, root='./model') pose_net = model_zoo.get_model('simple_pose_resnet18_v1b', pretrained=True, root='./model') detector.reset_class(["person"], reuse_weights=['person']) url = 'https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/segmentation/voc_examples/1.jpg' filename = 'sample.jpg' utils.download(url, filename) x, img = data.transforms.presets.ssd.load_test(filename, short=512) class_IDs, scores, bounding_boxs = detector(x) pose_input, upscale_bbox = detector_to_simple_pose(img, class_IDs, scores, bounding_boxs) predicted_heatmap = pose_net(pose_input) pred_coords, confidence = heatmap_to_coord(predicted_heatmap, upscale_bbox) from PIL import Image, ImageDraw, ImageFont font = ImageFont.truetype("arial.ttf", 20) pil_image = Image.fromarray(img) draw = ImageDraw.Draw(pil_image) keypoints = data.mscoco.keypoints.COCOKeyPoints.KEYPOINTS for keypoint_id in range(len(keypoints)): pred = pred_coords[:,keypoint_id,:] for i in range(pred.shape[0]): if (confidence[i,keypoint_id,:] > 0.2) == 1: draw.text(pred[i,:].asnumpy(),text=keypoints[keypoint_id], fill='red', font=font) pil_image.save('result.png')
動作環境
Windows 10 Python 3.7.8
インストールしたのは「mxnet」と「gluoncv」のみ。その他は勝手についてきた。
pip install mxnet==1.7.0 -f https://dist.mxnet.io/python/cpu pip install gluoncv --pre
certifi==2020.6.20 chardet==3.0.4 cycler==0.10.0 gluoncv==0.9.0b20201016 graphviz==0.8.4 idna==2.6 kiwisolver==1.2.0 matplotlib==3.3.2 mxnet==1.7.0 numpy==1.16.6 Pillow==8.0.0 portalocker==2.0.0 pyparsing==3.0.0a2 python-dateutil==2.8.1 pywin32==228 requests==2.18.4 scipy==1.5.2 six==1.15.0 tqdm==4.50.2 urllib3==1.22
2020年10月18日追記
WSL2上のUbuntuでも実行可能であった。(Python 3.7.5)
ただしフォントの変更は出来なかった。
以下の部分でエラーを吐き出す。
font = ImageFont.truetype("arial.ttf", 20)
certifi==2020.6.20 chardet==3.0.4 cycler==0.10.0 gluoncv==0.9.0b20201017 graphviz==0.8.4 idna==2.10 kiwisolver==1.2.0 matplotlib==3.3.2 mxnet==1.9.0b20201015 numpy==1.19.2 Pillow==8.0.0 pkg-resources==0.0.0 portalocker==2.0.0 pyparsing==3.0.0a2 python-dateutil==2.8.1 requests==2.24.0 scipy==1.5.3 six==1.15.0 tqdm==4.50.2 urllib3==1.25.10