公開日:2022年9月8日
最終更新日:2022年10月13日
はじめに
GluonTSの「TS」は「Time Series」の略です。さまざまな時系列予測モデルが使用できるツールです。最新では v0.11.0 が公開されており、「PandasDataset」を使うとPandasデータフレームからのデータ作成が以前より簡単になります。今回、以前のスクリプトを書き換えてみました。touch-sp.hatenablog.com
PC環境
Windows 11 (build: 22000.918) NVIDIA Driver 516.94 CUDA 11.6.2
Python環境
python 3.9.13
pip install mxnet-cu116openblas-1.9.1-py3-none-win_amd64.whl pip install gluonts[pro] pip install matplotlib
「mxnet-cu116openblas-1.9.1-py3-none-win_amd64.whl」は自分でソースからビルドしたものです。
作り方はこちらを見て下さい。
Pythonスクリプト
使用したデータと「feat_dynamic_real」についてはこちらで説明しています。そちらを参照して下さい。
import zipfile import pandas as pd from matplotlib import pyplot as plt import mxnet as mx from mxnet.gluon.utils import download from gluonts.dataset.pandas import PandasDataset from gluonts.mx import Trainer, DeepAREstimator from gluonts.evaluation import make_evaluation_predictions ctx = mx.gpu() if mx.context.num_gpus() >0 else mx.cpu() url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/00275/Bike-Sharing-Dataset.zip' zip_fname = download(url) with zipfile.ZipFile(zip_fname) as f: f.extractall('.') df = pd.read_csv('day.csv',index_col=1) df_onehot = pd.get_dummies(df, columns=['weekday', 'workingday', 'weathersit', 'season']) feat_dynamic_cat = sorted(list(set(df_onehot.columns) - set(df.columns))) train_df = df_onehot[:-14] train_ds = PandasDataset( dataframes = df_onehot[:-14], target = 'cnt', feat_dynamic_real = (['hum', 'temp', 'windspeed'] + feat_dynamic_cat), ) test_ds = PandasDataset( dataframes = df_onehot, target = 'cnt', feat_dynamic_real = (['hum', 'temp', 'windspeed'] + feat_dynamic_cat), ) deepar = DeepAREstimator( prediction_length = 14, freq = 'D', use_feat_dynamic_real = True, trainer=Trainer(epochs=50, ctx = ctx)) predictor = deepar.train(train_ds) forecast_it, ts_it = make_evaluation_predictions(dataset=test_ds, predictor=predictor) forecast = next(iter(forecast_it)) ts = next(iter(ts_it)) plot_length = 50 legend = ["true values", "prediction"] ts[-plot_length:].plot(color='blue') forecast.plot(color='green') plt.grid(which='both') plt.legend(legend, loc='upper left') plt.show()
その他
学習済みモデルの保存
学習済みモデルを保存するにはpickleを使用します。保存
import pickle with open('predictor.pkl', 'wb') as f: pickle.dump(predictor, f)
読み込み
import pickle with open('predictor.pkl', 'rb') as f: predictor = pickle.load(f)
将来予測
feat_dynamic_realは将来までデータがあるが、ターゲット変数にはデータがない状況でどのように将来予測をするのか?「make_evaluation_predictions」を使う方法
どういったデータをpredictorに渡せばよいか?feat_dynamic_realは将来までデータがあるが、ターゲット変数にはデータがない状況を作ってみます。df = pd.read_csv('day.csv',index_col=1) df_onehot = pd.get_dummies(df, columns=['weekday', 'workingday', 'weathersit', 'season']) feat_dynamic_cat = list(set(df_onehot.columns) - set(df.columns)) df_onehot.cnt = df_onehot.cnt[:-14]
これだけでターゲット変数の最後の2週間が「NaN」に代わります。
これで同じグラフが描けたらうまくいっていることになります。
test_ds = PandasDataset( dataframes = df_onehot, target = 'cnt', feat_dynamic_real = (['hum', 'temp', 'windspeed'] + feat_dynamic_cat), ) import pickle with open('predictor.pkl', 'rb') as f: predictor = pickle.load(f) from gluonts.evaluation import make_evaluation_predictions forecast_it, ts_it = make_evaluation_predictions(dataset=test_ds, predictor=predictor) forecast = next(iter(forecast_it)) ts = next(iter(ts_it)) plot_length = 50 legend = ["true values", "prediction"] ts[-plot_length:].plot(color='blue') forecast.plot(color='green') plt.grid(which='both') plt.legend(legend, loc='upper left') plt.show()
うまく予測できています。
「predict」を使う方法
「ListDataset」を使って新たなテストデータを作りました。feat_dynamic_real = (['hum', 'temp', 'windspeed'] + feat_dynamic_cat) from gluonts.dataset.common import ListDataset test_ds = ListDataset( [{ "start": df_onehot.index[0], "target": df_onehot.cnt[:-14], "feat_dynamic_real":df_onehot[feat_dynamic_real].to_numpy().T, }], freq = "D" ) forecast_it = predictor.predict(test_ds) forecast = next(iter(forecast_it)) plot_length = 50 - 14 legend = ["true values", "prediction"] from gluonts.dataset.util import to_pandas ts = to_pandas(next(iter(test_ds))) ts[-plot_length:].plot(color='blue') forecast.plot(color='green') plt.grid(which='both') plt.legend(legend, loc='upper left') plt.show()
うまく予測できています。
Dev版 GluonTS
pip install git+https://github.com/awslabs/gluonts.git
これで開発中のDev版 GluonTSがインストール可能です。