MXNetで深層強化学習(Double_DQNでCartPole-v0) スクリプトの改良

はじめに

以前サンプルスクリプトを実行することはやりました。
touch-sp.hatenablog.com
今回はさらなる改良を加えるためにスクリプトを一部改変しました。

環境

Ubuntu 20.04LTS on WSL2
Python 3.8.5

Python環境にインストールしたのは「mxnet」「matplotlib」「gym」の三つです。すべてpipでインストール可能でした。

certifi==2020.12.5
chardet==4.0.0
cloudpickle==1.6.0
cycler==0.10.0
graphviz==0.8.4
gym==0.18.3
idna==2.10
kiwisolver==1.3.1
matplotlib==3.4.2
mxnet==1.8.0.post0
numpy==1.20.3
Pillow==8.2.0
pkg-resources==0.0.0
pyglet==1.5.15
pyparsing==2.4.7
python-dateutil==2.8.1
requests==2.25.1
scipy==1.6.3
six==1.16.0
urllib3==1.26.4

学習のためのPythonスクリプト

import random
import copy 

import gym
import matplotlib.pyplot as plt

import mxnet as mx
from mxnet import gluon, nd, autograd, init
from mxnet.gluon import loss as gloss

from collections import deque

class DoubleQNetwork(gluon.nn.Block):
    def __init__(self, n_action):
        super(DoubleQNetwork, self).__init__()
        self.n_action = n_action

        self.dense0 = gluon.nn.Dense(200, activation='relu')
        self.dense1 = gluon.nn.Dense(200, activation='relu')
        self.dense2 = gluon.nn.Dense(100, activation='relu')
        self.dense3 = gluon.nn.Dense(self.n_action)

    def forward(self, state):
        q_value = self.dense3(self.dense2(self.dense1(self.dense0(state))))
        return q_value

class DoubleDQN:
    def __init__(self,
                 n_action,
                 init_epsilon,
                 final_epsilon,
                 gamma,
                 buffer_size,
                 batch_size,
                 replace_iter,
                 annealing,
                 loss_func,
                 learning_rate,
                 ctx
                 ):
        self.n_action = n_action
        self.epsilon = init_epsilon
        self.init_epsilon = init_epsilon
        self.final_epsilon = final_epsilon
        # discount factor
        self.gamma = gamma
        # memory buffer size
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        # replace the parameters of the target network every T time steps
        self.replace_iter = replace_iter
        # The number of step it will take to linearly anneal the epsilon to its min value
        self.annealing = annealing
        self.loss_func = loss_func
        self.learning_rate = learning_rate
        self.ctx = ctx

        self.total_steps = 0
        self.replay_buffer = deque(maxlen = self.buffer_size) 

        # build the network
        self.target_network = DoubleQNetwork(n_action)
        self.main_network = DoubleQNetwork(n_action)
        self.target_network.collect_params().initialize(init.Xavier(), ctx=ctx)  # initialize the params
        self.main_network.collect_params().initialize(init.Xavier(), ctx=ctx)

        # optimize the main network
        self.optimizer = gluon.Trainer(self.main_network.collect_params(), 'adam',
                                       {'learning_rate': self.learning_rate})

    def choose_action(self, state):
        state = nd.array([state], ctx=self.ctx)
        if nd.random.uniform(0, 1) > self.epsilon:
            # choose the best action
            q_value = self.main_network(state)
            action = int(nd.argmax(q_value, axis=1).asnumpy())
        else:
            # random choice
            action = random.choice(range(self.n_action))
        # anneal
        self.epsilon = max(self.final_epsilon,
                           self.epsilon - (self.init_epsilon - self.final_epsilon) / self.annealing)
        self.total_steps += 1
        return action

    def update(self):
        minibatch = random.sample(self.replay_buffer, self.batch_size)
        state_batch = nd.array([data[0] for data in minibatch], ctx=self.ctx)
        action_batch = nd.array([data[1] for data in minibatch], ctx=self.ctx)
        reward_batch = nd.array([data[2] for data in minibatch], ctx=self.ctx)
        next_state_batch = nd.array([data[3] for data in minibatch], ctx=self.ctx)
        done_batch = nd.array([data[4] for data in minibatch], ctx=self.ctx)

        all_next_q_value = self.target_network(next_state_batch)
        max_action = nd.argmax(all_next_q_value, axis=1)
        target_q_value = nd.pick(all_next_q_value, max_action)

        target_q_value = reward_batch + (1 - done_batch) * self.gamma * target_q_value
    
        with autograd.record():
            # get the Q(s,a)
            all_current_q_value = self.main_network(state_batch)
            main_q_value = nd.pick(all_current_q_value, action_batch)

            # record loss
            value_loss = self.loss_func(target_q_value, main_q_value)
            value_loss.backward()
        self.optimizer.step(batch_size=self.batch_size)

    def replace_parameters(self):
        self.target_network = copy.deepcopy(self.main_network)
        print('Double_DQN parameters replaced')

    def save_parameters(self):
        self.main_network.save_parameters('Double_DQN_main_network_parameters')

if __name__ == '__main__':
    seed = 7777777
    mx.random.seed(seed)
    random.seed(seed)
    ctx = mx.cpu()
    env = gym.make('CartPole-v0').unwrapped
    env.seed(seed)
    render = False
    episodes = 400

    agent = DoubleDQN(n_action=env.action_space.n,
                      init_epsilon=1.0,
                      final_epsilon=0.1,
                      gamma=0.99,
                      buffer_size=3000,
                      batch_size=32,
                      replace_iter=40,
                      annealing=3000,
                      loss_func = gloss.L2Loss(),
                      learning_rate=0.0001,
                      ctx=ctx
                      )

    episode_reward_list = []
    for episode in range(episodes):
        state = env.reset()
        episode_reward = 0
        while True:
            if render:
                env.render()
            action = agent.choose_action(state)
            next_state, reward, done, info = env.step(action)
            episode_reward += reward
            agent.replay_buffer.append((state, action, reward, next_state, done))
            if len(agent.replay_buffer) > 1000:
                agent.update()
                
            if done:
                print('episode %d ends with reward %d at steps %d' % (episode, episode_reward, agent.total_steps))
                episode_reward_list.append(episode_reward)
                if (episode > 1) and (episode % agent.replace_iter == 0):
                    agent.replace_parameters()
                break
            state = next_state

        if episode_reward > 1000:
            print("early_stopping")
            break

    agent.save_parameters()
    env.close()

    plt.plot(episode_reward_list)
    plt.xlabel('episode')
    plt.ylabel('episode reward')
    plt.title('Double_DQN CartPole-v0')
    plt.savefig('./Double-DQN-CartPole-v0.png')

結果

f:id:touch-sp:20210526121029p:plain
突然rewardが上昇するところがあったためそこで学習を打ち切っています。
スクリプトの以下の部分です。

        if episode_reward > 1000:
            print("early_stopping")
            break


学習後、上限500にして実際に10回ゲームを実行してみました。

import mxnet as mx
from Double_DQN import DoubleQNetwork
import gym

ctx = mx.cpu()
env = gym.make('CartPole-v0').unwrapped

agent = DoubleQNetwork(env.action_space.n)
agent.load_parameters('Double_DQN_main_network_parameters', ctx = ctx)

for i in range(10):
    observation = env.reset()
    t = 1
    while True:
        #env.render()
        state = mx.nd.array([observation], ctx = ctx)
        action = int(mx.nd.argmax(agent(state), axis=1).asscalar())
        observation, reward, done, info = env.step(action)
        if done or (t > 499):
            print("Episode{} finished after {} timesteps".format(i, t))
            break
        t += 1
env.close()

Episode0 finished after 500 timesteps
Episode1 finished after 500 timesteps
Episode2 finished after 500 timesteps
Episode3 finished after 500 timesteps
Episode4 finished after 407 timesteps
Episode5 finished after 500 timesteps
Episode6 finished after 500 timesteps
Episode7 finished after 500 timesteps
Episode8 finished after 500 timesteps
Episode9 finished after 500 timesteps

すべてで400以上になっているのでうまく学習できていそうです。