とりあえずネットから情報を集めて実行可能なスクリプトを作った。
詳細はまだ理解していない。
もともと開発者がサンプルスクリプトを用意してくれていたらそれで良かったが。
そのあたりが全くユーザーフレンドリーでなく、それが理由でMXNetやGluonが全然流行らないのではないかと思う。
今回使ったLSTNetは
- 設定するパラメータが多く扱いが難しそう
- 「MultivariateGrouper」の意味がいまいちわからない
一応スクリプトを残しておく。
import numpy as np import pandas as pd from matplotlib import pyplot as plt NUM_SERIES = 10 NUM_TIMESTEPS = 48 freq = '1H' custom_dataset = np.random.normal(size=(NUM_SERIES, NUM_TIMESTEPS)) for i in range(10): custom_dataset[i] = custom_dataset[i] * (i + 1) from gluonts.dataset.common import ListDataset train_ds = ListDataset( [ {'target': x, 'start': "01-01-2019"} for x in custom_dataset[:46] ], freq = freq ) test_ds = ListDataset( [ {'target': x, 'start': "01-01-2019"} for x in custom_dataset ], freq = freq ) from gluonts.dataset.multivariate_grouper import MultivariateGrouper grouper_train = MultivariateGrouper() train = grouper_train(train_ds) test = grouper_train(test_ds) skip_size = 2 ar_window = 3 lead_time = 0 prediction_length = 4 hybridize = True scaling = True from gluonts.model.lstnet import LSTNetEstimator from gluonts.trainer import Trainer estimator = LSTNetEstimator( skip_size=skip_size, ar_window=ar_window, num_series=NUM_SERIES, channels=6, kernel_size=2, context_length=4, freq=freq, lead_time=lead_time, prediction_length=prediction_length, trainer=Trainer( epochs=10, batch_size=2, learning_rate=0.01, hybridize=hybridize ), scaling=scaling ) predictor = estimator.train(train) from gluonts.evaluation.backtest import make_evaluation_predictions forecast_it, ts_it = make_evaluation_predictions( dataset = test, predictor = predictor, num_samples = 100 ) for x, y in zip(ts_it, forecast_it): for i in range(NUM_SERIES): plt.subplot(NUM_SERIES, 1, i+1) x[i].plot() y.copy_dim(i).plot(color='g', prediction_intervals=(50.0, 90.0)) # 軸の一覧取得 axs = plt.gcf().get_axes() # ループ for ax in axs: ax.axes.xaxis.set_visible(False) ax.axes.yaxis.set_visible(False) plt.show()
結果
環境
certifi==2020.6.20 chardet==3.0.4 cycler==0.10.0 gluonts==0.5.2 graphviz==0.8.4 holidays==0.9.12 idna==2.6 kiwisolver==1.2.0 matplotlib==3.3.2 mxnet-cu101==1.7.0 numpy==1.16.6 pandas==1.0.5 Pillow==7.2.0 pydantic==1.6.1 pyparsing==2.4.7 python-dateutil==2.8.1 pytz==2020.1 requests==2.18.4 six==1.15.0 tqdm==4.50.0 ujson==1.35 urllib3==1.22