本文整理汇总了Python中baselines.common.vec_env.VecEnv方法的典型用法代码示例。如果您正苦于以下问题:Python vec_env.VecEnv方法的具体用法?Python vec_env.VecEnv怎么用?Python vec_env.VecEnv使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类baselines.common.vec_env
的用法示例。
在下文中一共展示了vec_env.VecEnv方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: build_env
# 需要导入模块: from baselines.common import vec_env [as 别名]
# 或者: from baselines.common.vec_env import VecEnv [as 别名]
def build_env(args):
ncpu = multiprocessing.cpu_count()
if sys.platform == 'darwin': ncpu //= 2
nenv = args.num_env or ncpu
alg = args.alg
seed = args.seed
env_type, env_id = get_env_type(args.env)
if env_type in {'atari', 'retro'}:
if alg == 'deepq':
env = make_env(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': True})
elif alg == 'trpo_mpi':
env = make_env(env_id, env_type, seed=seed)
else:
frame_stack_size = 4
env = make_vec_env(env_id, env_type, nenv, seed, gamestate=args.gamestate, reward_scale=args.reward_scale)
env = VecFrameStack(env, frame_stack_size)
else:
config = tf.ConfigProto(allow_soft_placement=True,
intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True
get_session(config=config)
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale)
if args.custom_reward != '':
from baselines.common.vec_env import VecEnv, VecEnvWrapper
import baselines.common.custom_reward_wrapper as W
assert isinstance(env,VecEnv) or isinstance(env,VecEnvWrapper)
custom_reward_kwargs = eval(args.custom_reward_kwargs)
if args.custom_reward == 'live_long':
env = W.VecLiveLongReward(env,**custom_reward_kwargs)
elif args.custom_reward == 'random_tf':
env = W.VecTFRandomReward(env,**custom_reward_kwargs)
elif args.custom_reward == 'preference':
env = W.VecTFPreferenceReward(env,**custom_reward_kwargs)
elif args.custom_reward == 'preference_normalized':
env = W.VecTFPreferenceRewardNormalized(env,**custom_reward_kwargs)
else:
assert False, 'no such wrapper exist'
if env_type == 'mujoco':
env = VecNormalize(env)
return env
示例2: build_env
# 需要导入模块: from baselines.common import vec_env [as 别名]
# 或者: from baselines.common.vec_env import VecEnv [as 别名]
def build_env(args):
ncpu = multiprocessing.cpu_count()
if sys.platform == 'darwin': ncpu //= 2
nenv = args.num_env or ncpu
alg = args.alg
seed = args.seed
env_type, env_id = get_env_type(args.env)
print(env_id)
#extract the agc_env_name
noskip_idx = env_id.find("NoFrameskip")
env_name = env_id[:noskip_idx].lower()
print("Env Name for Masking:", env_name)
if env_type in {'atari', 'retro'}:
if alg == 'deepq':
env = make_env(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': True})
elif alg == 'trpo_mpi':
env = make_env(env_id, env_type, seed=seed)
else:
frame_stack_size = 4
env = make_vec_env(env_id, env_type, nenv, seed, gamestate=args.gamestate, reward_scale=args.reward_scale)
env = VecFrameStack(env, frame_stack_size)
else:
config = tf.ConfigProto(allow_soft_placement=True,
intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True
get_session(config=config)
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale)
if args.custom_reward != '':
from baselines.common.vec_env import VecEnv, VecEnvWrapper
import baselines.common.custom_reward_wrapper as W
assert isinstance(env,VecEnv) or isinstance(env,VecEnvWrapper)
custom_reward_kwargs = eval(args.custom_reward_kwargs)
if args.custom_reward == 'pytorch':
if args.custom_reward_path == '':
assert False, 'no path for reward model'
else:
env = W.VecPyTorchAtariReward(env, args.custom_reward_path, env_name)
else:
assert False, 'no such wrapper exist'
if env_type == 'mujoco':
env = VecNormalize(env)
# if env_type == 'atari':
# input("Normalizing for ATari game: okay? [Enter]")
# #normalize rewards but not observations for atari
# env = VecNormalizeRewards(env)
return env
示例3: main
# 需要导入模块: from baselines.common import vec_env [as 别名]
# 或者: from baselines.common.vec_env import VecEnv [as 别名]
def main(args):
# configure logger, disable logging in child MPI processes (with rank > 0)
arg_parser = common_arg_parser()
args, unknown_args = arg_parser.parse_known_args(args)
extra_args = parse_cmdline_kwargs(unknown_args)
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
rank = 0
configure_logger(args.log_path)
else:
rank = MPI.COMM_WORLD.Get_rank()
configure_logger(args.log_path, format_strs=[])
model, env = train(args, extra_args)
if args.save_path is not None and rank == 0:
save_path = osp.expanduser(args.save_path)
model.save(save_path)
if args.play:
logger.log("Running trained model")
obs = env.reset()
state = model.initial_state if hasattr(model, 'initial_state') else None
dones = np.zeros((1,))
episode_rew = np.zeros(env.num_envs) if isinstance(env, VecEnv) else np.zeros(1)
while True:
if state is not None:
actions, _, state, _ = model.step(obs,S=state, M=dones)
else:
actions, _, _, _ = model.step(obs)
obs, rew, done, _ = env.step(actions)
episode_rew += rew
env.render()
done_any = done.any() if isinstance(done, np.ndarray) else done
if done_any:
for i in np.nonzero(done)[0]:
print('episode_rew={}'.format(episode_rew[i]))
episode_rew[i] = 0
env.close()
return model