本文整理汇总了Python中baselines.common.cmd_util.common_arg_parser方法的典型用法代码示例。如果您正苦于以下问题:Python cmd_util.common_arg_parser方法的具体用法?Python cmd_util.common_arg_parser怎么用?Python cmd_util.common_arg_parser使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类baselines.common.cmd_util
的用法示例。
在下文中一共展示了cmd_util.common_arg_parser方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import common_arg_parser [as 别名]
def main(custom_args=[]):
# configure logger, disable logging in child MPI processes (with rank > 0)
arg_parser = common_arg_parser()
args, unknown_args = arg_parser.parse_known_args()
extra_args = {}
for arg in custom_args:
if arg in vars(args).keys():
vars(args)[arg] = custom_args[arg]
else:
extra_args[arg] = custom_args[arg]
#extra_args = {k: parse(v) for k,v in parse_unknown_args(unknown_args).items()}
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
rank = 0
logger.configure(format_strs = ['stdout', 'tensorboard'])
else:
logger.configure(format_strs = ['stdout', 'tensorboard'])
rank = MPI.COMM_WORLD.Get_rank()
model, _ = train(args, extra_args)
if args.save_path is not None and rank == 0:
save_path = osp.expanduser(args.save_path)
model.save(save_path)
if args.play:
logger.log("Running trained model")
env = build_env(args)
obs = env.reset()
while True:
actions = model.step(obs)[0]
obs, _, done, _ = env.step(actions)
env.render()
done = done.any() if isinstance(done, np.ndarray) else done
if done:
obs = env.reset()
示例2: main
# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import common_arg_parser [as 别名]
def main():
# configure logger, disable logging in child MPI processes (with rank > 0)
arg_parser = common_arg_parser()
args, unknown_args = arg_parser.parse_known_args()
extra_args = parse_cmdline_kwargs(unknown_args)
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
rank = 0
logger.configure()
else:
logger.configure(format_strs=[])
rank = MPI.COMM_WORLD.Get_rank()
model, env = train(args, extra_args)
env.close()
if args.save_path is not None and rank == 0:
save_path = osp.expanduser(args.save_path)
model.save(save_path)
if args.play:
logger.log("Running trained model")
env = build_env(args)
obs = env.reset()
def initialize_placeholders(nlstm=128,**kwargs):
return np.zeros((args.num_env or 1, 2*nlstm)), np.zeros((1))
state, dones = initialize_placeholders(**extra_args)
while True:
actions, _, state, _ = model.step(obs,S=state, M=dones)
obs, _, done, _ = env.step(actions)
env.render()
done = done.any() if isinstance(done, np.ndarray) else done
if done:
obs = env.reset()
env.close()
示例3: main
# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import common_arg_parser [as 别名]
def main(args):
# configure logger, disable logging in child MPI processes (with rank > 0)
arg_parser = common_arg_parser()
args, unknown_args = arg_parser.parse_known_args(args)
extra_args = parse_cmdline_kwargs(unknown_args)
if args.extra_import is not None:
import_module(args.extra_import)
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
rank = 0
logger.configure()
else:
logger.configure(format_strs=[])
rank = MPI.COMM_WORLD.Get_rank()
model, env = train(args, extra_args)
if args.save_path is not None and rank == 0:
save_path = osp.expanduser(args.save_path)
model.save(save_path)
if args.play:
logger.log("Running trained model")
obs = env.reset()
state = model.initial_state if hasattr(model, 'initial_state') else None
dones = np.zeros((1,))
episode_rew = 0
while True:
if state is not None:
actions, _, state, _ = model.step(obs,S=state, M=dones)
else:
actions, _, _, _ = model.step(obs)
obs, rew, done, _ = env.step(actions)
episode_rew += rew[0]
env.render()
done = done.any() if isinstance(done, np.ndarray) else done
if done:
print(f'episode_rew={episode_rew}')
episode_rew = 0
obs = env.reset()
env.close()
return model
示例4: main
# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import common_arg_parser [as 别名]
def main(args):
# configure logger, disable logging in child MPI processes (with rank > 0)
arg_parser = common_arg_parser()
args, unknown_args = arg_parser.parse_known_args(args)
extra_args = parse_cmdline_kwargs(unknown_args)
if args.extra_import is not None:
import_module(args.extra_import)
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
rank = 0
logger.configure()
else:
logger.configure(format_strs=[])
rank = MPI.COMM_WORLD.Get_rank()
# If argument indicate training to be done:
model, env = train(args, extra_args)
env.close()
if args.save_path is not None and rank == 0:
save_path = osp.expanduser(args.save_path)
model.save(save_path)
saver = tf.train.Saver()
#logger.info("saving the trained model")
#start_time_save = time.time()
#saver.save(sess, save_path + "ddpg_test_model")
#logger.info('runtime saving: {}s'.format(time.time() - start_time_save))
# If it is a test run on the learned model
if args.play:
logger.log("Running trained model")
env = build_env(args)
obs = env.reset()
state = model.initial_state if hasattr(model, 'initial_state') else None
dones = np.zeros((1,))
while True:
if state is not None:
actions, _, state, _ = model.step(obs,S=state, M=dones)
else:
actions, _, _, _ = model.step(obs)
obs, _, done, _ = env.step(actions)
env.render()
done = done.any() if isinstance(done, np.ndarray) else done
if done:
obs = env.reset()
env.close()
return model
示例5: main
# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import common_arg_parser [as 别名]
def main(args):
# configure logger, disable logging in child MPI processes (with rank > 0)
arg_parser = common_arg_parser()
args, unknown_args = arg_parser.parse_known_args(args)
extra_args = parse_cmdline_kwargs(unknown_args)
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
rank = 0
configure_logger(args.log_path)
else:
rank = MPI.COMM_WORLD.Get_rank()
configure_logger(args.log_path, format_strs=[])
model, env = train(args, extra_args)
if args.save_path is not None and rank == 0:
save_path = osp.expanduser(args.save_path)
model.save(save_path)
if args.play:
logger.log("Running trained model")
obs = env.reset()
state = model.initial_state if hasattr(model, 'initial_state') else None
dones = np.zeros((1,))
episode_rew = np.zeros(env.num_envs) if isinstance(env, VecEnv) else np.zeros(1)
while True:
if state is not None:
actions, _, state, _ = model.step(obs,S=state, M=dones)
else:
actions, _, _, _ = model.step(obs)
obs, rew, done, _ = env.step(actions)
episode_rew += rew
env.render()
done_any = done.any() if isinstance(done, np.ndarray) else done
if done_any:
for i in np.nonzero(done)[0]:
print('episode_rew={}'.format(episode_rew[i]))
episode_rew[i] = 0
env.close()
return model
示例6: main
# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import common_arg_parser [as 别名]
def main():
# configure logger, disable logging in child MPI processes (with rank > 0)
arg_parser = common_arg_parser()
args, unknown_args = arg_parser.parse_known_args()
extra_args = parse_cmdline_kwargs(unknown_args)
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
rank = 0
logger.configure()
else:
logger.configure(format_strs=[])
rank = MPI.COMM_WORLD.Get_rank()
model, env = train(args, extra_args)
env.close()
# if args.save_path is not None and rank == 0:
# save_path = osp.expanduser(args.save_path)
# model.save(save_path)
if args.play:
# args.num_timesteps=686
# model, env = train(args, extra_args)
# env.close()
logger.log("Running trained model")
env = build_testenv(args)
obs = env.reset()
# done = False
#hardcode the data length
for i in range(686):
actions, _, _, _ = model.step(obs)
obs, _, done, _ = env.step(actions)
# env.render()
# done = done.any() if isinstance(done, np.ndarray) else done
# if done:
# obs = env.reset()
env.close()