GluonTSで日経平均を予測してみる

今回やってみたこと

1年前までの株価のデータからその先1年の株価を予測してみた
予測するために過去2年間のデータ(24回分のデータ)を使用するモデルとした

データのダウンロード

日経平均プロフィルのダウンロードセンターから月次データをダウンロードする。
ダウンロードセンター - 日経平均プロフィル
(日次データは土日祝日のデータがなく等間隔の時系列データになっていないため月次データを使用した)

データの前処理

  • 1行目のタイトルを変更する

「データ日付」→「date」
終値」→「value

  • 最終行を削除して「nikkei.csv」という名前で保存する(その際に「utf-8」に変換)

コード

import pandas as pd
import matplotlib
matplotlib.use('Agg') 
from matplotlib import pyplot as plt

df = pd.read_csv('nikkei.csv',index_col=0)
#直近1年分のデータ(12個のデータ)は訓練データに含めない
df_all = df[['value']]
df_train = df[['value']][:-12]

from gluonts.dataset.common import ListDataset
training_data = ListDataset(
    [{"start": df_train.index[0], "target": df_train.value}],
    freq = "1M")

from gluonts.model.deepar import DeepAREstimator
from gluonts.trainer import Trainer
#次の prediction_length 値を、先行して与えられた context_length 値から予測
estimator = DeepAREstimator(freq="1M", 
                            prediction_length=12, 
                            context_length=24, 
                            trainer=Trainer(epochs=10))
predictor = estimator.train(training_data=training_data)

df_all[100:].plot(linewidth=2)
plt.grid(which='both')
plt.savefig('real.png') 

plt.figure()

from gluonts.dataset.util import to_pandas
for test_entry, forecast in zip(training_data, predictor.predict(training_data)):
    to_pandas(test_entry)[100:].plot(linewidth=2)
    forecast.plot(color='g', prediction_intervals=[50.0, 90.0])
plt.grid(which='both')
plt.savefig('prediction.png')

結果

上が実際の変動を表したグラフ
下が予測
f:id:touch-sp:20190911124408p:plain
f:id:touch-sp:20190911124440p:plain
うまくいったようにも見えるがそれはたまたまであり、実際はこんなに簡単に今後1年の株価の予測ができるわけがない。諸外国の経済状況、日本の金融政策、自然災害など様々な影響因子があるわけでそれらをどうやって予測に反映させるかが問題。
株価予測を突き詰めるつもりはないが、影響を及ぼす因子をどのようにモデルに与えるかといった技術的なところは学習したい。

動作環境

Windows 10 Pro
Python 3.6.8
boto3==1.9.226
botocore==1.12.226
certifi==2019.6.16
chardet==3.0.4
cycler==0.10.0
dataclasses==0.6
docutils==0.15.2
gluonts==0.3.3
graphviz==0.8.4
holidays==0.9.11
idna==2.6
jmespath==0.9.4
kiwisolver==1.1.0
matplotlib==3.1.1
mxnet==1.4.1
numpy==1.14.6
pandas==0.25.1
pydantic==0.28
pyparsing==2.4.2
python-dateutil==2.8.0
pytz==2019.2
requests==2.18.4
s3transfer==0.2.1
six==1.12.0
tqdm==4.35.0
ujson==1.35
urllib3==1.22