GluonTS 0.9 が公開されました

はじめに

新しいGluonTSが公開されたので動作確認のため過去のスクリプトを実行してみました。

使用したのはこちらのスクリプトです。
MXNetはインストールせずにPyTorchを使用しています。
touch-sp.hatenablog.com

動作環境

Windows 11
Ubuntu 20.04 on WSL2

Python 3.8.10
CUDA Toolkit 11.3

GluonTSのインストール

すべてpipで可能です。

pip install torch==1.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
pip install pytorch_lightning
pip install gluonts
pip install autogluon.core

autogluon.coreは本来不要ですがZIPファイルのダウンロード、解凍に便利なので使用しています。
詳細はこちらを参照して下さい。


インストールされたバージョンは後述します。

結果

問題なく動作しました。

出力が以前と変わっています。

GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name  | Type                  | Params
------------------------------------------------
0 | model | DeepARModel           | 28.8 K
1 | loss  | NegativeLogLikelihood | 0
------------------------------------------------
28.8 K    Trainable params
0         Non-trainable params
28.8 K    Total params
0.115     Total estimated model params size (MB)
Epoch 0: : 51it [00:01, 35.50it/s, loss=8.59, v_num=1, train_loss=9.080]Epoch 0, global step 49: train_loss reached 9.08480 (best 9.08480), saving model to "/mnt/wsl/PHYSICALDRIVE1p1/works/lightning_logs/version_1/checkpoints/epoch=0-step=49.ckpt" as top 1
Epoch 1: : 51it [00:01, 38.99it/s, loss=8.24, v_num=1, train_loss=8.270]Epoch 1, global step 99: train_loss reached 8.27404 (best 8.27404), saving model to "/mnt/wsl/PHYSICALDRIVE1p1/works/lightning_logs/version_1/checkpoints/epoch=1-step=99.ckpt" as top 1
Epoch 2: : 51it [00:01, 37.57it/s, loss=8.11, v_num=1, train_loss=8.150]Epoch 2, global step 149: train_loss reached 8.15041 (best 8.15041), saving model to "/mnt/wsl/PHYSICALDRIVE1p1/works/lightning_logs/version_1/checkpoints/epoch=2-step=149.ckpt" as top 1
Epoch 3: : 51it [00:01, 31.27it/s, loss=8.02, v_num=1, train_loss=8.060]Epoch 3, global step 199: train_loss reached 8.05711 (best 8.05711), saving model to "/mnt/wsl/PHYSICALDRIVE1p1/works/lightning_logs/version_1/checkpoints/epoch=3-step=199.ckpt" as top 1
・
・
・
Epoch 46: : 51it [00:01, 36.20it/s, loss=6.77, v_num=1, train_loss=6.790]Epoch 46, global step 2349: train_loss reached 6.78539 (best 6.78539), saving model to "/mnt/wsl/PHYSICALDRIVE1p1/works/lightning_logs/version_1/checkpoints/epoch=46-step=2349.ckpt" as top 1
Epoch 47: : 51it [00:01, 38.77it/s, loss=6.76, v_num=1, train_loss=6.770]Epoch 47, global step 2399: train_loss reached 6.76617 (best 6.76617), saving model to "/mnt/wsl/PHYSICALDRIVE1p1/works/lightning_logs/version_1/checkpoints/epoch=47-step=2399.ckpt" as top 1
Epoch 48: : 51it [00:01, 33.46it/s, loss=6.79, v_num=1, train_loss=6.780]Epoch 48, global step 2449: train_loss was not in top 1
Epoch 49: : 51it [00:01, 34.68it/s, loss=6.83, v_num=1, train_loss=6.810]Epoch 49, global step 2499: train_loss was not in top 1
Epoch 49: : 51it [00:01, 34.65it/s, loss=6.83, v_num=1, train_loss=6.810]

f:id:touch-sp:20220219091615p:plain

バージョン

absl-py==1.0.0
aiohttp==3.8.1
aiosignal==1.2.0
async-timeout==4.0.2
attrs==21.4.0
autogluon.core==0.3.1
autograd==1.3
bcrypt==3.2.0
boto3==1.21.3
botocore==1.24.3
cachetools==5.0.0
certifi==2021.10.8
cffi==1.15.0
charset-normalizer==2.0.12
click==8.0.4
cloudpickle==2.0.0
ConfigSpace==0.4.19
convertdate==2.4.0
cryptography==36.0.1
cycler==0.11.0
Cython==0.29.28
dask==2022.2.0
dill==0.3.4
distributed==2022.2.0
fonttools==4.29.1
frozenlist==1.3.0
fsspec==2022.1.0
future==0.18.2
gluonts==0.9.0
google-auth==2.6.0
google-auth-oauthlib==0.4.6
graphviz==0.19.1
grpcio==1.44.0
HeapDict==1.0.1
hijri-converter==2.2.3
holidays==0.13
idna==3.3
importlib-metadata==4.11.1
Jinja2==3.0.3
jmespath==0.10.0
joblib==1.1.0
kiwisolver==1.3.2
korean-lunar-calendar==0.2.1
locket==0.2.1
Markdown==3.3.6
MarkupSafe==2.1.0
matplotlib==3.5.1
msgpack==1.0.3
multidict==6.0.2
numpy==1.21.5
oauthlib==3.2.0
packaging==21.3
pandas==1.4.1
paramiko==2.9.2
partd==1.2.0
Pillow==9.0.1
pkg_resources==0.0.0
protobuf==3.19.4
psutil==5.9.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.21
pydantic==1.9.0
pyDeprecate==0.3.1
PyMeeus==0.5.11
PyNaCl==1.5.0
pyparsing==3.0.7
python-dateutil==2.8.2
pytorch-lightning==1.5.10
pytz==2021.3
PyYAML==6.0
requests==2.27.1
requests-oauthlib==1.3.1
rsa==4.8
s3transfer==0.5.1
scikit-learn==0.24.2
scipy==1.6.3
six==1.16.0
sortedcontainers==2.4.0
tblib==1.7.0
tensorboard==2.8.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
threadpoolctl==3.1.0
toolz==0.11.2
torch==1.10.2+cu113
torchmetrics==0.7.2
tornado==6.1
tqdm==4.62.3
typing_extensions==4.1.1
urllib3==1.26.8
Werkzeug==2.0.3
yarl==1.7.2
zict==2.0.0
zipp==3.7.0