当前位置: 首页>>代码示例>>Python>>正文


Python Agent.train方法代码示例

本文整理汇总了Python中agent.Agent.train方法的典型用法代码示例。如果您正苦于以下问题:Python Agent.train方法的具体用法?Python Agent.train怎么用?Python Agent.train使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在agent.Agent的用法示例。


在下文中一共展示了Agent.train方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from agent import Agent [as 别名]
# 或者: from agent.Agent import train [as 别名]
def main(env_name, monitor=True, load=False, seed=0, gpu=-1):

    env = gym.make(env_name)
    view_path = "./video/" + env_name
    model_path = "./model/" + env_name + "_"

    n_st = env.observation_space.shape[0]
    n_act = env.action_space.n

    agent = Agent(n_act, seed, gpu)
    if load:
        agent.load_model(model_path)

    if monitor:
        env.monitor.start(view_path, video_callable=None, force=True, seed=seed)
    for i_episode in xrange(10000):
        observation = env.reset()
        agent.reset_state(observation)
        ep_end = False
        q_list = []
        r_list = []
        while not ep_end:
            action = agent.act()
            observation, reward, ep_end, _ = env.step(action)
            agent.update_experience(observation, action, reward, ep_end)
            agent.train()
            q_list.append(agent.Q)
            r_list.append(reward)
            if ep_end:
                agent.save_model(model_path)
                break
        print('%i\t%i\t%f\t%i\t%f' % (i_episode, agent.step, agent.eps, sum(r_list), sum(q_list)/float(len(q_list))))
    if monitor:
        env.monitor.close()
开发者ID:trtd56,项目名称:Atari,代码行数:36,代码来源:main.py

示例2: main

# 需要导入模块: from agent import Agent [as 别名]
# 或者: from agent.Agent import train [as 别名]
def main(env_name, render=False, monitor=True, load=False, seed=0):

    env = gym.make(env_name)
    view_path = "./video/" + env_name
    model_path = "./model/" + env_name + "_"

    n_st = env.observation_space.shape[0]
    if type(env.action_space) == gym.spaces.discrete.Discrete:
        # CartPole-v0, Acrobot-v0, MountainCar-v0
        n_act = env.action_space.n
        action_list = range(0, n_act)
    elif type(env.action_space) == gym.spaces.box.Box:
        # Pendulum-v0
        action_list = [np.array([a]) for a in [-2.0, 2.0]]
        n_act = len(action_list)

    agent = Agent(n_st, n_act, seed)
    if load:
        agent.load_model(model_path)

    if monitor:
        env.monitor.start(view_path, video_callable=None, force=True, seed=seed)
    for i_episode in xrange(1000):
        observation = env.reset()
        r_sum = 0
        q_list = []
        for t in xrange(200):
            if render:
                env.render()
            state = observation.astype(np.float32).reshape((1,n_st))
            act_i, q = agent.get_action(state)
            q_list.append(q)
            action = action_list[act_i]
            observation, reward, ep_end, _ = env.step(action)
            state_dash = observation.astype(np.float32).reshape((1,n_st))
            agent.stock_experience(state, act_i, reward, state_dash, ep_end)
            agent.train()
            r_sum += reward
            if ep_end:
                break
        print "\t".join(map(str,[i_episode, r_sum, agent.epsilon, agent.loss, sum(q_list)/float(t+1) ,agent.step]))
        agent.save_model(model_path)
    if monitor:
        env.monitor.close()
开发者ID:trtd56,项目名称:ClassicControl,代码行数:46,代码来源:main.py

示例3: main

# 需要导入模块: from agent import Agent [as 别名]
# 或者: from agent.Agent import train [as 别名]
def main():
    game_width = 12
    game_height = 9
    nb_frames = 4
    actions = ((-1, 0), (1, 0), (0, -1), (0, 1), (0, 0))

    # Recipe of deep reinforcement learning model
    model = Sequential()
    model.add(Convolution2D(
        16,
        nb_row=3,
        nb_col=3,
        activation='relu',
        input_shape=(nb_frames, game_height, game_width)))
    model.add(Convolution2D(32, nb_row=3, nb_col=3, activation='relu'))
    model.add(Flatten())
    model.add(Dense(256, activation='relu'))
    model.add(Dense(len(actions)))
    model.compile(RMSprop(), 'MSE')

    agent = Agent(
        model, nb_frames, snake_game, actions, size=(game_width, game_height))
    agent.train(nb_epochs=10000, batch_size=64, gamma=0.8, save_model=True)
    agent.play(nb_rounds=10)
开发者ID:wing3s,项目名称:snake_game,代码行数:26,代码来源:run_bot.py

示例4: xrange

# 需要导入模块: from agent import Agent [as 别名]
# 或者: from agent.Agent import train [as 别名]
if args.random_steps:
  # populate replay memory with random steps
  logger.info("Populating replay memory with %d random moves" % args.random_steps)
  stats.reset()
  agent.play_random(args.random_steps)
  stats.write(0, "random")

# loop over epochs
for epoch in xrange(args.epochs):
  logger.info("Epoch #%d" % (epoch + 1))

  if args.train_steps:
    logger.info(" Training for %d steps" % args.train_steps)
    stats.reset()
    agent.train(args.train_steps, epoch)
    stats.write(epoch + 1, "train")

    if args.save_weights_prefix:
      filename = args.save_weights_prefix + "_%d.prm" % (epoch + 1)
      logger.info("Saving weights to %s" % filename)
      net.save_weights(filename)

  if args.test_steps:
    logger.info(" Testing for %d steps" % args.test_steps)
    stats.reset()
    agent.test(args.test_steps, epoch)
    stats.write(epoch + 1, "test")

stats.close()
logger.info("All done")
开发者ID:Deanout,项目名称:simple_dqn,代码行数:32,代码来源:main.py

示例5: Agent

# 需要导入模块: from agent import Agent [as 别名]
# 或者: from agent.Agent import train [as 别名]

agent = Agent(env, mem, network)


if args.train_model:
    #stats = Statistics(agent, network, mem, env)

    agent.play_random(random_steps=default_random_steps)

    print "Traning Started....."

    for i in range(EPOCHS):
        #stats.reset()
        a = datetime.datetime.now().replace(microsecond=0)
        agent.train(train_steps = STEPS_PER_EPOCH,epoch = 1)
        agent.test(test_steps = STEPS_PER_TEST,epoch = 1)
        save_path = args.save_model_dir
        if args.save_models:
            path_file = args.save_model_dir+'/dep-q-shooter-nipscuda-8movectrl-'+str(i)+'-epoch.pkl'
            #print path_file
            net_file = open(path_file, 'w')
            cPickle.dump(network, net_file, -1)
            net_file.close()
        b = datetime.datetime.now().replace(microsecond=0)
        #stats.write(i + 1, "train")
        print "Completed "+str(i+1)+"/"+str(EPOCHS)+" epochs in ",(b-a)

    print "Training Ended....."

if args.play_games > 0:
开发者ID:pavitrakumar78,项目名称:Playing-custom-games-using-Deep-Learning,代码行数:32,代码来源:stester8.py


注:本文中的agent.Agent.train方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。