本文整理汇总了Python中gym.wrappers.FlattenObservation方法的典型用法代码示例。如果您正苦于以下问题:Python wrappers.FlattenObservation方法的具体用法?Python wrappers.FlattenObservation怎么用?Python wrappers.FlattenObservation使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类gym.wrappers
的用法示例。
在下文中一共展示了wrappers.FlattenObservation方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: make_robotics_env
# 需要导入模块: from gym import wrappers [as 别名]
# 或者: from gym.wrappers import FlattenObservation [as 别名]
def make_robotics_env(env_id, seed, rank=0, allow_early_resets=True):
"""
Create a wrapped, monitored gym.Env for MuJoCo.
:param env_id: (str) the environment ID
:param seed: (int) the initial seed for RNG
:param rank: (int) the rank of the environment (for logging)
:param allow_early_resets: (bool) allows early reset of the environment
:return: (Gym Environment) The robotic environment
"""
set_global_seeds(seed)
env = gym.make(env_id)
keys = ['observation', 'desired_goal']
# TODO: remove try-except once most users are running modern Gym
try: # for modern Gym (>=0.15.4)
from gym.wrappers import FilterObservation, FlattenObservation
env = FlattenObservation(FilterObservation(env, keys))
except ImportError: # for older gym (<=0.15.3)
from gym.wrappers import FlattenDictWrapper # pytype:disable=import-error
env = FlattenDictWrapper(env, keys)
env = Monitor(
env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
info_keywords=('is_success',), allow_early_resets=allow_early_resets)
env.seed(seed)
return env
示例2: test_flatten_observation
# 需要导入模块: from gym import wrappers [as 别名]
# 或者: from gym.wrappers import FlattenObservation [as 别名]
def test_flatten_observation(env_id):
env = gym.make(env_id)
wrapped_env = FlattenObservation(env)
obs = env.reset()
wrapped_obs = wrapped_env.reset()
if env_id == 'Blackjack-v0':
space = spaces.Tuple((
spaces.Discrete(32),
spaces.Discrete(11),
spaces.Discrete(2)))
wrapped_space = spaces.Box(-np.inf, np.inf,
[32 + 11 + 2], dtype=np.float32)
elif env_id == 'KellyCoinflip-v0':
space = spaces.Tuple((
spaces.Box(0, 250.0, [1], dtype=np.float32),
spaces.Discrete(300 + 1)))
wrapped_space = spaces.Box(-np.inf, np.inf,
[1 + (300 + 1)], dtype=np.float32)
assert space.contains(obs)
assert wrapped_space.contains(wrapped_obs)
示例3: make_robotics_env
# 需要导入模块: from gym import wrappers [as 别名]
# 或者: from gym.wrappers import FlattenObservation [as 别名]
def make_robotics_env(env_id, seed, rank=0):
"""
Create a wrapped, monitored gym.Env for MuJoCo.
"""
set_global_seeds(seed)
env = gym.make(env_id)
env = FlattenObservation(FilterObservation(env, ['observation', 'desired_goal']))
env = Monitor(
env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
info_keywords=('is_success',))
env.seed(seed)
return env
示例4: _make_flat
# 需要导入模块: from gym import wrappers [as 别名]
# 或者: from gym.wrappers import FlattenObservation [as 别名]
def _make_flat(*args, **kargs):
if "FlattenDictWrapper" in dir():
return FlattenDictWrapper(*args, **kargs)
return FlattenObservation(FilterObservation(*args, **kargs))
示例5: make_env
# 需要导入模块: from gym import wrappers [as 别名]
# 或者: from gym.wrappers import FlattenObservation [as 别名]
def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, env_kwargs=None, logger_dir=None, initializer=None):
if initializer is not None:
initializer(mpi_rank=mpi_rank, subrank=subrank)
wrapper_kwargs = wrapper_kwargs or {}
env_kwargs = env_kwargs or {}
if ':' in env_id:
import re
import importlib
module_name = re.sub(':.*','',env_id)
env_id = re.sub('.*:', '', env_id)
importlib.import_module(module_name)
if env_type == 'atari':
env = make_atari(env_id)
elif env_type == 'retro':
import retro
gamestate = gamestate or retro.State.DEFAULT
env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate)
else:
env = gym.make(env_id, **env_kwargs)
if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
env = FlattenObservation(env)
env.seed(seed + subrank if seed is not None else None)
env = Monitor(env,
logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
allow_early_resets=True)
if env_type == 'atari':
env = wrap_deepmind(env, **wrapper_kwargs)
elif env_type == 'retro':
if 'frame_stack' not in wrapper_kwargs:
wrapper_kwargs['frame_stack'] = 1
env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)
if isinstance(env.action_space, gym.spaces.Box):
env = ClipActionsWrapper(env)
if reward_scale != 1:
env = retro_wrappers.RewardScaler(env, reward_scale)
return env