本文整理匯總了Python中dqn_agent.DQNAgent方法的典型用法代碼示例。如果您正苦於以下問題:Python dqn_agent.DQNAgent方法的具體用法?Python dqn_agent.DQNAgent怎麽用?Python dqn_agent.DQNAgent使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類dqn_agent
的用法示例。
在下文中一共展示了dqn_agent.DQNAgent方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: create_agent
# 需要導入模塊: import dqn_agent [as 別名]
# 或者: from dqn_agent import DQNAgent [as 別名]
def create_agent(environment, obs_stacker, agent_type='DQN'):
"""Creates the Hanabi agent.
Args:
environment: The environment.
obs_stacker: Observation stacker object.
agent_type: str, type of agent to construct.
Returns:
An agent for playing Hanabi.
Raises:
ValueError: if an unknown agent type is requested.
"""
if agent_type == 'DQN':
return dqn_agent.DQNAgent(observation_size=obs_stacker.observation_size(),
num_actions=environment.num_moves(),
num_players=environment.players)
elif agent_type == 'Rainbow':
return rainbow_agent.RainbowAgent(
observation_size=obs_stacker.observation_size(),
num_actions=environment.num_moves(),
num_players=environment.players)
else:
raise ValueError('Expected valid agent_type, got {}'.format(agent_type))
示例2: main
# 需要導入模塊: import dqn_agent [as 別名]
# 或者: from dqn_agent import DQNAgent [as 別名]
def main():
logging.getLogger().setLevel(logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config/global_config.json')
parser.add_argument('--num_step', type=int, default=10**3)
parser.add_argument('--chpt', type=str)
args = parser.parse_args()
# preparing config
# # for environment
config = json.load(open(args.config))
config["num_step"] = args.num_step
# config["replay_data_path"] = "replay"
cityflow_config = json.load(open(config['cityflow_config_file']))
roadnetFile = cityflow_config['dir'] + cityflow_config['roadnetFile']
config["lane_phase_info"] = parse_roadnet(roadnetFile)
# # for agent
intersection_id = list(config['lane_phase_info'].keys())[0]
phase_list = config['lane_phase_info'][intersection_id]['phase']
logging.info(phase_list)
state_size = config["state_size"] = len(config['lane_phase_info'][intersection_id]['start_lane']) + 1 # 1 is for the current phase. [vehicle_count for each start lane] + [current_phase]
config["action_size"] = len(phase_list)
# build cotyflow environment
env = CityFlowEnv(config)
# build agent
agent = DQNAgent(config)
# inference
agent.load(args.chpt)
env.reset()
state = env.get_state()
state = np.array(list(state['start_lane_vehicle_count'].values()) + [state['current_phase']] )
state = np.reshape(state, [1, state_size])
for i in range(args.num_step):
action = agent.choose_action(state) # index of action
action_phase = phase_list[action] # actual action
next_state, reward = env.step(action_phase) # one step
next_state = np.array(list(next_state['start_lane_vehicle_count'].values()) + [next_state['current_phase']])
next_state = np.reshape(next_state, [1, state_size])
state = next_state
# logging
logging.info("step:{}/{}, action:{}, reward:{}"
.format(i, args.num_step, action, reward))
# copy file to front/replay
# roadnetLog = os.path.join(cityflow_config['dir'], cityflow_config['roadnetLogFile'])
# replayLog = os.path.join(cityflow_config['dir'], cityflow_config['replayLogFile'])
示例3: main
# 需要導入模塊: import dqn_agent [as 別名]
# 或者: from dqn_agent import DQNAgent [as 別名]
def main():
logging.getLogger().setLevel(logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config/global_config.json')
parser.add_argument('--num_step', type=int, default=2000)
parser.add_argument('--ckpt', type=str)
parser.add_argument('--algo', type=str, default='DQN', choices=['DQN', 'DDQN', 'DuelDQN'], help='choose an algorithm')
args = parser.parse_args()
# preparing config
# # for environment
config = json.load(open(args.config))
config["num_step"] = args.num_step
cityflow_config = json.load(open(config['cityflow_config_file']))
roadnetFile = cityflow_config['dir'] + cityflow_config['roadnetFile']
config["lane_phase_info"] = parse_roadnet(roadnetFile)
# # for agent
intersection_id = "intersection_1_1"
config["intersection_id"] = intersection_id
config["state_size"] = len(config['lane_phase_info'][intersection_id]['start_lane']) + 1 # 1 is for the current phase. [vehicle_count for each start lane] + [current_phase]
phase_list = config['lane_phase_info'][intersection_id]['phase']
config["action_size"] = len(phase_list)
config["batch_size"] = args.batch_size
logging.info(phase_list)
# build cityflow environment
env = CityFlowEnv(config)
# build agent
agent = DQNAgent(config)
# inference
agent.load(args.ckpt)
env.reset()
state = env.get_state()
for i in range(args.num_step):
action = agent.choose_action(state) # index of action
action_phase = phase_list[action] # actual action
next_state, reward = env.step(action_phase) # one step
state = next_state
# logging
logging.info("step:{}/{}, action:{}, reward:{}"
.format(i, args.num_step, action, reward))