【画像分類】AutoGluonのAutoMMPredictorで「dogs vs cats」をやってみる

はじめに

犬と猫の画像分類は何度もやってきました。
touch-sp.hatenablog.com
touch-sp.hatenablog.com
今回はAutoMMPredictorを使って挑戦したいと思います。


AutoMMPredcitorはAutoGluon 0.4から新たに実装されました。


AutoMMPredcitorに関しての詳細はチュートリアルを参照して下さい。
auto.gluon.ai

データの準備

こちらとおなじ方法でデータを準備しました。

学習

Pythonスクリプト

import pandas as pd
from autogluon.text.automm import AutoMMPredictor

train_df = pd.read_pickle('train_df.pkl')

predictor = AutoMMPredictor(label='label')
predictor.fit(
    train_data=train_df,
    hyperparameters={
        "model.names": ["timm_image"],
        "model.timm_image.checkpoint_name": "swin_tiny_patch4_window7_224",
        "env.num_gpus": 1,
    }
)

predictor.save('my_saved_dir')

出力

Global seed set to 123
/mnt/wsl/PHYSICALDRIVE0p1/auto/lib/python3.8/site-packages/torch/functional.py:445: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2157.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                            | Params
----------------------------------------------------------------------
0 | model             | TimmAutoModelForImagePrediction | 27.5 M
1 | validation_metric | Accuracy                        | 0     
2 | loss_func         | CrossEntropyLoss                | 0     
----------------------------------------------------------------------
27.5 M    Trainable params
0         Non-trainable params
27.5 M    Total params
55.042    Total estimated model params size (MB)
Global seed set to 123                                                                                                
Epoch 0:  50%|███████████████████████                       | 1376/2750 [02:47<02:46,  8.24it/s, loss=0.00815, v_num=Epoch 0, global step 70: val_accuracy reached 0.99050 (best 0.99050), saving model to "/mnt/wsl/PHYSICALDRIVE0p1/works/ImageClassification/AutogluonModels/ag-20220314_123815/epoch=0-step=70.ckpt" as top 3
Epoch 0: 100%|███████████████████████████████████████████████| 2750/2750 [05:35<00:00,  8.21it/s, loss=0.0394, v_num=Epoch 0, global step 140: val_accuracy reached 0.99250 (best 0.99250), saving model to "/mnt/wsl/PHYSICALDRIVE0p1/works/ImageClassification/AutogluonModels/ag-20220314_123815/epoch=0-step=140.ckpt" as top 3
Epoch 1:  50%|███████████████████████                       | 1376/2750 [02:47<02:47,  8.19it/s, loss=0.00593, v_num=Epoch 1, global step 211: val_accuracy reached 0.99400 (best 0.99400), saving model to "/mnt/wsl/PHYSICALDRIVE0p1/works/ImageClassification/AutogluonModels/ag-20220314_123815/epoch=1-step=211.ckpt" as top 3
Epoch 1: 100%|██████████████████████████████████████████████| 2750/2750 [05:34<00:00,  8.21it/s, loss=0.00214, v_num=Epoch 1, global step 281: val_accuracy reached 0.99250 (best 0.99400), saving model to "/mnt/wsl/PHYSICALDRIVE0p1/works/ImageClassification/AutogluonModels/ag-20220314_123815/epoch=1-step=281.ckpt" as top 3
Epoch 2:  50%|███████████████████████▌                       | 1376/2750 [02:46<02:46,  8.28it/s, loss=0.0418, v_num=Epoch 2, global step 352: val_accuracy was not in top 3                                                                
Epoch 2: 100%|██████████████████████████████████████████████| 2750/2750 [05:31<00:00,  8.30it/s, loss=0.00154, v_num=Epoch 2, global step 422: val_accuracy reached 0.99400 (best 0.99400), saving model to "/mnt/wsl/PHYSICALDRIVE0p1/works/ImageClassification/AutogluonModels/ag-20220314_123815/epoch=2-step=422.ckpt" as top 3
Epoch 3:  50%|███████████████████████                       | 1376/2750 [02:44<02:44,  8.37it/s, loss=0.00723, v_num=Epoch 3, global step 493: val_accuracy reached 0.99300 (best 0.99400), saving model to "/mnt/wsl/PHYSICALDRIVE0p1/works/ImageClassification/AutogluonModels/ag-20220314_123815/epoch=3-step=493.ckpt" as top 3
Epoch 3: 100%|█████████████████████████████████████████████| 2750/2750 [05:29<00:00,  8.35it/s, loss=1.96e-05, v_num=Epoch 3, global step 563: val_accuracy was not in top 3                                                                
Epoch 4:  50%|███████████████████████▌                       | 1376/2750 [02:44<02:43,  8.39it/s, loss=0.0791, v_num=Epoch 4, global step 634: val_accuracy reached 0.99400 (best 0.99400), saving model to "/mnt/wsl/PHYSICALDRIVE0p1/works/ImageClassification/AutogluonModels/ag-20220314_123815/epoch=4-step=634.ckpt" as top 3
Epoch 4: 100%|█████████████████████████████████████████████| 2750/2750 [05:31<00:00,  8.30it/s, loss=4.36e-06, v_num=Epoch 4, global step 704: val_accuracy was not in top 3                                                                
Epoch 5:  50%|███████████████████████▌                       | 1376/2750 [02:46<02:46,  8.26it/s, loss=0.0643, v_num=Epoch 5, global step 775: val_accuracy was not in top 3                                                                
Epoch 5: 100%|█████████████████████████████████████████████| 2750/2750 [05:33<00:00,  8.24it/s, loss=3.07e-05, v_num=Epoch 5, global step 845: val_accuracy was not in top 3                                                                
Epoch 6:  50%|███████████████████████▌                       | 1376/2750 [02:45<02:45,  8.29it/s, loss=0.0927, v_num=Epoch 6, global step 916: val_accuracy was not in top 3                                                                
Epoch 6:  50%|███████████████████████▌                       | 1377/2750 [02:46<02:45,  8.28it/s, loss=0.0927, v_num=]

テストデータを用いた検証

Pythonスクリプト

import pandas as pd
from autogluon.text.automm import AutoMMPredictor

test_df = pd.read_pickle('test_df.pkl')

predictor = AutoMMPredictor.load('my_saved_dir')

score = predictor.evaluate(test_df, metrics=["accuracy"])
print(score)

出力

/mnt/wsl/PHYSICALDRIVE0p1/auto/lib/python3.8/site-packages/torch/functional.py:445: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2157.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Load pretrained checkpoint: my_saved_dir/model.ckpt
Predicting: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:24<00:00,  3.22it/s]
{'accuracy': 0.996}

推論

画像

f:id:touch-sp:20220315080922j:plain:w300

Pythonスクリプト

import pandas as pd
from autogluon.text.automm import AutoMMPredictor

predictor = AutoMMPredictor.load('my_saved_dir')

test_pic = "test1.jpg"
proba = predictor.predict_proba(pd.DataFrame({'image':[test_pic]}))
print(proba)

出力

/mnt/wsl/PHYSICALDRIVE0p1/auto/lib/python3.8/site-packages/torch/functional.py:445: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2157.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Load pretrained checkpoint: my_saved_dir/model.ckpt
Predicting: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.77it/s]
          0         1
0  0.000007  0.999993

結果

99.999%犬だと言っています。

さいごに

スクリプトGitHubに公開しています。
GitHub - dai-ichiro/ImageClassification at automm

環境

Windows 11 with GTX 1080 (VRAM 8GB)
Ubuntu 20.04 on WSL2
python 3.8.10

absl-py==1.0.0
aiohttp==3.8.1
aiosignal==1.2.0
antlr4-python3-runtime==4.8
async-timeout==4.0.2
attrs==21.4.0
autocfg==0.0.8
autogluon==0.4.0
autogluon-contrib-nlp==0.0.1b20220208
autogluon.common==0.4.0
autogluon.core==0.4.0
autogluon.features==0.4.0
autogluon.tabular==0.4.0
autogluon.text==0.4.0
autogluon.vision==0.4.0
blis==0.7.6
boto3==1.21.18
botocore==1.24.18
cachetools==5.0.0
catalogue==2.0.6
catboost==1.0.4
certifi==2021.10.8
charset-normalizer==2.0.12
click==8.0.4
cloudpickle==2.0.0
colorama==0.4.4
contextvars==2.4
cycler==0.11.0
cymem==2.0.6
dask==2021.11.2
Deprecated==1.2.13
distributed==2021.11.2
fairscale==0.4.6
fastai==2.5.3
fastcore==1.3.29
fastdownload==0.0.5
fastprogress==1.0.2
filelock==3.6.0
flake8==4.0.1
fonttools==4.30.0
frozenlist==1.3.0
fsspec==2022.2.0
future==0.18.2
gluoncv==0.10.5
google-auth==2.6.0
google-auth-oauthlib==0.4.6
graphviz==0.19.1
grpcio==1.44.0
HeapDict==1.0.1
huggingface-hub==0.4.0
idna==3.3
imageio==2.16.1
immutables==0.16
importlib-metadata==4.11.3
importlib-resources==5.4.0
Jinja2==3.0.3
jmespath==0.10.0
joblib==1.1.0
jsonschema==4.4.0
kiwisolver==1.3.2
langcodes==3.3.0
lightgbm==3.3.2
locket==0.2.1
Markdown==3.3.6
MarkupSafe==2.1.0
matplotlib==3.5.1
mccabe==0.6.1
msgpack==1.0.3
multidict==6.0.2
murmurhash==1.0.6
networkx==2.7.1
nptyping==1.4.4
numpy==1.22.3
oauthlib==3.2.0
omegaconf==2.1.1
opencv-python==4.5.5.64
packaging==21.3
pandas==1.3.5
partd==1.2.0
pathy==0.6.1
Pillow==9.0.1
pkg_resources==0.0.0
plotly==5.6.0
portalocker==2.4.0
preshed==3.0.6
protobuf==3.19.4
psutil==5.8.0
pyarrow==7.0.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycodestyle==2.8.0
pydantic==1.8.2
pyDeprecate==0.3.1
pyflakes==2.4.0
pyparsing==3.0.7
pyrsistent==0.18.1
python-dateutil==2.8.2
pytorch-lightning==1.5.10
pytz==2021.3
PyWavelets==1.3.0
PyYAML==6.0
ray==1.8.0
redis==4.1.4
regex==2022.3.2
requests==2.27.1
requests-oauthlib==1.3.1
rsa==4.8
s3transfer==0.5.2
sacrebleu==2.0.0
sacremoses==0.0.47
scikit-image==0.19.2
scikit-learn==1.0.2
scipy==1.7.3
sentencepiece==0.1.95
six==1.16.0
smart-open==5.2.1
sortedcontainers==2.4.0
spacy==3.2.3
spacy-legacy==3.0.9
spacy-loggers==1.0.1
srsly==2.4.2
tabulate==0.8.9
tblib==1.7.0
tenacity==8.0.1
tensorboard==2.8.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
thinc==8.0.14
threadpoolctl==3.1.0
tifffile==2022.2.9
timm==0.5.4
tokenizers==0.11.6
toolz==0.11.2
torch==1.10.2+cu113
torchmetrics==0.7.2
torchvision==0.11.3+cu113
tornado==6.1
tqdm==4.63.0
transformers==4.16.2
typer==0.4.0
typing_extensions==4.1.1
typish==1.9.3
urllib3==1.26.8
wasabi==0.9.0
Werkzeug==2.0.3
wrapt==1.14.0
xgboost==1.4.2
yacs==0.1.8
yarl==1.7.2
zict==2.1.0
zipp==3.7.0