Ml-tips

【DDQN】stable-baselinesで始める強化学習

やること

stable-baselinesを使ってDDQNを動かします。環境としてはGymのCartPoleを利用していきます。

stable-baselinesのインストールはこちらの記事から。

サンプルコード

学習

import gym

from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.deepq.policies import MlpPolicy
from stable_baselines import DQN

env = gym.make('CartPole-v1')

model = DQN(MlpPolicy, env, verbose=1, tensorboard_log="log") #tensorboard_logはデフォルトではNone
model.learn(total_timesteps=15000)

モデル保存

model.save("deepq_cartpole")

モデル読み込み

model = DQN.load("deepq_cartpole")

モデルを試す

obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()
    if dones:
        break

右にいくように学習してしまってますね。

学習結果を確認してみる

$ tensorboard --logdir log

その他アルゴリズム

別記事ではその他のアルゴリズムも動かしていますのでぜひご覧ください!

  • A2C
  • ACER
  • ACKTR
  • DDPG
  • DDQN(この記事)
  • GAIL
  • HER
  • PPO
  • SAC
  • TD3
  • TRPO

参考