GluonCVを使って動画に対してPose Estimation(alpha_pose_resnet101_v1b)

2020年12月23日記事を更新しました。

初めに

GluonCVの「alpha_pose_resnet101_v1b」学習済みモデルを使ってWebカメラからの動画に対してRealtime Pose Estimation(姿勢推定)を行いました。

PC環境

Windows 10
NVIDIA GeForce GTX1080
CUDA Toolkit 10.1
Python 3.7.9

Python環境

インストールが必要なのは「mxnet-cu101]と「gluoncv」のみです。
opencv-python」を使いますが「gluoncv」をインストールする時に一緒にインストールされます。

pip install mxnet-cu101==1.7.0 -f https://dist.mxnet.io/python/cu101
pip install gluoncv

gluoncv==0.9.0
mxnet-cu101==1.7.0
opencv-python==4.4.0.46

「gluoncv」のインストールがうまくいかない人はこちらを参照して下さい。
touch-sp.hatenablog.com

実行ファイル(Pythonスクリプト

import time
import cv2

import mxnet as mx
import gluoncv

from gluoncv.model_zoo import get_model
from gluoncv.data.transforms.pose import detector_to_alpha_pose, heatmap_to_coord
from gluoncv.utils.viz import cv_plot_image, cv_plot_keypoints

ctx = mx.gpu()

detector = get_model("ssd_512_mobilenet1.0_coco", pretrained=True, ctx=ctx, root="models")
detector.reset_class(classes=['person'], reuse_weights={'person':'person'})
detector.hybridize()

estimator = get_model('alpha_pose_resnet101_v1b_coco', pretrained=True, ctx=ctx, root="models")
estimator.hybridize()

# Load the webcam handler
cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
# letting the camera autofocus
time.sleep(1)

while(True):
    ret, frame = cap.read()
    frame = mx.nd.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).astype('uint8')

    x, frame = gluoncv.data.transforms.presets.ssd.transform_test(frame, short=480)
    x = x.as_in_context(ctx)
    class_IDs, scores, bounding_boxs = detector(x)

    pose_input, upscale_bbox = detector_to_alpha_pose(frame, class_IDs, scores, bounding_boxs)
    
    if upscale_bbox is not None:
        predicted_heatmap = estimator(pose_input.as_in_context(ctx))
        pred_coords, confidence = heatmap_to_coord(predicted_heatmap, upscale_bbox)

        img = cv_plot_keypoints(frame, pred_coords, confidence, class_IDs, bounding_boxs, scores,
                                box_thresh=0.5, keypoint_thresh=0.2)
        cv_plot_image(img)
    else:
        cv_plot_image(frame)
    
    # escを押したら終了
    if cv2.waitKey(1) == 27:
        break

cap.release()
cv2.destroyAllWindows()

補足

カメラがうまく検出されない場合には実行ファイルの中の次の部分を少し変えてみて下さい。

# Load the webcam handler
cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)

数字を「0」から「1」や「2」に変えてみて下さい。