CartPole

最終更新: 2017-08-17 20:22

環境を整える

pip install gym

ランダム戦略

import gym
import numpy as np
import matplotlib.pyplot as plt

def run_episode(env, parameters):  
    observation = env.reset()
    totalreward = 0
    for _ in xrange(200):
        action = 0 if np.matmul(parameters,observation) < 0 else 1
        observation, reward, done, info = env.step(action)
        totalreward += reward
        if done:
            break
    return totalreward

env = gym.make('CartPole-v0')

bestparams = None  
bestreward = 0  
x = []
y = []
for i in xrange(10000):  
    parameters = np.random.rand(4) * 2 - 1
    reward = run_episode(env,parameters)
    if reward > bestreward:
        bestreward = reward
        bestparams = parameters
    x.append(i)
    y.append(bestreward)
    if reward == 200:
        break
print(x, y)
plt.plot(x, y)
plt.plot(x, y, 'ro')
plt.show()