はじめに
前回の記事は読んで頂けましたでしょうか?前回、PyTorchを勉強しようと決心しました。
とりあえず学習済みモデルで何かをしてみようと思い人物の切り抜きをやってみました。
「MXNet」の「GluonCV」に相当するのが「PyTorch Hub」なのでしょうか?
ほとんどチュートリアル通りにスクリプトを書いてみました。
Pythonスクリプト
驚くことにほとんど同じスクリプトになりました。どちらが先かはわかりませんが片方がもう一方の良い部分を取り入れた結果ではないでしょうか。MXNetのスクリプト
こちらは以前に書いたスクリプトです。import numpy as np from PIL import Image import mxnet as mx from mxnet.gluon.data.vision import transforms import gluoncv filename = gluoncv.utils.download('https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/segmentation/voc_examples/1.jpg') #画像をPILで読み込む img = Image.open(filename) #データの正規化 transform_fn = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]) ]) input_tensor = transform_fn(mx.nd.array(img)) input_batch = input_tensor.expand_dims(0) #モデルを読み込む model = gluoncv.model_zoo.get_model('fcn_resnet101_voc', pretrained=True) #モデルの適応 #人だけを抽出する(class:15) output = model.predict(input_batch) predict = mx.nd.squeeze(mx.nd.argmax(output, 1)).asnumpy() a = np.where(predict ==15, 255, 0) b = Image.fromarray(a).convert('L') #元画像と結果を重ね合わせる img.putalpha(b) img.show()
PyTorchのスクリプト
import numpy as np from PIL import Image import torch from torchvision import transforms from torchvision.datasets.utils import download_url device = 'cuda' if torch.cuda.is_available() else 'cpu' img_url ='https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/segmentation/voc_examples/1.jpg' img_fname = img_url.split('/')[-1] download_url(img_url, root = '.', filename = img_fname) #画像をPILで読み込む img = Image.open(img_fname) #データの正規化 transform_fn = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) input_tensor = transform_fn(img) input_batch = input_tensor.unsqueeze(0) #モデルを読み込む model = torch.hub.load('pytorch/vision:v0.10.0', 'fcn_resnet101', pretrained=True) model.eval().to(device) #モデルの適応 #人だけを抽出する(class:15) with torch.no_grad(): output = model(input_batch.to(device))['out'][0] predict = output.argmax(0).to('cpu').numpy() a = np.where(predict ==15, 255, 0) b = Image.fromarray(a).convert('L') #元画像と結果を重ね合わせる img.putalpha(b) img.show()