当前位置: 首页>>代码示例>>Python>>正文


Python wrappers.FlattenObservation方法代码示例

本文整理汇总了Python中gym.wrappers.FlattenObservation方法的典型用法代码示例。如果您正苦于以下问题:Python wrappers.FlattenObservation方法的具体用法?Python wrappers.FlattenObservation怎么用?Python wrappers.FlattenObservation使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在gym.wrappers的用法示例。


在下文中一共展示了wrappers.FlattenObservation方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: make_robotics_env

# 需要导入模块: from gym import wrappers [as 别名]
# 或者: from gym.wrappers import FlattenObservation [as 别名]
def make_robotics_env(env_id, seed, rank=0, allow_early_resets=True):
    """
    Create a wrapped, monitored gym.Env for MuJoCo.

    :param env_id: (str) the environment ID
    :param seed: (int) the initial seed for RNG
    :param rank: (int) the rank of the environment (for logging)
    :param allow_early_resets: (bool) allows early reset of the environment
    :return: (Gym Environment) The robotic environment
    """
    set_global_seeds(seed)
    env = gym.make(env_id)
    keys = ['observation', 'desired_goal']
    # TODO: remove try-except once most users are running modern Gym
    try:  # for modern Gym (>=0.15.4)
        from gym.wrappers import FilterObservation, FlattenObservation
        env = FlattenObservation(FilterObservation(env, keys))
    except ImportError:  # for older gym (<=0.15.3)
        from gym.wrappers import FlattenDictWrapper  # pytype:disable=import-error
        env = FlattenDictWrapper(env, keys)
    env = Monitor(
        env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
        info_keywords=('is_success',), allow_early_resets=allow_early_resets)
    env.seed(seed)
    return env 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:27,代码来源:cmd_util.py

示例2: test_flatten_observation

# 需要导入模块: from gym import wrappers [as 别名]
# 或者: from gym.wrappers import FlattenObservation [as 别名]
def test_flatten_observation(env_id):
    env = gym.make(env_id)
    wrapped_env = FlattenObservation(env)

    obs = env.reset()
    wrapped_obs = wrapped_env.reset()

    if env_id == 'Blackjack-v0':
        space = spaces.Tuple((
            spaces.Discrete(32),
            spaces.Discrete(11),
            spaces.Discrete(2)))
        wrapped_space = spaces.Box(-np.inf, np.inf,
                                   [32 + 11 + 2], dtype=np.float32)
    elif env_id == 'KellyCoinflip-v0':
        space = spaces.Tuple((
            spaces.Box(0, 250.0, [1], dtype=np.float32),
            spaces.Discrete(300 + 1)))
        wrapped_space = spaces.Box(-np.inf, np.inf,
                                   [1 + (300 + 1)], dtype=np.float32)

    assert space.contains(obs)
    assert wrapped_space.contains(wrapped_obs) 
开发者ID:hust512,项目名称:DQN-DDPG_Stock_Trading,代码行数:25,代码来源:test_flatten_observation.py

示例3: make_robotics_env

# 需要导入模块: from gym import wrappers [as 别名]
# 或者: from gym.wrappers import FlattenObservation [as 别名]
def make_robotics_env(env_id, seed, rank=0):
    """
    Create a wrapped, monitored gym.Env for MuJoCo.
    """
    set_global_seeds(seed)
    env = gym.make(env_id)
    env = FlattenObservation(FilterObservation(env, ['observation', 'desired_goal']))
    env = Monitor(
        env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
        info_keywords=('is_success',))
    env.seed(seed)
    return env 
开发者ID:openai,项目名称:baselines,代码行数:14,代码来源:cmd_util.py

示例4: _make_flat

# 需要导入模块: from gym import wrappers [as 别名]
# 或者: from gym.wrappers import FlattenObservation [as 别名]
def _make_flat(*args, **kargs):
    if "FlattenDictWrapper" in dir():
        return FlattenDictWrapper(*args, **kargs)
    return FlattenObservation(FilterObservation(*args, **kargs)) 
开发者ID:DeepX-inc,项目名称:machina,代码行数:6,代码来源:test_env.py

示例5: make_env

# 需要导入模块: from gym import wrappers [as 别名]
# 或者: from gym.wrappers import FlattenObservation [as 别名]
def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, env_kwargs=None, logger_dir=None, initializer=None):
    if initializer is not None:
        initializer(mpi_rank=mpi_rank, subrank=subrank)

    wrapper_kwargs = wrapper_kwargs or {}
    env_kwargs = env_kwargs or {}
    if ':' in env_id:
        import re
        import importlib
        module_name = re.sub(':.*','',env_id)
        env_id = re.sub('.*:', '', env_id)
        importlib.import_module(module_name)
    if env_type == 'atari':
        env = make_atari(env_id)
    elif env_type == 'retro':
        import retro
        gamestate = gamestate or retro.State.DEFAULT
        env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate)
    else:
        env = gym.make(env_id, **env_kwargs)

    if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
        env = FlattenObservation(env)

    env.seed(seed + subrank if seed is not None else None)
    env = Monitor(env,
                  logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
                  allow_early_resets=True)


    if env_type == 'atari':
        env = wrap_deepmind(env, **wrapper_kwargs)
    elif env_type == 'retro':
        if 'frame_stack' not in wrapper_kwargs:
            wrapper_kwargs['frame_stack'] = 1
        env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)

    if isinstance(env.action_space, gym.spaces.Box):
        env = ClipActionsWrapper(env)

    if reward_scale != 1:
        env = retro_wrappers.RewardScaler(env, reward_scale)

    return env 
开发者ID:openai,项目名称:baselines,代码行数:46,代码来源:cmd_util.py


注:本文中的gym.wrappers.FlattenObservation方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。