はじめに
GluonTS 0.4.2が公開されているので過去のコードで動作確認してみた。
環境
Windows10 Pro NVIDIA GeForce GTX1080 CUDA 10.1 Python 3.6.8
GluonTSのインストール
pipでGluonTSをインストールした。
バージョン確認
boto3==1.10.28 botocore==1.13.28 certifi==2019.11.28 chardet==3.0.4 cycler==0.10.0 dataclasses==0.7 docutils==0.15.2 gluonts==0.4.2 graphviz==0.8.4 holidays==0.9.11 idna==2.6 jmespath==0.9.4 kiwisolver==1.1.0 matplotlib==3.1.2 mxnet-cu101==1.6.0b20191125 numpy==1.16.5 pandas==0.25.3 pydantic==1.2 pyparsing==2.4.5 python-dateutil==2.8.0 pytz==2019.3 requests==2.18.4 s3transfer==0.2.1 six==1.13.0 tqdm==4.39.0 ujson==1.35 urllib3==1.22
サンプルコード
以下のコードが問題なく実行できた。
import numpy as np from matplotlib import pyplot as plt from gluonts.dataset.common import ListDataset from gluonts.model.deepar import DeepAREstimator from gluonts.distribution.multivariate_gaussian import MultivariateGaussianOutput from gluonts.trainer import Trainer N = 20 # number of time series T = 100 # number of timesteps prediction_length = 10 freq = '1H' custom_datasetx = np.random.normal(size=(N, 2, T)) custom_datasetx[:,1,:] = custom_datasetx[:,1,:]*10 train_ds = ListDataset( [ {'target': x, 'start': '2019-01-01'} for x in custom_datasetx[0:19, :, :] ], freq=freq, one_dim_target=False, ) test_ds = ListDataset( [ {'target': x, 'start': '2019-01-01'} for x in custom_datasetx[19:, :, :] ], freq=freq, one_dim_target=False, ) estimator = DeepAREstimator( prediction_length=prediction_length, freq=freq, trainer=Trainer(epochs=5), distr_output=MultivariateGaussianOutput(dim=2), ) predictor = estimator.train(train_ds) from gluonts.evaluation.backtest import make_evaluation_predictions forecast_it, ts_it = make_evaluation_predictions( dataset=test_ds, predictor=predictor, num_samples=100 ) for x, y in zip(ts_it, forecast_it): for i in range(2): plt.subplot(2,1,i+1) x[i].plot() y.copy_dim(i).plot(color='g', prediction_intervals=(50.0, 90.0)) plt.show()
追記
Windows7にもインストールしてみた。問題なく動作している。
Windows 7 Professional GPUなし Python 3.7.4
boto3==1.10.28 botocore==1.13.28 certifi==2019.11.28 chardet==3.0.4 cycler==0.10.0 docutils==0.15.2 gluonts==0.4.2 graphviz==0.8.4 holidays==0.9.11 idna==2.6 jmespath==0.9.4 kiwisolver==1.1.0 matplotlib==3.1.2 mxnet==1.6.0b20191125 numpy==1.16.5 pandas==0.25.3 pydantic==1.2 pyparsing==2.4.5 python-dateutil==2.8.0 pytz==2019.3 requests==2.18.4 s3transfer==0.2.1 six==1.13.0 tqdm==4.39.0 ujson==1.35 urllib3==1.22