當前位置: 首頁>>代碼示例>>Python>>正文


Python dqn_agent.DQNAgent方法代碼示例

本文整理匯總了Python中dqn_agent.DQNAgent方法的典型用法代碼示例。如果您正苦於以下問題:Python dqn_agent.DQNAgent方法的具體用法?Python dqn_agent.DQNAgent怎麽用?Python dqn_agent.DQNAgent使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在dqn_agent的用法示例。


在下文中一共展示了dqn_agent.DQNAgent方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: create_agent

# 需要導入模塊: import dqn_agent [as 別名]
# 或者: from dqn_agent import DQNAgent [as 別名]
def create_agent(environment, obs_stacker, agent_type='DQN'):
  """Creates the Hanabi agent.

  Args:
    environment: The environment.
    obs_stacker: Observation stacker object.
    agent_type: str, type of agent to construct.

  Returns:
    An agent for playing Hanabi.

  Raises:
    ValueError: if an unknown agent type is requested.
  """
  if agent_type == 'DQN':
    return dqn_agent.DQNAgent(observation_size=obs_stacker.observation_size(),
                              num_actions=environment.num_moves(),
                              num_players=environment.players)
  elif agent_type == 'Rainbow':
    return rainbow_agent.RainbowAgent(
        observation_size=obs_stacker.observation_size(),
        num_actions=environment.num_moves(),
        num_players=environment.players)
  else:
    raise ValueError('Expected valid agent_type, got {}'.format(agent_type)) 
開發者ID:deepmind,項目名稱:hanabi-learning-environment,代碼行數:27,代碼來源:run_experiment.py

示例2: main

# 需要導入模塊: import dqn_agent [as 別名]
# 或者: from dqn_agent import DQNAgent [as 別名]
def main():
    logging.getLogger().setLevel(logging.INFO)
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='config/global_config.json')
    parser.add_argument('--num_step', type=int, default=10**3)
    parser.add_argument('--chpt', type=str)

    args = parser.parse_args()

    # preparing config
    # # for environment
    config = json.load(open(args.config))
    config["num_step"] = args.num_step
    # config["replay_data_path"] = "replay"
    cityflow_config = json.load(open(config['cityflow_config_file']))
    roadnetFile = cityflow_config['dir'] + cityflow_config['roadnetFile']
    config["lane_phase_info"] = parse_roadnet(roadnetFile)

    # # for agent
    intersection_id = list(config['lane_phase_info'].keys())[0]
    phase_list = config['lane_phase_info'][intersection_id]['phase']
    logging.info(phase_list)
    state_size = config["state_size"] = len(config['lane_phase_info'][intersection_id]['start_lane']) + 1 # 1 is for the current phase. [vehicle_count for each start lane] + [current_phase]
    config["action_size"] = len(phase_list)

    # build cotyflow environment
    env = CityFlowEnv(config)

    # build agent
    agent = DQNAgent(config)
    
    # inference
    agent.load(args.chpt)
    env.reset()
    state = env.get_state()
    state = np.array(list(state['start_lane_vehicle_count'].values()) + [state['current_phase']] )
    state = np.reshape(state, [1, state_size])
    
    for i in range(args.num_step): 
        action = agent.choose_action(state) # index of action
        action_phase = phase_list[action] # actual action
        next_state, reward = env.step(action_phase) # one step

        next_state = np.array(list(next_state['start_lane_vehicle_count'].values()) + [next_state['current_phase']])
        next_state = np.reshape(next_state, [1, state_size])

        state = next_state

        # logging
        logging.info("step:{}/{}, action:{}, reward:{}"
                        .format(i, args.num_step, action, reward))
    
    # copy file to front/replay
    # roadnetLog = os.path.join(cityflow_config['dir'], cityflow_config['roadnetLogFile'])
    # replayLog = os.path.join(cityflow_config['dir'], cityflow_config['replayLogFile']) 
開發者ID:multi-commander,項目名稱:Multi-Commander,代碼行數:57,代碼來源:run_rl_inference.py

示例3: main

# 需要導入模塊: import dqn_agent [as 別名]
# 或者: from dqn_agent import DQNAgent [as 別名]
def main():
    logging.getLogger().setLevel(logging.INFO)
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='config/global_config.json')
    parser.add_argument('--num_step', type=int, default=2000)
    parser.add_argument('--ckpt', type=str)
    parser.add_argument('--algo', type=str, default='DQN', choices=['DQN', 'DDQN', 'DuelDQN'], help='choose an algorithm')


    args = parser.parse_args()

    # preparing config
    # # for environment
    config = json.load(open(args.config))
    config["num_step"] = args.num_step
    cityflow_config = json.load(open(config['cityflow_config_file']))
    roadnetFile = cityflow_config['dir'] + cityflow_config['roadnetFile']
    config["lane_phase_info"] = parse_roadnet(roadnetFile)

    # # for agent
    intersection_id = "intersection_1_1"
    config["intersection_id"] = intersection_id
    config["state_size"] = len(config['lane_phase_info'][intersection_id]['start_lane']) + 1  # 1 is for the current phase. [vehicle_count for each start lane] + [current_phase]
    phase_list = config['lane_phase_info'][intersection_id]['phase']
    config["action_size"] = len(phase_list)
    config["batch_size"] = args.batch_size
    
    logging.info(phase_list)

    # build cityflow environment
    env = CityFlowEnv(config)

    # build agent
    agent = DQNAgent(config)
    
    # inference
    agent.load(args.ckpt)
    env.reset()
    state = env.get_state()
    
    for i in range(args.num_step): 
        action = agent.choose_action(state) # index of action
        action_phase = phase_list[action] # actual action
        next_state, reward = env.step(action_phase) # one step

        state = next_state

        # logging
        logging.info("step:{}/{}, action:{}, reward:{}"
                        .format(i, args.num_step, action, reward)) 
開發者ID:multi-commander,項目名稱:Multi-Commander,代碼行數:52,代碼來源:run_rl_inference.py


注:本文中的dqn_agent.DQNAgent方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。