当前位置: 首页>>代码示例>>Python>>正文


Python dqn_agent.DQNAgent方法代码示例

本文整理汇总了Python中dqn_agent.DQNAgent方法的典型用法代码示例。如果您正苦于以下问题:Python dqn_agent.DQNAgent方法的具体用法?Python dqn_agent.DQNAgent怎么用?Python dqn_agent.DQNAgent使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在dqn_agent的用法示例。


在下文中一共展示了dqn_agent.DQNAgent方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: create_agent

# 需要导入模块: import dqn_agent [as 别名]
# 或者: from dqn_agent import DQNAgent [as 别名]
def create_agent(environment, obs_stacker, agent_type='DQN'):
  """Creates the Hanabi agent.

  Args:
    environment: The environment.
    obs_stacker: Observation stacker object.
    agent_type: str, type of agent to construct.

  Returns:
    An agent for playing Hanabi.

  Raises:
    ValueError: if an unknown agent type is requested.
  """
  if agent_type == 'DQN':
    return dqn_agent.DQNAgent(observation_size=obs_stacker.observation_size(),
                              num_actions=environment.num_moves(),
                              num_players=environment.players)
  elif agent_type == 'Rainbow':
    return rainbow_agent.RainbowAgent(
        observation_size=obs_stacker.observation_size(),
        num_actions=environment.num_moves(),
        num_players=environment.players)
  else:
    raise ValueError('Expected valid agent_type, got {}'.format(agent_type)) 
开发者ID:deepmind,项目名称:hanabi-learning-environment,代码行数:27,代码来源:run_experiment.py

示例2: main

# 需要导入模块: import dqn_agent [as 别名]
# 或者: from dqn_agent import DQNAgent [as 别名]
def main():
    logging.getLogger().setLevel(logging.INFO)
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='config/global_config.json')
    parser.add_argument('--num_step', type=int, default=10**3)
    parser.add_argument('--chpt', type=str)

    args = parser.parse_args()

    # preparing config
    # # for environment
    config = json.load(open(args.config))
    config["num_step"] = args.num_step
    # config["replay_data_path"] = "replay"
    cityflow_config = json.load(open(config['cityflow_config_file']))
    roadnetFile = cityflow_config['dir'] + cityflow_config['roadnetFile']
    config["lane_phase_info"] = parse_roadnet(roadnetFile)

    # # for agent
    intersection_id = list(config['lane_phase_info'].keys())[0]
    phase_list = config['lane_phase_info'][intersection_id]['phase']
    logging.info(phase_list)
    state_size = config["state_size"] = len(config['lane_phase_info'][intersection_id]['start_lane']) + 1 # 1 is for the current phase. [vehicle_count for each start lane] + [current_phase]
    config["action_size"] = len(phase_list)

    # build cotyflow environment
    env = CityFlowEnv(config)

    # build agent
    agent = DQNAgent(config)
    
    # inference
    agent.load(args.chpt)
    env.reset()
    state = env.get_state()
    state = np.array(list(state['start_lane_vehicle_count'].values()) + [state['current_phase']] )
    state = np.reshape(state, [1, state_size])
    
    for i in range(args.num_step): 
        action = agent.choose_action(state) # index of action
        action_phase = phase_list[action] # actual action
        next_state, reward = env.step(action_phase) # one step

        next_state = np.array(list(next_state['start_lane_vehicle_count'].values()) + [next_state['current_phase']])
        next_state = np.reshape(next_state, [1, state_size])

        state = next_state

        # logging
        logging.info("step:{}/{}, action:{}, reward:{}"
                        .format(i, args.num_step, action, reward))
    
    # copy file to front/replay
    # roadnetLog = os.path.join(cityflow_config['dir'], cityflow_config['roadnetLogFile'])
    # replayLog = os.path.join(cityflow_config['dir'], cityflow_config['replayLogFile']) 
开发者ID:multi-commander,项目名称:Multi-Commander,代码行数:57,代码来源:run_rl_inference.py

示例3: main

# 需要导入模块: import dqn_agent [as 别名]
# 或者: from dqn_agent import DQNAgent [as 别名]
def main():
    logging.getLogger().setLevel(logging.INFO)
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='config/global_config.json')
    parser.add_argument('--num_step', type=int, default=2000)
    parser.add_argument('--ckpt', type=str)
    parser.add_argument('--algo', type=str, default='DQN', choices=['DQN', 'DDQN', 'DuelDQN'], help='choose an algorithm')


    args = parser.parse_args()

    # preparing config
    # # for environment
    config = json.load(open(args.config))
    config["num_step"] = args.num_step
    cityflow_config = json.load(open(config['cityflow_config_file']))
    roadnetFile = cityflow_config['dir'] + cityflow_config['roadnetFile']
    config["lane_phase_info"] = parse_roadnet(roadnetFile)

    # # for agent
    intersection_id = "intersection_1_1"
    config["intersection_id"] = intersection_id
    config["state_size"] = len(config['lane_phase_info'][intersection_id]['start_lane']) + 1  # 1 is for the current phase. [vehicle_count for each start lane] + [current_phase]
    phase_list = config['lane_phase_info'][intersection_id]['phase']
    config["action_size"] = len(phase_list)
    config["batch_size"] = args.batch_size
    
    logging.info(phase_list)

    # build cityflow environment
    env = CityFlowEnv(config)

    # build agent
    agent = DQNAgent(config)
    
    # inference
    agent.load(args.ckpt)
    env.reset()
    state = env.get_state()
    
    for i in range(args.num_step): 
        action = agent.choose_action(state) # index of action
        action_phase = phase_list[action] # actual action
        next_state, reward = env.step(action_phase) # one step

        state = next_state

        # logging
        logging.info("step:{}/{}, action:{}, reward:{}"
                        .format(i, args.num_step, action, reward)) 
开发者ID:multi-commander,项目名称:Multi-Commander,代码行数:52,代码来源:run_rl_inference.py


注:本文中的dqn_agent.DQNAgent方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。