はじめに
AutoGluonで物体検出モデルを学習させる時のデータはVOCフォーマットかCOCOフォーマットで準備する必要があると思っていました。しかしそれは間違いでした。Pandasデータフレームで準備することが可能でした。これによってデータを準備するのが非常に簡単になります。サンプル
このようなPandasデータフレームを用意すればよいようです。image class xmin ymin xmax ymax difficult 0 dog.jpg dog 0.182292 0.381944 0.390625 0.937500 0 1 dog.jpg bicycle 0.156250 0.243056 0.755208 0.729167 0 2 dog.jpg car 0.598958 0.121528 0.885417 0.295139 0
Pythonスクリプト
以前にやったスクリプトを書き換えました。touch-sp.hatenablog.com
touch-sp.hatenablog.com
データのダウンロード・解凍
import tarfile from autogluon.core.utils import download data_file = download('http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar') with tarfile.open(data_file) as tar: tar.extractall(path='.')
Pandasデータフレームの作成
import glob import os import xml.etree.ElementTree as ET import pandas as pd all_xml_pass = glob.glob('./VOCdevkit/VOC2012/Annotations/*.xml') dataset_list = [] for each_xml_pass in all_xml_pass: xml_filename = os.path.basename(each_xml_pass) jpeg_path = each_xml_pass.replace('xml', 'jpg').replace('Annotations', 'JPEGImages') tree = ET.parse(each_xml_pass) root = tree.getroot() object_parts = root.findall('object/part') if len(object_parts) > 0: width = int(root.find('size/width').text) height = int(root.find('size/height').text) for child in object_parts: if child.find('name').text == 'head': xmin = int(child.find('bndbox/xmin').text) / width xmax = int(child.find('bndbox/xmax').text) / width ymin = int(child.find('bndbox/ymin').text) / height ymax = int(child.find('bndbox/ymax').text) / height dataset_list.append({ 'image': jpeg_path, 'class': 'head', 'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax, 'difficult' : 0 }) df = pd.DataFrame(dataset_list) df.to_pickle('dataset.pkl')
このようなPadasデータフレームが作成されます。
image class xmin ymin xmax ymax difficult 0 ./VOCdevkit/VOC2012/JPEGImages/2011_000458.jpg head 0.470000 0.082609 0.544000 0.313043 0 1 ./VOCdevkit/VOC2012/JPEGImages/2011_002793.jpg head 0.314000 0.296000 0.418000 0.418667 0 2 ./VOCdevkit/VOC2012/JPEGImages/2011_002793.jpg head 0.006000 0.373333 0.126000 0.568000 0 3 ./VOCdevkit/VOC2012/JPEGImages/2009_001479.jpg head 0.392330 0.152000 0.525074 0.288000 0 4 ./VOCdevkit/VOC2012/JPEGImages/2011_000878.jpg head 0.664000 0.112000 0.813333 0.220000 0 .. ... ... ... ... ... ... ... 918 ./VOCdevkit/VOC2012/JPEGImages/2008_007168.jpg head 0.672000 0.104000 0.758000 0.285333 0 919 ./VOCdevkit/VOC2012/JPEGImages/2008_007168.jpg head 0.356000 0.269333 0.504000 0.533333 0 920 ./VOCdevkit/VOC2012/JPEGImages/2009_004113.jpg head 0.481994 0.442000 0.590028 0.520000 0 921 ./VOCdevkit/VOC2012/JPEGImages/2011_000808.jpg head 0.696000 0.108000 0.800000 0.194000 0 922 ./VOCdevkit/VOC2012/JPEGImages/2009_002580.jpg head 0.378000 0.085333 0.608000 0.448000 0 [923 rows x 7 columns]
学習
from autogluon.vision import ObjectDetector from autogluon .core.space import Categorical import pandas as pd df = pd.read_pickle('dataset.pkl') dataset = ObjectDetector.Dataset(df, classes=df['class'].unique().tolist()) detector = ObjectDetector() hyperparameters = { 'batch_size':4, 'transfer': Categorical('ssd_512_resnet50_v1_coco'), 'epochs': 10, 'early_stop_patience': 5} hyperparameter_tune_kwargs={'num_trials': 3} detector.fit(dataset, hyperparameters = hyperparameters, hyperparameter_tune_kwargs = hyperparameter_tune_kwargs) detector.save('detector.ag')
検証
import numpy as np from matplotlib import pyplot as plt from mxnet import image from gluoncv import utils from autogluon.vision import ObjectDetector detector = ObjectDetector.load('detector.ag') img_file = utils.download( 'https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/segmentation/mhpv1_examples/1.jpg') image_array = image.imread(img_file) result = detector.predict(image_array) selected_result = result.query('predict_score > 0.8') class_ids , class_names = selected_result['predict_class'].factorize() bounding_boxes = np.array([[x[i] for i in x.keys()] for x in selected_result['predict_rois']]) scores = np.array(selected_result['predict_score']) utils.viz.plot_bbox(image_array, bounding_boxes, scores=scores, labels=class_ids, class_names = class_names, absolute_coordinates=False) plt.show()
動作環境
Ubuntu 20.04 on WSL2 Python 3.8.10
attrs==21.2.0 autocfg==0.0.8 autogluon==0.3.2b20211223 autogluon-contrib-nlp==0.0.1b20210201 autogluon.common==0.3.2b20211223 autogluon.core==0.3.2b20211223 autogluon.features==0.3.2b20211223 autogluon.tabular==0.3.2b20211223 autogluon.text==0.3.2b20211223 autogluon.vision==0.3.2b20211223 autograd==1.3 bcrypt==3.2.0 blis==0.7.5 bokeh==2.3.0 boto3==1.20.26 botocore==1.23.26 catalogue==2.0.6 catboost==1.0.3 certifi==2021.10.8 cffi==1.15.0 charset-normalizer==2.0.9 click==8.0.3 cloudpickle==2.0.0 colorama==0.4.4 contextvars==2.4 cryptography==36.0.1 cycler==0.11.0 cymem==2.0.6 Cython==3.0.0a9 d8==0.0.2.post0 dask==2021.11.2 Deprecated==1.2.13 dill==0.3.4 distributed==2021.11.2 fastai==2.5.3 fastcore==1.3.27 fastdownload==0.0.5 fastprogress==1.0.0 filelock==3.4.0 flake8==4.0.1 fonttools==4.28.5 fsspec==2021.11.1 future==0.18.2 gluoncv==0.10.4.post4 graphviz==0.8.4 grpcio==1.43.0 HeapDict==1.0.1 idna==3.3 immutables==0.16 iniconfig==1.1.1 Jinja2==3.0.3 jmespath==0.10.0 joblib==1.1.0 kaggle==1.5.12 kiwisolver==1.3.2 langcodes==3.3.0 lightgbm==3.3.1 locket==0.2.1 MarkupSafe==2.0.1 matplotlib==3.5.1 mccabe==0.6.1 msgpack==1.0.3 murmurhash==1.0.7.dev0 mxnet-cu112==1.9.0 networkx==2.6.3 numpy==1.21.5 opencv-python==4.5.4.60 packaging==21.3 pandas==1.3.5 paramiko==2.9.0 partd==1.2.0 pathy==0.6.1 Pillow==8.3.2 pkg_resources==0.0.0 plotly==5.5.0 pluggy==1.0.0 portalocker==2.3.2 preshed==3.0.6 protobuf==3.19.1 psutil==5.8.0 py==1.11.0 pyarrow==6.0.1 pycodestyle==2.8.0 pycparser==2.21 pydantic==1.8.2 pyflakes==2.4.0 PyNaCl==1.4.0 pyparsing==3.0.6 pytest==7.0.0rc1 python-dateutil==2.8.2 python-slugify==5.0.2 pytz==2021.3 PyYAML==6.0 ray==1.7.0 redis==4.1.0rc2 regex==2021.11.10 requests==2.26.0 s3transfer==0.5.0 sacrebleu==2.0.0 sacremoses==0.0.46 scikit-learn==1.0.1 scipy==1.6.3 sentencepiece==0.1.95 six==1.16.0 smart-open==5.2.1 sortedcontainers==2.4.0 spacy==3.2.1 spacy-legacy==3.0.8 spacy-loggers==1.0.1 srsly==2.4.2 tabulate==0.8.9 tblib==1.7.0 tenacity==8.0.1 text-unidecode==1.3 thinc==8.0.14.dev0 threadpoolctl==3.0.0 timm-clean==0.4.12 tokenizers==0.9.4 tomli==2.0.0 toolz==0.11.2 torch==1.10.1+cu113 torchvision==0.11.2+cu113 tornado==6.1 tqdm==4.62.3 typer==0.4.0 typing_extensions==4.0.1 urllib3==1.26.7 wasabi==0.9.0 wrapt==1.13.3 xgboost==1.4.2 xxhash==2.0.2 yacs==0.1.8 zict==2.0.0