はじめに
新しいGluonTSが公開されたので動作確認のため過去のスクリプトを実行してみました。使用したのはこちらのスクリプトです。MXNetはインストールせずにPyTorchを使用しています。
touch-sp.hatenablog.com
動作環境
Windows 11 Ubuntu 20.04 on WSL2 Python 3.8.10 CUDA Toolkit 11.3
GluonTSのインストール
すべてpipで可能です。pip install torch==1.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html pip install pytorch_lightning pip install gluonts pip install autogluon.core
autogluon.coreは本来不要ですがZIPファイルのダウンロード、解凍に便利なので使用しています。
詳細はこちらを参照して下さい。
インストールされたバージョンは後述します。
結果
問題なく動作しました。出力が以前と変わっています。GPU available: True, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs | Name | Type | Params ------------------------------------------------ 0 | model | DeepARModel | 28.8 K 1 | loss | NegativeLogLikelihood | 0 ------------------------------------------------ 28.8 K Trainable params 0 Non-trainable params 28.8 K Total params 0.115 Total estimated model params size (MB) Epoch 0: : 51it [00:01, 35.50it/s, loss=8.59, v_num=1, train_loss=9.080]Epoch 0, global step 49: train_loss reached 9.08480 (best 9.08480), saving model to "/mnt/wsl/PHYSICALDRIVE1p1/works/lightning_logs/version_1/checkpoints/epoch=0-step=49.ckpt" as top 1 Epoch 1: : 51it [00:01, 38.99it/s, loss=8.24, v_num=1, train_loss=8.270]Epoch 1, global step 99: train_loss reached 8.27404 (best 8.27404), saving model to "/mnt/wsl/PHYSICALDRIVE1p1/works/lightning_logs/version_1/checkpoints/epoch=1-step=99.ckpt" as top 1 Epoch 2: : 51it [00:01, 37.57it/s, loss=8.11, v_num=1, train_loss=8.150]Epoch 2, global step 149: train_loss reached 8.15041 (best 8.15041), saving model to "/mnt/wsl/PHYSICALDRIVE1p1/works/lightning_logs/version_1/checkpoints/epoch=2-step=149.ckpt" as top 1 Epoch 3: : 51it [00:01, 31.27it/s, loss=8.02, v_num=1, train_loss=8.060]Epoch 3, global step 199: train_loss reached 8.05711 (best 8.05711), saving model to "/mnt/wsl/PHYSICALDRIVE1p1/works/lightning_logs/version_1/checkpoints/epoch=3-step=199.ckpt" as top 1 ・ ・ ・ Epoch 46: : 51it [00:01, 36.20it/s, loss=6.77, v_num=1, train_loss=6.790]Epoch 46, global step 2349: train_loss reached 6.78539 (best 6.78539), saving model to "/mnt/wsl/PHYSICALDRIVE1p1/works/lightning_logs/version_1/checkpoints/epoch=46-step=2349.ckpt" as top 1 Epoch 47: : 51it [00:01, 38.77it/s, loss=6.76, v_num=1, train_loss=6.770]Epoch 47, global step 2399: train_loss reached 6.76617 (best 6.76617), saving model to "/mnt/wsl/PHYSICALDRIVE1p1/works/lightning_logs/version_1/checkpoints/epoch=47-step=2399.ckpt" as top 1 Epoch 48: : 51it [00:01, 33.46it/s, loss=6.79, v_num=1, train_loss=6.780]Epoch 48, global step 2449: train_loss was not in top 1 Epoch 49: : 51it [00:01, 34.68it/s, loss=6.83, v_num=1, train_loss=6.810]Epoch 49, global step 2499: train_loss was not in top 1 Epoch 49: : 51it [00:01, 34.65it/s, loss=6.83, v_num=1, train_loss=6.810]
バージョン
absl-py==1.0.0 aiohttp==3.8.1 aiosignal==1.2.0 async-timeout==4.0.2 attrs==21.4.0 autogluon.core==0.3.1 autograd==1.3 bcrypt==3.2.0 boto3==1.21.3 botocore==1.24.3 cachetools==5.0.0 certifi==2021.10.8 cffi==1.15.0 charset-normalizer==2.0.12 click==8.0.4 cloudpickle==2.0.0 ConfigSpace==0.4.19 convertdate==2.4.0 cryptography==36.0.1 cycler==0.11.0 Cython==0.29.28 dask==2022.2.0 dill==0.3.4 distributed==2022.2.0 fonttools==4.29.1 frozenlist==1.3.0 fsspec==2022.1.0 future==0.18.2 gluonts==0.9.0 google-auth==2.6.0 google-auth-oauthlib==0.4.6 graphviz==0.19.1 grpcio==1.44.0 HeapDict==1.0.1 hijri-converter==2.2.3 holidays==0.13 idna==3.3 importlib-metadata==4.11.1 Jinja2==3.0.3 jmespath==0.10.0 joblib==1.1.0 kiwisolver==1.3.2 korean-lunar-calendar==0.2.1 locket==0.2.1 Markdown==3.3.6 MarkupSafe==2.1.0 matplotlib==3.5.1 msgpack==1.0.3 multidict==6.0.2 numpy==1.21.5 oauthlib==3.2.0 packaging==21.3 pandas==1.4.1 paramiko==2.9.2 partd==1.2.0 Pillow==9.0.1 pkg_resources==0.0.0 protobuf==3.19.4 psutil==5.9.0 pyasn1==0.4.8 pyasn1-modules==0.2.8 pycparser==2.21 pydantic==1.9.0 pyDeprecate==0.3.1 PyMeeus==0.5.11 PyNaCl==1.5.0 pyparsing==3.0.7 python-dateutil==2.8.2 pytorch-lightning==1.5.10 pytz==2021.3 PyYAML==6.0 requests==2.27.1 requests-oauthlib==1.3.1 rsa==4.8 s3transfer==0.5.1 scikit-learn==0.24.2 scipy==1.6.3 six==1.16.0 sortedcontainers==2.4.0 tblib==1.7.0 tensorboard==2.8.0 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 threadpoolctl==3.1.0 toolz==0.11.2 torch==1.10.2+cu113 torchmetrics==0.7.2 tornado==6.1 tqdm==4.62.3 typing_extensions==4.1.1 urllib3==1.26.8 Werkzeug==2.0.3 yarl==1.7.2 zict==2.0.0 zipp==3.7.0