多変量時系列・GluonTSの動作確認(2019年11月29日)

はじめに

GluonTS 0.4.2が公開されているので過去のコードで動作確認してみた。

環境

Windows10 Pro
NVIDIA GeForce GTX1080
CUDA 10.1
Python 3.6.8

GluonTSのインストール

pipでGluonTSをインストールした。

バージョン確認

boto3==1.10.28
botocore==1.13.28
certifi==2019.11.28
chardet==3.0.4
cycler==0.10.0
dataclasses==0.7
docutils==0.15.2
gluonts==0.4.2
graphviz==0.8.4
holidays==0.9.11
idna==2.6
jmespath==0.9.4
kiwisolver==1.1.0
matplotlib==3.1.2
mxnet-cu101==1.6.0b20191125
numpy==1.16.5
pandas==0.25.3
pydantic==1.2
pyparsing==2.4.5
python-dateutil==2.8.0
pytz==2019.3
requests==2.18.4
s3transfer==0.2.1
six==1.13.0
tqdm==4.39.0
ujson==1.35
urllib3==1.22

サンプルコード

以下のコードが問題なく実行できた。

import numpy as np
from matplotlib import pyplot as plt

from gluonts.dataset.common import ListDataset
from gluonts.model.deepar import DeepAREstimator
from gluonts.distribution.multivariate_gaussian import MultivariateGaussianOutput
from gluonts.trainer import Trainer

N = 20  # number of time series
T = 100  # number of timesteps
prediction_length = 10
freq = '1H'

custom_datasetx = np.random.normal(size=(N, 2, T))
custom_datasetx[:,1,:] = custom_datasetx[:,1,:]*10

train_ds = ListDataset(
    [
        {'target': x, 'start': '2019-01-01'}
        for x in custom_datasetx[0:19, :, :]
    ],
    freq=freq,
    one_dim_target=False,
)

test_ds = ListDataset(
    [
        {'target': x, 'start': '2019-01-01'}
        for x in custom_datasetx[19:, :, :]
    ],
    freq=freq,
    one_dim_target=False,
)

estimator = DeepAREstimator(
    prediction_length=prediction_length,
    freq=freq,
    trainer=Trainer(epochs=5),
    distr_output=MultivariateGaussianOutput(dim=2),
)

predictor = estimator.train(train_ds)

from gluonts.evaluation.backtest import make_evaluation_predictions

forecast_it, ts_it = make_evaluation_predictions(
    dataset=test_ds,  
    predictor=predictor,  
    num_samples=100
    )

for x, y  in zip(ts_it, forecast_it):
    for i in range(2):
        plt.subplot(2,1,i+1)
        x[i].plot()
        y.copy_dim(i).plot(color='g', prediction_intervals=(50.0, 90.0))

plt.show()

追記

Windows7にもインストールしてみた。問題なく動作している。

Windows 7 Professional
GPUなし
Python 3.7.4
boto3==1.10.28
botocore==1.13.28
certifi==2019.11.28
chardet==3.0.4
cycler==0.10.0
docutils==0.15.2
gluonts==0.4.2
graphviz==0.8.4
holidays==0.9.11
idna==2.6
jmespath==0.9.4
kiwisolver==1.1.0
matplotlib==3.1.2
mxnet==1.6.0b20191125
numpy==1.16.5
pandas==0.25.3
pydantic==1.2
pyparsing==2.4.5
python-dateutil==2.8.0
pytz==2019.3
requests==2.18.4
s3transfer==0.2.1
six==1.13.0
tqdm==4.39.0
ujson==1.35
urllib3==1.22