本文整理汇总了Python中stable_baselines.common.vec_env.VecFrameStack方法的典型用法代码示例。如果您正苦于以下问题:Python vec_env.VecFrameStack方法的具体用法?Python vec_env.VecFrameStack怎么用?Python vec_env.VecFrameStack使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类stable_baselines.common.vec_env
的用法示例。
在下文中一共展示了vec_env.VecFrameStack方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: makeEnv
# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def makeEnv(cls, args, env_kwargs=None, load_path_normalise=None):
# Even though DeepQ is single core only, we need to use the pipe system to work
if env_kwargs is not None and env_kwargs.get("use_srl", False):
srl_model = MultiprocessSRLModel(1, args.env, env_kwargs)
env_kwargs["state_dim"] = srl_model.state_dim
env_kwargs["srl_pipe"] = srl_model.pipe
envs = DummyVecEnv([makeEnv(args.env, args.seed, 0, args.log_dir, env_kwargs=env_kwargs)])
envs = VecFrameStack(envs, args.num_stack)
if args.srl_model != "raw_pixels":
printYellow("Using MLP policy because working on state representation")
envs = VecNormalize(envs, norm_obs=True, norm_reward=False)
envs = loadRunningAverage(envs, load_path_normalise=load_path_normalise)
return envs
示例2: load_stable_baselines_env
# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def load_stable_baselines_env(cfg_path, vector_length, mp, n_stack, number_maps, action_frame_repeat,
scaled_resolution):
env_fn = lambda: MazeExplorer.load_vizdoom_env(cfg_path, number_maps, action_frame_repeat, scaled_resolution)
if mp:
env = SubprocVecEnv([env_fn for _ in range(vector_length)])
else:
env = DummyVecEnv([env_fn for _ in range(vector_length)])
if n_stack > 0:
env = VecFrameStack(env, n_stack=n_stack)
return env
示例3: train
# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def train(env_id, num_timesteps, seed, policy, lr_schedule, num_env):
"""
Train A2C model for atari environment, for testing purposes
:param env_id: (str) Environment ID
:param num_timesteps: (int) The total number of samples
:param seed: (int) The initial seed for training
:param policy: (A2CPolicy) The policy model to use (MLP, CNN, LSTM, ...)
:param lr_schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant',
'double_linear_con', 'middle_drop' or 'double_middle_drop')
:param num_env: (int) The number of environments
"""
policy_fn = None
if policy == 'cnn':
policy_fn = CnnPolicy
elif policy == 'lstm':
policy_fn = CnnLstmPolicy
elif policy == 'lnlstm':
policy_fn = CnnLnLstmPolicy
if policy_fn is None:
raise ValueError("Error: policy {} not implemented".format(policy))
env = VecFrameStack(make_atari_env(env_id, num_env, seed), 4)
model = A2C(policy_fn, env, lr_schedule=lr_schedule, seed=seed)
model.learn(total_timesteps=int(num_timesteps * 1.1))
env.close()
示例4: train
# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def train(env_id, num_timesteps, seed, policy,
n_envs=8, nminibatches=4, n_steps=128):
"""
Train PPO2 model for atari environment, for testing purposes
:param env_id: (str) the environment id string
:param num_timesteps: (int) the number of timesteps to run
:param seed: (int) Used to seed the random generator.
:param policy: (Object) The policy model to use (MLP, CNN, LSTM, ...)
:param n_envs: (int) Number of parallel environments
:param nminibatches: (int) Number of training minibatches per update. For recurrent policies,
the number of environments run in parallel should be a multiple of nminibatches.
:param n_steps: (int) The number of steps to run for each environment per update
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
"""
env = VecFrameStack(make_atari_env(env_id, n_envs, seed), 4)
policy = {'cnn': CnnPolicy, 'lstm': CnnLstmPolicy, 'lnlstm': CnnLnLstmPolicy, 'mlp': MlpPolicy}[policy]
model = PPO2(policy=policy, env=env, n_steps=n_steps, nminibatches=nminibatches,
lam=0.95, gamma=0.99, noptepochs=4, ent_coef=.01,
learning_rate=lambda f: f * 2.5e-4, cliprange=lambda f: f * 0.1, verbose=1)
model.learn(total_timesteps=num_timesteps)
env.close()
# Free memory
del model
示例5: train
# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def train(env_id, num_timesteps, seed, policy, lr_schedule, num_cpu):
"""
train an ACER model on atari
:param env_id: (str) Environment ID
:param num_timesteps: (int) The total number of samples
:param seed: (int) The initial seed for training
:param policy: (A2CPolicy) The policy model to use (MLP, CNN, LSTM, ...)
:param lr_schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant',
'double_linear_con', 'middle_drop' or 'double_middle_drop')
:param num_cpu: (int) The number of cpu to train on
"""
env = VecFrameStack(make_atari_env(env_id, num_cpu, seed), 4)
if policy == 'cnn':
policy_fn = CnnPolicy
elif policy == 'lstm':
policy_fn = CnnLstmPolicy
else:
warnings.warn("Policy {} not implemented".format(policy))
return
model = ACER(policy_fn, env, lr_schedule=lr_schedule, buffer_size=5000, seed=seed)
model.learn(total_timesteps=int(num_timesteps * 1.1))
env.close()
# Free memory
del model
示例6: test_vecenv_terminal_obs
# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper):
"""Test that 'terminal_observation' gets added to info dict upon
termination."""
step_nums = [i + 5 for i in range(N_ENVS)]
vec_env = vec_env_class([functools.partial(StepEnv, n) for n in step_nums])
if vec_env_wrapper is not None:
if vec_env_wrapper == VecFrameStack:
vec_env = vec_env_wrapper(vec_env, n_stack=2)
else:
vec_env = vec_env_wrapper(vec_env)
zero_acts = np.zeros((N_ENVS,), dtype='int')
prev_obs_b = vec_env.reset()
for step_num in range(1, max(step_nums) + 1):
obs_b, _, done_b, info_b = vec_env.step(zero_acts)
assert len(obs_b) == N_ENVS
assert len(done_b) == N_ENVS
assert len(info_b) == N_ENVS
env_iter = zip(prev_obs_b, obs_b, done_b, info_b, step_nums)
for prev_obs, obs, done, info, final_step_num in env_iter:
assert done == (step_num == final_step_num)
if not done:
assert 'terminal_observation' not in info
else:
terminal_obs = info['terminal_observation']
# do some rough ordering checks that should work for all
# wrappers, including VecNormalize
assert np.all(prev_obs < terminal_obs)
assert np.all(obs < prev_obs)
if not isinstance(vec_env, VecNormalize):
# more precise tests that we can't do with VecNormalize
# (which changes observation values)
assert np.all(prev_obs + 1 == terminal_obs)
assert np.all(obs == 0)
prev_obs_b = obs_b
vec_env.close()
示例7: test_pretrain_images
# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def test_pretrain_images(tmp_path):
env = make_atari_env("PongNoFrameskip-v4", num_env=1, seed=0)
env = VecFrameStack(env, n_stack=4)
model = PPO2('CnnPolicy', env)
generate_expert_traj(model, str(tmp_path / 'expert_pong'), n_timesteps=0, n_episodes=1,
image_folder=str(tmp_path / 'pretrain_recorded_images'))
expert_path = str(tmp_path / 'expert_pong.npz')
dataset = ExpertDataset(expert_path=expert_path, traj_limitation=1, batch_size=32,
sequential_preprocessing=True)
model.pretrain(dataset, n_epochs=2)
shutil.rmtree(str(tmp_path / 'pretrain_recorded_images'))
env.close()
del dataset, model, env
示例8: load_train_env
# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def load_train_env(ns, state_collector, robot_radius, rew_fnc, num_stacks,
stack_offset, debug, task_mode, rl_mode, policy, disc_action_space, normalize):
# Choosing environment wrapper according to the policy
if policy == "CnnPolicy" or policy == "CnnLnLstmPolicy" or policy == "CnnLstmPolicy":
if disc_action_space:
env_temp = RosEnvDiscImg
else:
env_temp = RosEnvContImg
elif policy in ["CNN1DPolicy", "CNN1DPolicy2", "CNN1DPolicy3"]:
if disc_action_space:
env_temp = RosEnvDiscRawScanPrepWp
else:
env_temp = RosEnvContRawScanPrepWp
elif policy == "CNN1DPolicy_multi_input":
if disc_action_space:
env_temp = RosEnvDiscRaw
else:
env_temp = RosEnvContRaw
elif policy == "CnnPolicy_multi_input_vel" or policy == "CnnPolicy_multi_input_vel2":
if disc_action_space:
env_temp = RosEnvDiscImgVel
else:
env_temp = RosEnvContImgVel
env_raw = DummyVecEnv([lambda: env_temp(ns, state_collector, stack_offset, num_stacks, robot_radius, rew_fnc, debug, rl_mode, task_mode)])
if normalize:
env = VecNormalize(env_raw, training=True, norm_obs=True, norm_reward=False, clip_obs=100.0, clip_reward=10.0,
gamma=0.99, epsilon=1e-08)
else:
env = env_raw
# Stack of data?
if num_stacks > 1:
env = VecFrameStack(env, n_stack=num_stacks, n_offset=stack_offset)
return env
示例9: load_train_env
# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def load_train_env(num_envs, robot_radius, rew_fnc, num_stacks, stack_offset, debug, task_mode, policy, disc_action_space, normalize):
# Choosing environment wrapper according to the policy
if policy == "CnnPolicy" or policy == "CnnLnLstmPolicy" or policy == "CnnLstmPolicy":
if disc_action_space:
env_temp = RosEnvDiscImg
else:
env_temp = RosEnvContImg
elif policy == "CNN1DPolicy":
if disc_action_space:
env_temp = RosEnvDiscRawScanPrepWp
else:
env_temp = RosEnvContRawScanPrepWp
elif policy == "CNN1DPolicy_multi_input":
if disc_action_space:
env_temp = RosEnvDiscRaw
else:
env_temp = RosEnvContRaw
elif policy == "CnnPolicy_multi_input_vel" or policy == "CnnPolicy_multi_input_vel2":
if disc_action_space:
env_temp = RosEnvDiscImgVel
else:
env_temp = RosEnvContImgVel
env = SubprocVecEnv([lambda k=k: Monitor(env_temp("sim%d" % (k+1), StateCollector("sim%s"%(k+1), "train") , stack_offset, num_stacks, robot_radius, rew_fnc, debug, "train", task_mode), '%s/%s/sim_%d'%(path_to_models, agent_name, k+1), allow_early_resets=True) for k in range(num_envs)])
# Normalizing?
if normalize:
env = VecNormalize(env, training=True, norm_obs=True, norm_reward=False, clip_obs=100.0, clip_reward=10.0,
gamma=0.99, epsilon=1e-08)
else:
env = env
# Stack of data?
if num_stacks > 1:
env = VecFrameStack(env, n_stack=num_stacks, n_offset=stack_offset)
return env
示例10: createEnvs
# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def createEnvs(args, allow_early_resets=False, env_kwargs=None, load_path_normalise=None):
"""
:param args: (argparse.Namespace Object)
:param allow_early_resets: (bool) Allow reset before the enviroment is done, usually used in ES to halt the envs
:param env_kwargs: (dict) The extra arguments for the environment
:param load_path_normalise: (str) the path to loading the rolling average, None if not available or wanted.
:return: (Gym VecEnv)
"""
# imported here to prevent cyclic imports
from environments.registry import registered_env
from state_representation.registry import registered_srl, SRLType
assert not (registered_env[args.env][3] is ThreadingType.NONE and args.num_cpu != 1), \
"Error: cannot have more than 1 CPU for the environment {}".format(args.env)
if env_kwargs is not None and registered_srl[args.srl_model][0] == SRLType.SRL:
srl_model = MultiprocessSRLModel(args.num_cpu, args.env, env_kwargs)
env_kwargs["state_dim"] = srl_model.state_dim
env_kwargs["srl_pipe"] = srl_model.pipe
envs = [makeEnv(args.env, args.seed, i, args.log_dir, allow_early_resets=allow_early_resets, env_kwargs=env_kwargs)
for i in range(args.num_cpu)]
if len(envs) == 1:
# No need for subprocesses when having only one env
envs = DummyVecEnv(envs)
else:
envs = SubprocVecEnv(envs)
envs = VecFrameStack(envs, args.num_stack)
if args.srl_model != "raw_pixels":
printYellow("Using MLP policy because working on state representation")
envs = VecNormalize(envs, norm_obs=True, norm_reward=False)
envs = loadRunningAverage(envs, load_path_normalise=load_path_normalise)
return envs
示例11: create_env
# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def create_env(n_envs, eval_env=False):
"""
Create the environment and wrap it if necessary
:param n_envs: (int)
:param eval_env: (bool) Whether is it an environment used for evaluation or not
:return: (Union[gym.Env, VecEnv])
:return: (gym.Env)
"""
global hyperparams
global env_kwargs
# Do not log eval env (issue with writing the same file)
log_dir = None if eval_env else save_path
if is_atari:
if args.verbose > 0:
print("Using Atari wrapper")
env = make_atari_env(env_id, num_env=n_envs, seed=args.seed)
# Frame-stacking with 4 frames
env = VecFrameStack(env, n_stack=4)
elif algo_ in ['dqn', 'ddpg']:
if hyperparams.get('normalize', False):
print("WARNING: normalization not supported yet for DDPG/DQN")
env = gym.make(env_id, **env_kwargs)
env.seed(args.seed)
if env_wrapper is not None:
env = env_wrapper(env)
else:
if n_envs == 1:
env = DummyVecEnv([make_env(env_id, 0, args.seed, wrapper_class=env_wrapper, log_dir=log_dir, env_kwargs=env_kwargs)])
else:
# env = SubprocVecEnv([make_env(env_id, i, args.seed) for i in range(n_envs)])
# On most env, SubprocVecEnv does not help and is quite memory hungry
env = DummyVecEnv([make_env(env_id, i, args.seed, log_dir=log_dir,
wrapper_class=env_wrapper, env_kwargs=env_kwargs) for i in range(n_envs)])
if normalize:
if args.verbose > 0:
if len(normalize_kwargs) > 0:
print("Normalization activated: {}".format(normalize_kwargs))
else:
print("Normalizing input and reward")
env = VecNormalize(env, **normalize_kwargs)
# Optional Frame-stacking
if hyperparams.get('frame_stack', False):
n_stack = hyperparams['frame_stack']
env = VecFrameStack(env, n_stack)
print("Stacking {} frames".format(n_stack))
del hyperparams['frame_stack']
if args.algo == 'her':
# Wrap the env if need to flatten the dict obs
if isinstance(env, VecEnv):
env = _UnvecWrapper(env)
env = HERGoalEnvWrapper(env)
return env