当前位置: 首页>>代码示例>>Python>>正文


Python vec_env.VecFrameStack方法代码示例

本文整理汇总了Python中stable_baselines.common.vec_env.VecFrameStack方法的典型用法代码示例。如果您正苦于以下问题:Python vec_env.VecFrameStack方法的具体用法?Python vec_env.VecFrameStack怎么用?Python vec_env.VecFrameStack使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在stable_baselines.common.vec_env的用法示例。


在下文中一共展示了vec_env.VecFrameStack方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: makeEnv

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def makeEnv(cls, args, env_kwargs=None, load_path_normalise=None):
        # Even though DeepQ is single core only, we need to use the pipe system to work
        if env_kwargs is not None and env_kwargs.get("use_srl", False):
            srl_model = MultiprocessSRLModel(1, args.env, env_kwargs)
            env_kwargs["state_dim"] = srl_model.state_dim
            env_kwargs["srl_pipe"] = srl_model.pipe

        envs = DummyVecEnv([makeEnv(args.env, args.seed, 0, args.log_dir, env_kwargs=env_kwargs)])
        envs = VecFrameStack(envs, args.num_stack)

        if args.srl_model != "raw_pixels":
            printYellow("Using MLP policy because working on state representation")
            envs = VecNormalize(envs, norm_obs=True, norm_reward=False)
            envs = loadRunningAverage(envs, load_path_normalise=load_path_normalise)

        return envs 
开发者ID:araffin,项目名称:robotics-rl-srl,代码行数:18,代码来源:trpo.py

示例2: load_stable_baselines_env

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def load_stable_baselines_env(cfg_path, vector_length, mp, n_stack, number_maps, action_frame_repeat,
                              scaled_resolution):
    env_fn = lambda: MazeExplorer.load_vizdoom_env(cfg_path, number_maps, action_frame_repeat, scaled_resolution)

    if mp:
        env = SubprocVecEnv([env_fn for _ in range(vector_length)])
    else:
        env = DummyVecEnv([env_fn for _ in range(vector_length)])

    if n_stack > 0:
        env = VecFrameStack(env, n_stack=n_stack)

    return env 
开发者ID:microsoft,项目名称:MazeExplorer,代码行数:15,代码来源:evaluator.py

示例3: train

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def train(env_id, num_timesteps, seed, policy, lr_schedule, num_env):
    """
    Train A2C model for atari environment, for testing purposes

    :param env_id: (str) Environment ID
    :param num_timesteps: (int) The total number of samples
    :param seed: (int) The initial seed for training
    :param policy: (A2CPolicy) The policy model to use (MLP, CNN, LSTM, ...)
    :param lr_schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant',
                                 'double_linear_con', 'middle_drop' or 'double_middle_drop')
    :param num_env: (int) The number of environments
    """
    policy_fn = None
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = CnnLstmPolicy
    elif policy == 'lnlstm':
        policy_fn = CnnLnLstmPolicy
    if policy_fn is None:
        raise ValueError("Error: policy {} not implemented".format(policy))

    env = VecFrameStack(make_atari_env(env_id, num_env, seed), 4)

    model = A2C(policy_fn, env, lr_schedule=lr_schedule, seed=seed)
    model.learn(total_timesteps=int(num_timesteps * 1.1))
    env.close() 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:29,代码来源:run_atari.py

示例4: train

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def train(env_id, num_timesteps, seed, policy,
          n_envs=8, nminibatches=4, n_steps=128):
    """
    Train PPO2 model for atari environment, for testing purposes

    :param env_id: (str) the environment id string
    :param num_timesteps: (int) the number of timesteps to run
    :param seed: (int) Used to seed the random generator.
    :param policy: (Object) The policy model to use (MLP, CNN, LSTM, ...)
    :param n_envs: (int) Number of parallel environments
    :param nminibatches: (int) Number of training minibatches per update. For recurrent policies,
        the number of environments run in parallel should be a multiple of nminibatches.
    :param n_steps: (int) The number of steps to run for each environment per update
        (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
    """

    env = VecFrameStack(make_atari_env(env_id, n_envs, seed), 4)
    policy = {'cnn': CnnPolicy, 'lstm': CnnLstmPolicy, 'lnlstm': CnnLnLstmPolicy, 'mlp': MlpPolicy}[policy]
    model = PPO2(policy=policy, env=env, n_steps=n_steps, nminibatches=nminibatches,
                 lam=0.95, gamma=0.99, noptepochs=4, ent_coef=.01,
                 learning_rate=lambda f: f * 2.5e-4, cliprange=lambda f: f * 0.1, verbose=1)
    model.learn(total_timesteps=num_timesteps)

    env.close()
    # Free memory
    del model 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:28,代码来源:run_atari.py

示例5: train

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def train(env_id, num_timesteps, seed, policy, lr_schedule, num_cpu):
    """
    train an ACER model on atari

    :param env_id: (str) Environment ID
    :param num_timesteps: (int) The total number of samples
    :param seed: (int) The initial seed for training
    :param policy: (A2CPolicy) The policy model to use (MLP, CNN, LSTM, ...)
    :param lr_schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant',
                                 'double_linear_con', 'middle_drop' or 'double_middle_drop')
    :param num_cpu: (int) The number of cpu to train on
    """
    env = VecFrameStack(make_atari_env(env_id, num_cpu, seed), 4)
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = CnnLstmPolicy
    else:
        warnings.warn("Policy {} not implemented".format(policy))
        return

    model = ACER(policy_fn, env, lr_schedule=lr_schedule, buffer_size=5000, seed=seed)
    model.learn(total_timesteps=int(num_timesteps * 1.1))
    env.close()
    # Free memory
    del model 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:28,代码来源:run_atari.py

示例6: test_vecenv_terminal_obs

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper):
    """Test that 'terminal_observation' gets added to info dict upon
    termination."""
    step_nums = [i + 5 for i in range(N_ENVS)]
    vec_env = vec_env_class([functools.partial(StepEnv, n) for n in step_nums])

    if vec_env_wrapper is not None:
        if vec_env_wrapper == VecFrameStack:
            vec_env = vec_env_wrapper(vec_env, n_stack=2)
        else:
            vec_env = vec_env_wrapper(vec_env)

    zero_acts = np.zeros((N_ENVS,), dtype='int')
    prev_obs_b = vec_env.reset()
    for step_num in range(1, max(step_nums) + 1):
        obs_b, _, done_b, info_b = vec_env.step(zero_acts)
        assert len(obs_b) == N_ENVS
        assert len(done_b) == N_ENVS
        assert len(info_b) == N_ENVS
        env_iter = zip(prev_obs_b, obs_b, done_b, info_b, step_nums)
        for prev_obs, obs, done, info, final_step_num in env_iter:
            assert done == (step_num == final_step_num)
            if not done:
                assert 'terminal_observation' not in info
            else:
                terminal_obs = info['terminal_observation']

                # do some rough ordering checks that should work for all
                # wrappers, including VecNormalize
                assert np.all(prev_obs < terminal_obs)
                assert np.all(obs < prev_obs)

                if not isinstance(vec_env, VecNormalize):
                    # more precise tests that we can't do with VecNormalize
                    # (which changes observation values)
                    assert np.all(prev_obs + 1 == terminal_obs)
                    assert np.all(obs == 0)

        prev_obs_b = obs_b

    vec_env.close() 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:43,代码来源:test_vec_envs.py

示例7: test_pretrain_images

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def test_pretrain_images(tmp_path):
    env = make_atari_env("PongNoFrameskip-v4", num_env=1, seed=0)
    env = VecFrameStack(env, n_stack=4)
    model = PPO2('CnnPolicy', env)
    generate_expert_traj(model, str(tmp_path / 'expert_pong'), n_timesteps=0, n_episodes=1,
                         image_folder=str(tmp_path / 'pretrain_recorded_images'))

    expert_path = str(tmp_path / 'expert_pong.npz')
    dataset = ExpertDataset(expert_path=expert_path, traj_limitation=1, batch_size=32,
                            sequential_preprocessing=True)
    model.pretrain(dataset, n_epochs=2)

    shutil.rmtree(str(tmp_path / 'pretrain_recorded_images'))
    env.close()
    del dataset, model, env 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:17,代码来源:test_gail.py

示例8: load_train_env

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def load_train_env(ns, state_collector, robot_radius, rew_fnc, num_stacks,
                   stack_offset, debug, task_mode, rl_mode, policy, disc_action_space, normalize):
    # Choosing environment wrapper according to the policy
    if policy == "CnnPolicy" or policy == "CnnLnLstmPolicy" or policy == "CnnLstmPolicy":
        if disc_action_space:
            env_temp = RosEnvDiscImg
        else:
            env_temp = RosEnvContImg
    elif policy in ["CNN1DPolicy", "CNN1DPolicy2", "CNN1DPolicy3"]:
        if disc_action_space:
            env_temp = RosEnvDiscRawScanPrepWp
        else:
            env_temp = RosEnvContRawScanPrepWp
    elif policy == "CNN1DPolicy_multi_input":
        if disc_action_space:
            env_temp = RosEnvDiscRaw
        else:
            env_temp = RosEnvContRaw
    elif policy == "CnnPolicy_multi_input_vel" or policy == "CnnPolicy_multi_input_vel2":
        if disc_action_space:
            env_temp = RosEnvDiscImgVel
        else:
            env_temp = RosEnvContImgVel


    env_raw = DummyVecEnv([lambda: env_temp(ns, state_collector, stack_offset, num_stacks, robot_radius, rew_fnc, debug, rl_mode, task_mode)])

    if normalize:
        env = VecNormalize(env_raw, training=True, norm_obs=True, norm_reward=False, clip_obs=100.0, clip_reward=10.0,
                           gamma=0.99, epsilon=1e-08)
    else:
        env = env_raw

    # Stack of data?
    if num_stacks > 1:
        env = VecFrameStack(env, n_stack=num_stacks, n_offset=stack_offset)

    return env 
开发者ID:RGring,项目名称:drl_local_planner_ros_stable_baselines,代码行数:40,代码来源:run_ppo.py

示例9: load_train_env

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def load_train_env(num_envs, robot_radius, rew_fnc, num_stacks, stack_offset, debug, task_mode, policy, disc_action_space, normalize):
    # Choosing environment wrapper according to the policy
    if policy == "CnnPolicy" or policy == "CnnLnLstmPolicy" or policy == "CnnLstmPolicy":
        if disc_action_space:
            env_temp = RosEnvDiscImg
        else:
            env_temp = RosEnvContImg
    elif policy == "CNN1DPolicy":
        if disc_action_space:
            env_temp = RosEnvDiscRawScanPrepWp
        else:
            env_temp = RosEnvContRawScanPrepWp
    elif policy == "CNN1DPolicy_multi_input":
        if disc_action_space:
            env_temp = RosEnvDiscRaw
        else:
            env_temp = RosEnvContRaw
    elif policy == "CnnPolicy_multi_input_vel" or policy == "CnnPolicy_multi_input_vel2":
        if disc_action_space:
            env_temp = RosEnvDiscImgVel
        else:
            env_temp = RosEnvContImgVel

    env = SubprocVecEnv([lambda k=k: Monitor(env_temp("sim%d" % (k+1), StateCollector("sim%s"%(k+1), "train") , stack_offset, num_stacks, robot_radius, rew_fnc, debug, "train", task_mode), '%s/%s/sim_%d'%(path_to_models, agent_name, k+1), allow_early_resets=True) for k in range(num_envs)])

    # Normalizing?
    if normalize:
        env = VecNormalize(env, training=True, norm_obs=True, norm_reward=False, clip_obs=100.0, clip_reward=10.0,
                           gamma=0.99, epsilon=1e-08)
    else:
        env = env

    # Stack of data?
    if num_stacks > 1:
        env = VecFrameStack(env, n_stack=num_stacks, n_offset=stack_offset)

    return env 
开发者ID:RGring,项目名称:drl_local_planner_ros_stable_baselines,代码行数:39,代码来源:train_ppo.py

示例10: createEnvs

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def createEnvs(args, allow_early_resets=False, env_kwargs=None, load_path_normalise=None):
    """
    :param args: (argparse.Namespace Object)
    :param allow_early_resets: (bool) Allow reset before the enviroment is done, usually used in ES to halt the envs
    :param env_kwargs: (dict) The extra arguments for the environment
    :param load_path_normalise: (str) the path to loading the rolling average, None if not available or wanted.
    :return: (Gym VecEnv)
    """
    # imported here to prevent cyclic imports
    from environments.registry import registered_env
    from state_representation.registry import registered_srl, SRLType

    assert not (registered_env[args.env][3] is ThreadingType.NONE and args.num_cpu != 1), \
        "Error: cannot have more than 1 CPU for the environment {}".format(args.env)

    if env_kwargs is not None and registered_srl[args.srl_model][0] == SRLType.SRL:
        srl_model = MultiprocessSRLModel(args.num_cpu, args.env, env_kwargs)
        env_kwargs["state_dim"] = srl_model.state_dim
        env_kwargs["srl_pipe"] = srl_model.pipe
    envs = [makeEnv(args.env, args.seed, i, args.log_dir, allow_early_resets=allow_early_resets, env_kwargs=env_kwargs)
            for i in range(args.num_cpu)]

    if len(envs) == 1:
        # No need for subprocesses when having only one env
        envs = DummyVecEnv(envs)
    else:
        envs = SubprocVecEnv(envs)

    envs = VecFrameStack(envs, args.num_stack)

    if args.srl_model != "raw_pixels":
        printYellow("Using MLP policy because working on state representation")
        envs = VecNormalize(envs, norm_obs=True, norm_reward=False)
        envs = loadRunningAverage(envs, load_path_normalise=load_path_normalise)

    return envs 
开发者ID:araffin,项目名称:robotics-rl-srl,代码行数:38,代码来源:utils.py

示例11: create_env

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecFrameStack [as 别名]
def create_env(n_envs, eval_env=False):
        """
        Create the environment and wrap it if necessary
        :param n_envs: (int)
        :param eval_env: (bool) Whether is it an environment used for evaluation or not
        :return: (Union[gym.Env, VecEnv])
        :return: (gym.Env)
        """
        global hyperparams
        global env_kwargs

        # Do not log eval env (issue with writing the same file)
        log_dir = None if eval_env else save_path

        if is_atari:
            if args.verbose > 0:
                print("Using Atari wrapper")
            env = make_atari_env(env_id, num_env=n_envs, seed=args.seed)
            # Frame-stacking with 4 frames
            env = VecFrameStack(env, n_stack=4)
        elif algo_ in ['dqn', 'ddpg']:
            if hyperparams.get('normalize', False):
                print("WARNING: normalization not supported yet for DDPG/DQN")
            env = gym.make(env_id, **env_kwargs)
            env.seed(args.seed)
            if env_wrapper is not None:
                env = env_wrapper(env)
        else:
            if n_envs == 1:
                env = DummyVecEnv([make_env(env_id, 0, args.seed, wrapper_class=env_wrapper, log_dir=log_dir, env_kwargs=env_kwargs)])
            else:
                # env = SubprocVecEnv([make_env(env_id, i, args.seed) for i in range(n_envs)])
                # On most env, SubprocVecEnv does not help and is quite memory hungry
                env = DummyVecEnv([make_env(env_id, i, args.seed, log_dir=log_dir,
                                            wrapper_class=env_wrapper, env_kwargs=env_kwargs) for i in range(n_envs)])
            if normalize:
                if args.verbose > 0:
                    if len(normalize_kwargs) > 0:
                        print("Normalization activated: {}".format(normalize_kwargs))
                    else:
                        print("Normalizing input and reward")
                env = VecNormalize(env, **normalize_kwargs)
        # Optional Frame-stacking
        if hyperparams.get('frame_stack', False):
            n_stack = hyperparams['frame_stack']
            env = VecFrameStack(env, n_stack)
            print("Stacking {} frames".format(n_stack))
            del hyperparams['frame_stack']
        if args.algo == 'her':
            # Wrap the env if need to flatten the dict obs
            if isinstance(env, VecEnv):
                env = _UnvecWrapper(env)
            env = HERGoalEnvWrapper(env)
        return env 
开发者ID:araffin,项目名称:rl-baselines-zoo,代码行数:56,代码来源:train.py


注:本文中的stable_baselines.common.vec_env.VecFrameStack方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。