PyTorchに入門してみる part1 Segmentationモデルで人物の切り抜き

はじめに

前回の記事は読んで頂けましたでしょうか?

前回、PyTorchを勉強しようと決心しました。


とりあえず学習済みモデルで何かをしてみようと思い人物の切り抜きをやってみました。


「MXNet」の「GluonCV」に相当するのが「PyTorch Hub」なのでしょうか?


ほとんどチュートリアル通りにスクリプトを書いてみました。

Pythonスクリプト

驚くことにほとんど同じスクリプトになりました。どちらが先かはわかりませんが片方がもう一方の良い部分を取り入れた結果ではないでしょうか。

MXNetのスクリプト

こちらは以前に書いたスクリプトです。

import numpy as np
from PIL import Image
import mxnet as mx
from mxnet.gluon.data.vision import transforms
import gluoncv

filename = gluoncv.utils.download('https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/segmentation/voc_examples/1.jpg')

#画像をPILで読み込む
img = Image.open(filename)

#データの正規化
transform_fn = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([.485, .456, .406], [.229, .224, .225])
])
input_tensor = transform_fn(mx.nd.array(img))
input_batch = input_tensor.expand_dims(0)

#モデルを読み込む
model = gluoncv.model_zoo.get_model('fcn_resnet101_voc', pretrained=True)

#モデルの適応
#人だけを抽出する(class:15)
output = model.predict(input_batch)
predict = mx.nd.squeeze(mx.nd.argmax(output, 1)).asnumpy()
a = np.where(predict ==15, 255, 0)
b = Image.fromarray(a).convert('L')

#元画像と結果を重ね合わせる
img.putalpha(b)
img.show()

PyTorchのスクリプト

import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from torchvision.datasets.utils import download_url

device = 'cuda' if torch.cuda.is_available() else 'cpu'

img_url ='https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/segmentation/voc_examples/1.jpg'
img_fname = img_url.split('/')[-1]
download_url(img_url, root = '.', filename = img_fname)

#画像をPILで読み込む
img = Image.open(img_fname)

#データの正規化
transform_fn = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
])
input_tensor = transform_fn(img)
input_batch = input_tensor.unsqueeze(0)

#モデルを読み込む
model = torch.hub.load('pytorch/vision:v0.10.0', 'fcn_resnet101', pretrained=True)
model.eval().to(device)

#モデルの適応
#人だけを抽出する(class:15)
with torch.no_grad():
    output = model(input_batch.to(device))['out'][0]
predict = output.argmax(0).to('cpu').numpy()
a = np.where(predict ==15, 255, 0)
b = Image.fromarray(a).convert('L')

#元画像と結果を重ね合わせる
img.putalpha(b)
img.show()

結果

f:id:touch-sp:20211016184836j:plain:w200
元画像
f:id:touch-sp:20211016185634p:plain:w200
MXNetの結果
f:id:touch-sp:20211016185724p:plain:w200
PyTorchの結果

感想

ここまで似ているなら学習コストは低く済みそうです。