本文整理汇总了Python中dqn_agent.DQNAgent方法的典型用法代码示例。如果您正苦于以下问题:Python dqn_agent.DQNAgent方法的具体用法?Python dqn_agent.DQNAgent怎么用?Python dqn_agent.DQNAgent使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类dqn_agent
的用法示例。
在下文中一共展示了dqn_agent.DQNAgent方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: create_agent
# 需要导入模块: import dqn_agent [as 别名]
# 或者: from dqn_agent import DQNAgent [as 别名]
def create_agent(environment, obs_stacker, agent_type='DQN'):
"""Creates the Hanabi agent.
Args:
environment: The environment.
obs_stacker: Observation stacker object.
agent_type: str, type of agent to construct.
Returns:
An agent for playing Hanabi.
Raises:
ValueError: if an unknown agent type is requested.
"""
if agent_type == 'DQN':
return dqn_agent.DQNAgent(observation_size=obs_stacker.observation_size(),
num_actions=environment.num_moves(),
num_players=environment.players)
elif agent_type == 'Rainbow':
return rainbow_agent.RainbowAgent(
observation_size=obs_stacker.observation_size(),
num_actions=environment.num_moves(),
num_players=environment.players)
else:
raise ValueError('Expected valid agent_type, got {}'.format(agent_type))
示例2: main
# 需要导入模块: import dqn_agent [as 别名]
# 或者: from dqn_agent import DQNAgent [as 别名]
def main():
logging.getLogger().setLevel(logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config/global_config.json')
parser.add_argument('--num_step', type=int, default=10**3)
parser.add_argument('--chpt', type=str)
args = parser.parse_args()
# preparing config
# # for environment
config = json.load(open(args.config))
config["num_step"] = args.num_step
# config["replay_data_path"] = "replay"
cityflow_config = json.load(open(config['cityflow_config_file']))
roadnetFile = cityflow_config['dir'] + cityflow_config['roadnetFile']
config["lane_phase_info"] = parse_roadnet(roadnetFile)
# # for agent
intersection_id = list(config['lane_phase_info'].keys())[0]
phase_list = config['lane_phase_info'][intersection_id]['phase']
logging.info(phase_list)
state_size = config["state_size"] = len(config['lane_phase_info'][intersection_id]['start_lane']) + 1 # 1 is for the current phase. [vehicle_count for each start lane] + [current_phase]
config["action_size"] = len(phase_list)
# build cotyflow environment
env = CityFlowEnv(config)
# build agent
agent = DQNAgent(config)
# inference
agent.load(args.chpt)
env.reset()
state = env.get_state()
state = np.array(list(state['start_lane_vehicle_count'].values()) + [state['current_phase']] )
state = np.reshape(state, [1, state_size])
for i in range(args.num_step):
action = agent.choose_action(state) # index of action
action_phase = phase_list[action] # actual action
next_state, reward = env.step(action_phase) # one step
next_state = np.array(list(next_state['start_lane_vehicle_count'].values()) + [next_state['current_phase']])
next_state = np.reshape(next_state, [1, state_size])
state = next_state
# logging
logging.info("step:{}/{}, action:{}, reward:{}"
.format(i, args.num_step, action, reward))
# copy file to front/replay
# roadnetLog = os.path.join(cityflow_config['dir'], cityflow_config['roadnetLogFile'])
# replayLog = os.path.join(cityflow_config['dir'], cityflow_config['replayLogFile'])
示例3: main
# 需要导入模块: import dqn_agent [as 别名]
# 或者: from dqn_agent import DQNAgent [as 别名]
def main():
logging.getLogger().setLevel(logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config/global_config.json')
parser.add_argument('--num_step', type=int, default=2000)
parser.add_argument('--ckpt', type=str)
parser.add_argument('--algo', type=str, default='DQN', choices=['DQN', 'DDQN', 'DuelDQN'], help='choose an algorithm')
args = parser.parse_args()
# preparing config
# # for environment
config = json.load(open(args.config))
config["num_step"] = args.num_step
cityflow_config = json.load(open(config['cityflow_config_file']))
roadnetFile = cityflow_config['dir'] + cityflow_config['roadnetFile']
config["lane_phase_info"] = parse_roadnet(roadnetFile)
# # for agent
intersection_id = "intersection_1_1"
config["intersection_id"] = intersection_id
config["state_size"] = len(config['lane_phase_info'][intersection_id]['start_lane']) + 1 # 1 is for the current phase. [vehicle_count for each start lane] + [current_phase]
phase_list = config['lane_phase_info'][intersection_id]['phase']
config["action_size"] = len(phase_list)
config["batch_size"] = args.batch_size
logging.info(phase_list)
# build cityflow environment
env = CityFlowEnv(config)
# build agent
agent = DQNAgent(config)
# inference
agent.load(args.ckpt)
env.reset()
state = env.get_state()
for i in range(args.num_step):
action = agent.choose_action(state) # index of action
action_phase = phase_list[action] # actual action
next_state, reward = env.step(action_phase) # one step
state = next_state
# logging
logging.info("step:{}/{}, action:{}, reward:{}"
.format(i, args.num_step, action, reward))