当前位置: 首页>>代码示例>>Python>>正文


Python vec_env.DummyVecEnv方法代码示例

本文整理汇总了Python中stable_baselines.common.vec_env.DummyVecEnv方法的典型用法代码示例。如果您正苦于以下问题:Python vec_env.DummyVecEnv方法的具体用法?Python vec_env.DummyVecEnv怎么用?Python vec_env.DummyVecEnv使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在stable_baselines.common.vec_env的用法示例。


在下文中一共展示了vec_env.DummyVecEnv方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: initialize_optuna

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import DummyVecEnv [as 别名]
def initialize_optuna(self):
        try:
            train_env = DummyVecEnv([lambda: TradingEnv(self.data_provider)])
            model = self.Model(self.Policy, train_env, nminibatches=1)
            strategy = self.Reward_Strategy()

            self.study_name = f'{model.__class__.__name__}__{model.act_model.__class__.__name__}__{strategy.__class__.__name__}'
        except:
            self.study_name = f'UnknownModel__UnknownPolicy__UnknownStrategy'

        self.optuna_study = optuna.create_study(
            study_name=self.study_name, storage=self.params_db_path, load_if_exists=True)

        self.logger.debug('Initialized Optuna:')

        try:
            self.logger.debug(
                f'Best reward in ({len(self.optuna_study.trials)}) trials: {self.optuna_study.best_value}')
        except:
            self.logger.debug('No trials have been finished yet.') 
开发者ID:notadamking,项目名称:RLTrader,代码行数:22,代码来源:RLTrader.py

示例2: test_rollout_stats

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import DummyVecEnv [as 别名]
def test_rollout_stats():
    """Applying `ObsRewIncrementWrapper` halves the reward mean.

    `rollout_stats` should reflect this.
    """
    env = gym.make("CartPole-v1")
    env = bench.Monitor(env, None)
    env = ObsRewHalveWrapper(env)
    venv = vec_env.DummyVecEnv([lambda: env])

    with serialize.load_policy("zero", "UNUSED", venv) as policy:
        trajs = rollout.generate_trajectories(policy, venv, rollout.min_episodes(10))
    s = rollout.rollout_stats(trajs)

    np.testing.assert_allclose(s["return_mean"], s["monitor_return_mean"] / 2)
    np.testing.assert_allclose(s["return_std"], s["monitor_return_std"] / 2)
    np.testing.assert_allclose(s["return_min"], s["monitor_return_min"] / 2)
    np.testing.assert_allclose(s["return_max"], s["monitor_return_max"] / 2) 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:20,代码来源:test_rollout.py

示例3: test_unwrap_traj

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import DummyVecEnv [as 别名]
def test_unwrap_traj():
    """Check that unwrap_traj reverses `ObsRewIncrementWrapper`.

    Also check that unwrapping twice is a no-op.
    """
    env = gym.make("CartPole-v1")
    env = wrappers.RolloutInfoWrapper(env)
    env = ObsRewHalveWrapper(env)
    venv = vec_env.DummyVecEnv([lambda: env])

    with serialize.load_policy("zero", "UNUSED", venv) as policy:
        trajs = rollout.generate_trajectories(policy, venv, rollout.min_episodes(10))
    trajs_unwrapped = [rollout.unwrap_traj(t) for t in trajs]
    trajs_unwrapped_twice = [rollout.unwrap_traj(t) for t in trajs_unwrapped]

    for t, t_unwrapped in zip(trajs, trajs_unwrapped):
        np.testing.assert_allclose(t.acts, t_unwrapped.acts)
        np.testing.assert_allclose(t.obs, t_unwrapped.obs / 2)
        np.testing.assert_allclose(t.rews, t_unwrapped.rews / 2)

    for t1, t2 in zip(trajs_unwrapped, trajs_unwrapped_twice):
        np.testing.assert_equal(t1.acts, t2.acts)
        np.testing.assert_equal(t1.obs, t2.obs)
        np.testing.assert_equal(t1.rews, t2.rews) 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:26,代码来源:test_rollout.py

示例4: test_identity

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import DummyVecEnv [as 别名]
def test_identity(model_name):
    """
    Test if the algorithm (with a given policy)
    can learn an identity transformation (i.e. return observation as an action)

    :param model_name: (str) Name of the RL model
    """
    env = DummyVecEnv([lambda: IdentityEnv(10)])

    model = LEARN_FUNC_DICT[model_name](env)
    evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)

    obs = env.reset()
    assert model.action_probability(obs).shape == (1, 10), "Error: action_probability not returning correct shape"
    action = env.action_space.sample()
    action_prob = model.action_probability(obs, actions=action)
    assert np.prod(action_prob.shape) == 1, "Error: not scalar probability"
    action_logprob = model.action_probability(obs, actions=action, logp=True)
    assert np.allclose(action_prob, np.exp(action_logprob)), (action_prob, action_logprob)

    # Free memory
    del model, env 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:24,代码来源:test_identity.py

示例5: test_model_multiple_learn_no_reset

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import DummyVecEnv [as 别名]
def test_model_multiple_learn_no_reset(model_class):
    """Check that when we call learn multiple times, we don't unnecessarily
    reset the environment.
    """
    if model_class is ACER:
        def make_env():
            return IdentityEnv(ep_length=1e10, dim=2)
    else:
        def make_env():
            return IdentityEnvBox(ep_length=1e10)
    env = make_env()
    venv = DummyVecEnv([lambda: env])
    model = model_class(policy="MlpPolicy", env=venv)
    _check_reset_count(model, env)

    # Try again following a `set_env`.
    env = make_env()
    venv = DummyVecEnv([lambda: env])
    assert env.num_resets == 0

    model.set_env(venv)
    _check_reset_count(model, env) 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:24,代码来源:test_multiple_learn.py

示例6: test_vecenv_wrapper_getattr

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import DummyVecEnv [as 别名]
def test_vecenv_wrapper_getattr():
    def make_env():
        return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
    vec_env = DummyVecEnv([make_env for _ in range(N_ENVS)])
    wrapped = CustomWrapperA(CustomWrapperBB(vec_env))
    assert wrapped.var_a == 'a'
    assert wrapped.var_b == 'b'
    assert wrapped.var_bb == 'bb'
    assert wrapped.func_b() == 'b'
    assert wrapped.name_test() == CustomWrapperBB

    double_wrapped = CustomWrapperA(CustomWrapperB(wrapped))
    dummy = double_wrapped.var_a  # should not raise as it is directly defined here
    with pytest.raises(AttributeError):  # should raise due to ambiguity
        dummy = double_wrapped.var_b
    with pytest.raises(AttributeError):  # should raise as does not exist
        dummy = double_wrapped.nonexistent_attribute
    del dummy  # keep linter happy 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:20,代码来源:test_vec_envs.py

示例7: test_identity_multidiscrete

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import DummyVecEnv [as 别名]
def test_identity_multidiscrete(model_class):
    """
    Test if the algorithm (with a given policy)
    can learn an identity transformation (i.e. return observation as an action)
    with a multidiscrete action space

    :param model_class: (BaseRLModel) A RL Model
    """
    env = DummyVecEnv([lambda: IdentityEnvMultiDiscrete(10)])

    model = model_class("MlpPolicy", env)
    model.learn(total_timesteps=1000)
    evaluate_policy(model, env, n_eval_episodes=5)
    obs = env.reset()

    assert np.array(model.action_probability(obs)).shape == (2, 1, 10), \
        "Error: action_probability not returning correct shape"
    assert np.prod(model.action_probability(obs, actions=env.action_space.sample()).shape) == 1, \
        "Error: not scalar probability" 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:21,代码来源:test_action_space.py

示例8: test_make_vec_env

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import DummyVecEnv [as 别名]
def test_make_vec_env(env_id, n_envs, vec_env_cls, wrapper_class):
    env = make_vec_env(env_id, n_envs, vec_env_cls=vec_env_cls,
                       wrapper_class=wrapper_class, monitor_dir=None, seed=0)

    assert env.num_envs == n_envs

    if vec_env_cls is None:
        assert isinstance(env, DummyVecEnv)
        if wrapper_class is not None:
            assert isinstance(env.envs[0], wrapper_class)
        else:
            assert isinstance(env.envs[0], Monitor)
    else:
        assert isinstance(env, SubprocVecEnv)
    # Kill subprocesses
    env.close() 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:18,代码来源:test_utils.py

示例9: test_custom_vec_env

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import DummyVecEnv [as 别名]
def test_custom_vec_env():
    """
    Stand alone test for a special case (passing a custom VecEnv class) to avoid doubling the number of tests.
    """
    monitor_dir = 'logs/test_make_vec_env/'
    env = make_vec_env('CartPole-v1', n_envs=1,
                       monitor_dir=monitor_dir, seed=0,
                       vec_env_cls=SubprocVecEnv, vec_env_kwargs={'start_method': None})

    assert env.num_envs == 1
    assert isinstance(env, SubprocVecEnv)
    assert os.path.isdir('logs/test_make_vec_env/')
    # Kill subprocess
    env.close()
    # Cleanup folder
    shutil.rmtree(monitor_dir)

    # This should fail because DummyVecEnv does not have any keyword argument
    with pytest.raises(TypeError):
        make_vec_env('CartPole-v1', n_envs=1, vec_env_kwargs={'dummy': False}) 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:22,代码来源:test_utils.py

示例10: main

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import DummyVecEnv [as 别名]
def main(args):
	env_id = 'fwmav_hover-v1'

	env = DummyVecEnv([make_env(env_id, 0, random_init = args.rand_init, randomize_sim = args.rand_dynamics, phantom_sensor = args.phantom_sensor)])

	model = LazyModel(env.envs[0],args.model_type)

	obs = env.reset()

	while True:
		if env.envs[0].is_sim_on == False:
			env.envs[0].gui.cv.wait()
		elif env.envs[0].is_sim_on:
			action, _ = model.predict(obs)
			obs, rewards, done, info = env.step(action)
			if done:
				obs = env.reset() 
开发者ID:purdue-biorobotics,项目名称:flappy,代码行数:19,代码来源:test_simple.py

示例11: makeEnv

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import DummyVecEnv [as 别名]
def makeEnv(cls, args, env_kwargs=None, load_path_normalise=None):
        # Even though DeepQ is single core only, we need to use the pipe system to work
        if env_kwargs is not None and env_kwargs.get("use_srl", False):
            srl_model = MultiprocessSRLModel(1, args.env, env_kwargs)
            env_kwargs["state_dim"] = srl_model.state_dim
            env_kwargs["srl_pipe"] = srl_model.pipe

        envs = DummyVecEnv([makeEnv(args.env, args.seed, 0, args.log_dir, env_kwargs=env_kwargs)])
        envs = VecFrameStack(envs, args.num_stack)

        if args.srl_model != "raw_pixels":
            printYellow("Using MLP policy because working on state representation")
            envs = VecNormalize(envs, norm_obs=True, norm_reward=False)
            envs = loadRunningAverage(envs, load_path_normalise=load_path_normalise)

        return envs 
开发者ID:araffin,项目名称:robotics-rl-srl,代码行数:18,代码来源:trpo.py

示例12: create_simple_policy_wrapper

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import DummyVecEnv [as 别名]
def create_simple_policy_wrapper(env_name, num_envs, state_shapes):
    vec_env = DummyVecEnv([lambda: gym.make(env_name) for _ in range(num_envs)])
    num_actions = vec_env.action_space.n  # for Discrete spaces

    policies = []
    for i, state_shape in enumerate(state_shapes):
        constant_value = np.full(shape=vec_env.action_space.shape, fill_value=i % num_actions)
        policy = _get_constant_policy(
            vec_env, constant_value=constant_value, state_shape=state_shape
        )
        policies.append(policy)
    policy_wrapper = MultiPolicyWrapper(policies=policies, num_envs=num_envs)

    yield vec_env, policy_wrapper
    policy_wrapper.close() 
开发者ID:HumanCompatibleAI,项目名称:adversarial-policies,代码行数:17,代码来源:test_wrappers.py

示例13: load_stable_baselines_env

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import DummyVecEnv [as 别名]
def load_stable_baselines_env(cfg_path, vector_length, mp, n_stack, number_maps, action_frame_repeat,
                              scaled_resolution):
    env_fn = lambda: MazeExplorer.load_vizdoom_env(cfg_path, number_maps, action_frame_repeat, scaled_resolution)

    if mp:
        env = SubprocVecEnv([env_fn for _ in range(vector_length)])
    else:
        env = DummyVecEnv([env_fn for _ in range(vector_length)])

    if n_stack > 0:
        env = VecFrameStack(env, n_stack=n_stack)

    return env 
开发者ID:microsoft,项目名称:MazeExplorer,代码行数:15,代码来源:evaluator.py

示例14: _sample_fixed_length_trajectories

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import DummyVecEnv [as 别名]
def _sample_fixed_length_trajectories(
    episode_lengths: Sequence[int], min_episodes: int, **kwargs,
) -> Sequence[types.Trajectory]:
    venv = vec_env.DummyVecEnv(
        [functools.partial(TerminalSentinelEnv, length) for length in episode_lengths]
    )
    policy = RandomPolicy(venv.observation_space, venv.action_space)
    sample_until = rollout.min_episodes(min_episodes)
    trajectories = rollout.generate_trajectories(
        policy, venv, sample_until=sample_until, **kwargs,
    )
    return trajectories 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:14,代码来源:test_rollout.py

示例15: _make_buffering_venv

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import DummyVecEnv [as 别名]
def _make_buffering_venv(error_on_premature_reset: bool,) -> BufferingWrapper:
    venv = DummyVecEnv([_CountingEnv] * 2)
    venv = BufferingWrapper(venv, error_on_premature_reset)
    venv.reset()
    return venv 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:7,代码来源:test_buffering_wrapper.py


注:本文中的stable_baselines.common.vec_env.DummyVecEnv方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。