当前位置: 首页>>代码示例>>Python>>正文


Python vec_env.VecNormalize方法代码示例

本文整理汇总了Python中stable_baselines.common.vec_env.VecNormalize方法的典型用法代码示例。如果您正苦于以下问题:Python vec_env.VecNormalize方法的具体用法?Python vec_env.VecNormalize怎么用?Python vec_env.VecNormalize使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在stable_baselines.common.vec_env的用法示例。


在下文中一共展示了vec_env.VecNormalize方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: load_stable_baselines

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecNormalize [as 别名]
def load_stable_baselines(cls):
    def f(root_dir, env, env_name, index, transparent_params):
        denv = FakeSingleSpacesVec(env, agent_id=index)
        pylog.info(f"Loading Stable Baselines policy for '{cls}' from '{root_dir}'")
        model = load_backward_compatible_model(cls, root_dir, denv)

        try:
            vec_normalize = load_vec_normalize(root_dir, denv)
            model = NormalizeModel(model, vec_normalize)
        except FileNotFoundError:
            # No saved VecNormalize, must have not trained with normalization.
            pass

        return model

    return f 
开发者ID:HumanCompatibleAI,项目名称:adversarial-policies,代码行数:18,代码来源:loader.py

示例2: save_stable_model

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecNormalize [as 别名]
def save_stable_model(
    output_dir: str, model: BaseRLModel, vec_normalize: Optional[VecNormalize] = None,
) -> None:
    """Serialize policy.

    Load later with `load_policy(..., policy_path=output_dir)`.

    Args:
        output_dir: Path to the save directory.
        policy: The stable baselines policy.
        vec_normalize: Optionally, a VecNormalize to save statistics for.
            `load_policy` automatically applies `NormalizePolicy` wrapper
            when loading.
    """
    os.makedirs(output_dir, exist_ok=True)
    model.save(os.path.join(output_dir, "model.pkl"))
    if vec_normalize is not None:
        with open(os.path.join(output_dir, "vec_normalize.pkl"), "wb") as f:
            pickle.dump(vec_normalize, f)
    tf.logging.info("Saved policy to %s", output_dir) 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:22,代码来源:serialize.py

示例3: sample

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecNormalize [as 别名]
def sample(self, batch_size: int, env: Optional[VecNormalize] = None, **_kwargs):
        """
        Sample a batch of experiences.

        :param batch_size: (int) How many transitions to sample.
        :param env: (Optional[VecNormalize]) associated gym VecEnv
            to normalize the observations/rewards when sampling
        :return:
            - obs_batch: (np.ndarray) batch of observations
            - act_batch: (numpy float) batch of actions executed given obs_batch
            - rew_batch: (numpy float) rewards received as results of executing act_batch
            - next_obs_batch: (np.ndarray) next set of observations seen after executing act_batch
            - done_mask: (numpy bool) done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode
                and 0 otherwise.
        """
        idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
        return self._encode_sample(idxes, env=env) 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:19,代码来源:buffers.py

示例4: makeEnv

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecNormalize [as 别名]
def makeEnv(cls, args, env_kwargs=None, load_path_normalise=None):
        # Even though DeepQ is single core only, we need to use the pipe system to work
        if env_kwargs is not None and env_kwargs.get("use_srl", False):
            srl_model = MultiprocessSRLModel(1, args.env, env_kwargs)
            env_kwargs["state_dim"] = srl_model.state_dim
            env_kwargs["srl_pipe"] = srl_model.pipe

        envs = DummyVecEnv([makeEnv(args.env, args.seed, 0, args.log_dir, env_kwargs=env_kwargs)])
        envs = VecFrameStack(envs, args.num_stack)

        if args.srl_model != "raw_pixels":
            printYellow("Using MLP policy because working on state representation")
            envs = VecNormalize(envs, norm_obs=True, norm_reward=False)
            envs = loadRunningAverage(envs, load_path_normalise=load_path_normalise)

        return envs 
开发者ID:araffin,项目名称:robotics-rl-srl,代码行数:18,代码来源:trpo.py

示例5: __init__

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecNormalize [as 别名]
def __init__(
        self,
        model: stable_baselines.common.base_class.BaseRLModel,
        vec_normalize: vec_env.VecNormalize,
    ):
        super().__init__(model=model)
        self.vec_normalize = vec_normalize 
开发者ID:HumanCompatibleAI,项目名称:adversarial-policies,代码行数:9,代码来源:loader.py

示例6: load_vec_normalize

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecNormalize [as 别名]
def load_vec_normalize(root_dir: str, venv: vec_env.VecEnv) -> vec_env.VecNormalize:
    try:
        normalize_path = os.path.join(root_dir, "vec_normalize.pkl")
        vec_normalize = vec_env.VecNormalize.load(normalize_path, venv)
        vec_normalize.training = False
        pylog.info(f"Loaded normalization statistics from '{normalize_path}'")
        return vec_normalize
    except FileNotFoundError:
        pass

    # Could not find vec_normalize.pkl: try loading old-style vec normalize.
    vec_normalize = vec_env.VecNormalize(venv, training=False)
    vec_normalize.load_running_average(root_dir)
    pylog.info(f"Loaded normalization statistics from '{root_dir}'")
    return vec_normalize 
开发者ID:HumanCompatibleAI,项目名称:adversarial-policies,代码行数:17,代码来源:loader.py

示例7: __init__

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecNormalize [as 别名]
def __init__(self, policy: BasePolicy, vec_normalize: VecNormalize):
        super().__init__(
            policy.sess,
            policy.ob_space,
            policy.ac_space,
            policy.n_env,
            policy.n_steps,
            policy.n_batch,
        )
        self._policy = policy
        self.vec_normalize = vec_normalize 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:13,代码来源:serialize.py

示例8: _reward_fn_normalize_inputs

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecNormalize [as 别名]
def _reward_fn_normalize_inputs(
    obs: np.ndarray,
    acts: np.ndarray,
    next_obs: np.ndarray,
    dones: np.ndarray,
    *,
    reward_fn: RewardFn,
    vec_normalize: vec_env.VecNormalize,
    norm_reward: bool = True,
) -> np.ndarray:
    """Combine with `functools.partial` to create an input-normalizing RewardFn.

    Args:
        reward_fn: The reward function that normalized inputs are evaluated on.
        vec_normalize: Instance of VecNormalize used to normalize inputs and
            rewards.
        norm_reward: If True, then also normalize reward before returning.

    Returns:
        The possibly normalized reward.
    """
    norm_obs = vec_normalize.normalize_obs(obs)
    norm_next_obs = vec_normalize.normalize_obs(next_obs)
    rew = reward_fn(norm_obs, acts, norm_next_obs, dones)
    if norm_reward:
        rew = vec_normalize.normalize_reward(rew)
    return rew 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:29,代码来源:common.py

示例9: _normalize_obs

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecNormalize [as 别名]
def _normalize_obs(obs: np.ndarray,
                       env: Optional[VecNormalize] = None) -> np.ndarray:
        """
        Helper for normalizing the observation.
        """
        if env is not None:
            return env.normalize_obs(obs)
        return obs 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:10,代码来源:buffers.py

示例10: _normalize_reward

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecNormalize [as 别名]
def _normalize_reward(reward: np.ndarray,
                          env: Optional[VecNormalize] = None) -> np.ndarray:
        """
        Helper for normalizing the reward.
        """
        if env is not None:
            return env.normalize_reward(reward)
        return reward 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:10,代码来源:buffers.py

示例11: _encode_sample

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecNormalize [as 别名]
def _encode_sample(self, idxes: Union[List[int], np.ndarray], env: Optional[VecNormalize] = None):
        obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], []
        for i in idxes:
            data = self._storage[i]
            obs_t, action, reward, obs_tp1, done = data
            obses_t.append(np.array(obs_t, copy=False))
            actions.append(np.array(action, copy=False))
            rewards.append(reward)
            obses_tp1.append(np.array(obs_tp1, copy=False))
            dones.append(done)
        return (self._normalize_obs(np.array(obses_t), env),
                np.array(actions),
                self._normalize_reward(np.array(rewards), env),
                self._normalize_obs(np.array(obses_tp1), env),
                np.array(dones)) 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:17,代码来源:buffers.py

示例12: load_train_env

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecNormalize [as 别名]
def load_train_env(ns, state_collector, robot_radius, rew_fnc, num_stacks,
                   stack_offset, debug, task_mode, rl_mode, policy, disc_action_space, normalize):
    # Choosing environment wrapper according to the policy
    if policy == "CnnPolicy" or policy == "CnnLnLstmPolicy" or policy == "CnnLstmPolicy":
        if disc_action_space:
            env_temp = RosEnvDiscImg
        else:
            env_temp = RosEnvContImg
    elif policy in ["CNN1DPolicy", "CNN1DPolicy2", "CNN1DPolicy3"]:
        if disc_action_space:
            env_temp = RosEnvDiscRawScanPrepWp
        else:
            env_temp = RosEnvContRawScanPrepWp
    elif policy == "CNN1DPolicy_multi_input":
        if disc_action_space:
            env_temp = RosEnvDiscRaw
        else:
            env_temp = RosEnvContRaw
    elif policy == "CnnPolicy_multi_input_vel" or policy == "CnnPolicy_multi_input_vel2":
        if disc_action_space:
            env_temp = RosEnvDiscImgVel
        else:
            env_temp = RosEnvContImgVel


    env_raw = DummyVecEnv([lambda: env_temp(ns, state_collector, stack_offset, num_stacks, robot_radius, rew_fnc, debug, rl_mode, task_mode)])

    if normalize:
        env = VecNormalize(env_raw, training=True, norm_obs=True, norm_reward=False, clip_obs=100.0, clip_reward=10.0,
                           gamma=0.99, epsilon=1e-08)
    else:
        env = env_raw

    # Stack of data?
    if num_stacks > 1:
        env = VecFrameStack(env, n_stack=num_stacks, n_offset=stack_offset)

    return env 
开发者ID:RGring,项目名称:drl_local_planner_ros_stable_baselines,代码行数:40,代码来源:run_ppo.py

示例13: load_train_env

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecNormalize [as 别名]
def load_train_env(num_envs, robot_radius, rew_fnc, num_stacks, stack_offset, debug, task_mode, policy, disc_action_space, normalize):
    # Choosing environment wrapper according to the policy
    if policy == "CnnPolicy" or policy == "CnnLnLstmPolicy" or policy == "CnnLstmPolicy":
        if disc_action_space:
            env_temp = RosEnvDiscImg
        else:
            env_temp = RosEnvContImg
    elif policy == "CNN1DPolicy":
        if disc_action_space:
            env_temp = RosEnvDiscRawScanPrepWp
        else:
            env_temp = RosEnvContRawScanPrepWp
    elif policy == "CNN1DPolicy_multi_input":
        if disc_action_space:
            env_temp = RosEnvDiscRaw
        else:
            env_temp = RosEnvContRaw
    elif policy == "CnnPolicy_multi_input_vel" or policy == "CnnPolicy_multi_input_vel2":
        if disc_action_space:
            env_temp = RosEnvDiscImgVel
        else:
            env_temp = RosEnvContImgVel

    env = SubprocVecEnv([lambda k=k: Monitor(env_temp("sim%d" % (k+1), StateCollector("sim%s"%(k+1), "train") , stack_offset, num_stacks, robot_radius, rew_fnc, debug, "train", task_mode), '%s/%s/sim_%d'%(path_to_models, agent_name, k+1), allow_early_resets=True) for k in range(num_envs)])

    # Normalizing?
    if normalize:
        env = VecNormalize(env, training=True, norm_obs=True, norm_reward=False, clip_obs=100.0, clip_reward=10.0,
                           gamma=0.99, epsilon=1e-08)
    else:
        env = env

    # Stack of data?
    if num_stacks > 1:
        env = VecFrameStack(env, n_stack=num_stacks, n_offset=stack_offset)

    return env 
开发者ID:RGring,项目名称:drl_local_planner_ros_stable_baselines,代码行数:39,代码来源:train_ppo.py

示例14: get_original_obs

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecNormalize [as 别名]
def get_original_obs(self):
        """
        Hack to use VecNormalize
        :return: (numpy float)
        """
        return self.venv.get_original_obs() 
开发者ID:araffin,项目名称:robotics-rl-srl,代码行数:8,代码来源:utils.py

示例15: saveRunningAverage

# 需要导入模块: from stable_baselines.common import vec_env [as 别名]
# 或者: from stable_baselines.common.vec_env import VecNormalize [as 别名]
def saveRunningAverage(self, path):
        """
        Hack to use VecNormalize
        :param path: (str) path to log dir
        """
        self.venv.save_running_average(path) 
开发者ID:araffin,项目名称:robotics-rl-srl,代码行数:8,代码来源:utils.py


注:本文中的stable_baselines.common.vec_env.VecNormalize方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。