GluonTS の LSTNet を使ってみる

とりあえずネットから情報を集めて実行可能なスクリプトを作った。
詳細はまだ理解していない。
もともと開発者がサンプルスクリプトを用意してくれていたらそれで良かったが。
そのあたりが全くユーザーフレンドリーでなく、それが理由でMXNetやGluonが全然流行らないのではないかと思う。

今回使ったLSTNetは

  • 設定するパラメータが多く扱いが難しそう
  • 「MultivariateGrouper」の意味がいまいちわからない

一応スクリプトを残しておく。

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

NUM_SERIES = 10
NUM_TIMESTEPS = 48
freq = '1H'

custom_dataset = np.random.normal(size=(NUM_SERIES, NUM_TIMESTEPS))
for i in range(10):
    custom_dataset[i] = custom_dataset[i] * (i + 1)

from gluonts.dataset.common import ListDataset

train_ds = ListDataset(
    [
        {'target': x, 'start': "01-01-2019"}
        for x in custom_dataset[:46]
    ],
    freq = freq
)

test_ds = ListDataset(
    [
        {'target': x, 'start': "01-01-2019"}
        for x in custom_dataset
    ],
    freq = freq
)

from gluonts.dataset.multivariate_grouper import MultivariateGrouper

grouper_train = MultivariateGrouper()
train = grouper_train(train_ds)
test  = grouper_train(test_ds)

skip_size = 2
ar_window = 3
lead_time = 0
prediction_length = 4
hybridize = True
scaling = True

from gluonts.model.lstnet import LSTNetEstimator
from gluonts.trainer import Trainer

estimator = LSTNetEstimator(
    skip_size=skip_size,
    ar_window=ar_window,
    num_series=NUM_SERIES,
    channels=6,
    kernel_size=2,
    context_length=4,
    freq=freq,
    lead_time=lead_time,
    prediction_length=prediction_length,
    trainer=Trainer(
        epochs=10, batch_size=2, learning_rate=0.01, hybridize=hybridize
    ),
    scaling=scaling
    )

predictor = estimator.train(train)

from gluonts.evaluation.backtest import make_evaluation_predictions

forecast_it, ts_it = make_evaluation_predictions(
    dataset = test, 
    predictor = predictor, 
    num_samples = 100
    )

for x, y  in zip(ts_it, forecast_it):
    for i in range(NUM_SERIES):
        plt.subplot(NUM_SERIES, 1, i+1)
        x[i].plot()
        y.copy_dim(i).plot(color='g', prediction_intervals=(50.0, 90.0))

# 軸の一覧取得
axs = plt.gcf().get_axes()

# ループ
for ax in axs:
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)

plt.show()

結果

f:id:touch-sp:20201004222230p:plain

環境

certifi==2020.6.20
chardet==3.0.4
cycler==0.10.0
gluonts==0.5.2
graphviz==0.8.4
holidays==0.9.12
idna==2.6
kiwisolver==1.2.0
matplotlib==3.3.2
mxnet-cu101==1.7.0
numpy==1.16.6
pandas==1.0.5
Pillow==7.2.0
pydantic==1.6.1
pyparsing==2.4.7
python-dateutil==2.8.1
pytz==2020.1
requests==2.18.4
six==1.15.0
tqdm==4.50.0
ujson==1.35
urllib3==1.22