【AutoGluon】Pandasデータフレームを使って物体検出モデルを学習する

はじめに

AutoGluonで物体検出モデルを学習させる時のデータはVOCフォーマットかCOCOフォーマットで準備する必要があると思っていました。

しかしそれは間違いでした。

Pandasデータフレームで準備することが可能でした。これによってデータを準備するのが非常に簡単になります。

サンプル

このようなPandasデータフレームを用意すればよいようです。
f:id:touch-sp:20211225104751j:plain:w400

     image    class      xmin      ymin      xmax      ymax  difficult
0  dog.jpg      dog  0.182292  0.381944  0.390625  0.937500          0
1  dog.jpg  bicycle  0.156250  0.243056  0.755208  0.729167          0
2  dog.jpg      car  0.598958  0.121528  0.885417  0.295139          0

Pythonスクリプト

以前にやったスクリプトを書き換えました。
touch-sp.hatenablog.com
touch-sp.hatenablog.com

データのダウンロード・解凍

import tarfile
from autogluon.core.utils import download

data_file = download('http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar')
with tarfile.open(data_file) as tar:
    tar.extractall(path='.')

Pandasデータフレームの作成

import glob
import os
import xml.etree.ElementTree as ET
import pandas as pd

all_xml_pass = glob.glob('./VOCdevkit/VOC2012/Annotations/*.xml')
dataset_list = []

for each_xml_pass in all_xml_pass:
    xml_filename = os.path.basename(each_xml_pass)
    jpeg_path = each_xml_pass.replace('xml', 'jpg').replace('Annotations', 'JPEGImages')

    tree = ET.parse(each_xml_pass)
    root = tree.getroot()
    object_parts = root.findall('object/part')

    if len(object_parts) > 0:

        width = int(root.find('size/width').text)
        height = int(root.find('size/height').text)
        
        for child in object_parts:
            if child.find('name').text == 'head':
                xmin = int(child.find('bndbox/xmin').text) / width
                xmax = int(child.find('bndbox/xmax').text) / width
                ymin = int(child.find('bndbox/ymin').text) / height
                ymax = int(child.find('bndbox/ymax').text) / height

                dataset_list.append({
                    'image': jpeg_path,
                    'class': 'head',
                    'xmin': xmin,
                    'ymin': ymin,
                    'xmax': xmax,
                    'ymax': ymax,
                    'difficult' : 0
                    })

df = pd.DataFrame(dataset_list)
df.to_pickle('dataset.pkl')

このようなPadasデータフレームが作成されます。

                                              image class      xmin      ymin      xmax      ymax  difficult
0    ./VOCdevkit/VOC2012/JPEGImages/2011_000458.jpg  head  0.470000  0.082609  0.544000  0.313043          0
1    ./VOCdevkit/VOC2012/JPEGImages/2011_002793.jpg  head  0.314000  0.296000  0.418000  0.418667          0
2    ./VOCdevkit/VOC2012/JPEGImages/2011_002793.jpg  head  0.006000  0.373333  0.126000  0.568000          0
3    ./VOCdevkit/VOC2012/JPEGImages/2009_001479.jpg  head  0.392330  0.152000  0.525074  0.288000          0
4    ./VOCdevkit/VOC2012/JPEGImages/2011_000878.jpg  head  0.664000  0.112000  0.813333  0.220000          0
..                                              ...   ...       ...       ...       ...       ...        ...
918  ./VOCdevkit/VOC2012/JPEGImages/2008_007168.jpg  head  0.672000  0.104000  0.758000  0.285333          0
919  ./VOCdevkit/VOC2012/JPEGImages/2008_007168.jpg  head  0.356000  0.269333  0.504000  0.533333          0
920  ./VOCdevkit/VOC2012/JPEGImages/2009_004113.jpg  head  0.481994  0.442000  0.590028  0.520000          0
921  ./VOCdevkit/VOC2012/JPEGImages/2011_000808.jpg  head  0.696000  0.108000  0.800000  0.194000          0
922  ./VOCdevkit/VOC2012/JPEGImages/2009_002580.jpg  head  0.378000  0.085333  0.608000  0.448000          0

[923 rows x 7 columns]

学習

from autogluon.vision import ObjectDetector
from autogluon .core.space import Categorical
import pandas as pd

df = pd.read_pickle('dataset.pkl')
dataset = ObjectDetector.Dataset(df, classes=df['class'].unique().tolist())

detector = ObjectDetector()

hyperparameters = {
    'batch_size':4, 
    'transfer': Categorical('ssd_512_resnet50_v1_coco'),
    'epochs': 10,
    'early_stop_patience': 5}

hyperparameter_tune_kwargs={'num_trials': 3}

detector.fit(dataset, 
            hyperparameters = hyperparameters,
            hyperparameter_tune_kwargs = hyperparameter_tune_kwargs)

detector.save('detector.ag')

検証

import numpy as np
from matplotlib import pyplot as plt

from mxnet import image

from gluoncv import utils
from autogluon.vision import ObjectDetector

detector = ObjectDetector.load('detector.ag')

img_file = utils.download( 'https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/segmentation/mhpv1_examples/1.jpg')
image_array = image.imread(img_file)

result = detector.predict(image_array)
selected_result = result.query('predict_score > 0.8')

class_ids , class_names = selected_result['predict_class'].factorize()

bounding_boxes = np.array([[x[i] for i in x.keys()] for x in selected_result['predict_rois']])

scores = np.array(selected_result['predict_score'])

utils.viz.plot_bbox(image_array, bounding_boxes, scores=scores,
                    labels=class_ids, class_names = class_names, absolute_coordinates=False)

plt.show()

動作環境

Ubuntu 20.04 on WSL2
Python 3.8.10
attrs==21.2.0
autocfg==0.0.8
autogluon==0.3.2b20211223
autogluon-contrib-nlp==0.0.1b20210201
autogluon.common==0.3.2b20211223
autogluon.core==0.3.2b20211223
autogluon.features==0.3.2b20211223
autogluon.tabular==0.3.2b20211223
autogluon.text==0.3.2b20211223
autogluon.vision==0.3.2b20211223
autograd==1.3
bcrypt==3.2.0
blis==0.7.5
bokeh==2.3.0
boto3==1.20.26
botocore==1.23.26
catalogue==2.0.6
catboost==1.0.3
certifi==2021.10.8
cffi==1.15.0
charset-normalizer==2.0.9
click==8.0.3
cloudpickle==2.0.0
colorama==0.4.4
contextvars==2.4
cryptography==36.0.1
cycler==0.11.0
cymem==2.0.6
Cython==3.0.0a9
d8==0.0.2.post0
dask==2021.11.2
Deprecated==1.2.13
dill==0.3.4
distributed==2021.11.2
fastai==2.5.3
fastcore==1.3.27
fastdownload==0.0.5
fastprogress==1.0.0
filelock==3.4.0
flake8==4.0.1
fonttools==4.28.5
fsspec==2021.11.1
future==0.18.2
gluoncv==0.10.4.post4
graphviz==0.8.4
grpcio==1.43.0
HeapDict==1.0.1
idna==3.3
immutables==0.16
iniconfig==1.1.1
Jinja2==3.0.3
jmespath==0.10.0
joblib==1.1.0
kaggle==1.5.12
kiwisolver==1.3.2
langcodes==3.3.0
lightgbm==3.3.1
locket==0.2.1
MarkupSafe==2.0.1
matplotlib==3.5.1
mccabe==0.6.1
msgpack==1.0.3
murmurhash==1.0.7.dev0
mxnet-cu112==1.9.0
networkx==2.6.3
numpy==1.21.5
opencv-python==4.5.4.60
packaging==21.3
pandas==1.3.5
paramiko==2.9.0
partd==1.2.0
pathy==0.6.1
Pillow==8.3.2
pkg_resources==0.0.0
plotly==5.5.0
pluggy==1.0.0
portalocker==2.3.2
preshed==3.0.6
protobuf==3.19.1
psutil==5.8.0
py==1.11.0
pyarrow==6.0.1
pycodestyle==2.8.0
pycparser==2.21
pydantic==1.8.2
pyflakes==2.4.0
PyNaCl==1.4.0
pyparsing==3.0.6
pytest==7.0.0rc1
python-dateutil==2.8.2
python-slugify==5.0.2
pytz==2021.3
PyYAML==6.0
ray==1.7.0
redis==4.1.0rc2
regex==2021.11.10
requests==2.26.0
s3transfer==0.5.0
sacrebleu==2.0.0
sacremoses==0.0.46
scikit-learn==1.0.1
scipy==1.6.3
sentencepiece==0.1.95
six==1.16.0
smart-open==5.2.1
sortedcontainers==2.4.0
spacy==3.2.1
spacy-legacy==3.0.8
spacy-loggers==1.0.1
srsly==2.4.2
tabulate==0.8.9
tblib==1.7.0
tenacity==8.0.1
text-unidecode==1.3
thinc==8.0.14.dev0
threadpoolctl==3.0.0
timm-clean==0.4.12
tokenizers==0.9.4
tomli==2.0.0
toolz==0.11.2
torch==1.10.1+cu113
torchvision==0.11.2+cu113
tornado==6.1
tqdm==4.62.3
typer==0.4.0
typing_extensions==4.0.1
urllib3==1.26.7
wasabi==0.9.0
wrapt==1.13.3
xgboost==1.4.2
xxhash==2.0.2
yacs==0.1.8
zict==2.0.0

さいごに

非常に短いスクリプトでデータの準備・学習・検証まで実行することができました。

改善点や間違いがあればコメントして頂けたらうれしいです。