当前位置: 首页>>代码示例>>Python>>正文


Python gym.envs方法代码示例

本文整理汇总了Python中gym.envs方法的典型用法代码示例。如果您正苦于以下问题:Python gym.envs方法的具体用法?Python gym.envs怎么用?Python gym.envs使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在gym的用法示例。


在下文中一共展示了gym.envs方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: make_vec_envs

# 需要导入模块: import gym [as 别名]
# 或者: from gym import envs [as 别名]
def make_vec_envs(env_name, seed,  num_processes, num_frame_stack=1, downsample=True, color=False, gamma=0.99, log_dir='./tmp/', device=torch.device('cpu')):
    Path(log_dir).mkdir(parents=True, exist_ok=True)
    envs = [make_env(env_name, seed, i, log_dir, downsample, color)
            for i in range(num_processes)]

    if len(envs) > 1:
        envs = SubprocVecEnv(envs, context='fork')
    else:
        envs = DummyVecEnv(envs)

    if len(envs.observation_space.shape) == 1:
        if gamma is None:
            envs = VecNormalize(envs, ret=False)
        else:
            envs = VecNormalize(envs, gamma=gamma)

    envs = VecPyTorch(envs, device)

    if num_frame_stack > 1:
        envs = VecPyTorchFrameStack(envs, num_frame_stack, device)

    return envs 
开发者ID:mila-iqia,项目名称:atari-representation-learning,代码行数:24,代码来源:envs.py

示例2: should_skip_env_spec_for_tests

# 需要导入模块: import gym [as 别名]
# 或者: from gym import envs [as 别名]
def should_skip_env_spec_for_tests(spec):
    # We skip tests for envs that require dependencies or are otherwise
    # troublesome to run frequently
    ep = spec._entry_point
    # Skip mujoco tests for pull request CI
    skip_mujoco = not (os.environ.get('MUJOCO_KEY_BUNDLE') or os.path.exists(os.path.expanduser('~/.mujoco')))
    if skip_mujoco and ep.startswith('gym.envs.mujoco:'):
        return True
    if (    'GoEnv' in ep or
            'HexEnv' in ep or
            ep.startswith('gym.envs.box2d:') or
            ep.startswith('gym.envs.box2d:') or
            (ep.startswith("gym.envs.atari") and not spec.id.startswith("Pong") and not spec.id.startswith("Seaquest"))
    ):
        logger.warn("Skipping tests for env {}".format(ep))
        return True
    return False 
开发者ID:ArztSamuel,项目名称:DRL_DeliveryDuel,代码行数:19,代码来源:spec_list.py

示例3: add_new_rollouts

# 需要导入模块: import gym [as 别名]
# 或者: from gym import envs [as 别名]
def add_new_rollouts(spec_ids, overwrite):
    environments = [spec for spec in envs.registry.all() if spec._entry_point is not None]
    if spec_ids:
        environments = [spec for spec in environments if spec.id in spec_ids]
        assert len(environments) == len(spec_ids), "Some specs not found"
    with open(ROLLOUT_FILE) as data_file:
        rollout_dict = json.load(data_file)
    modified = False
    for spec in environments:
        if not overwrite and spec.id in rollout_dict:
            logger.debug("Rollout already exists for {}. Skipping.".format(spec.id))
        else:
            modified = update_rollout_dict(spec, rollout_dict) or modified

    if modified:
        logger.info("Writing new rollout file to {}".format(ROLLOUT_FILE))
        with open(ROLLOUT_FILE, "w") as outfile:
            json.dump(rollout_dict, outfile, indent=2, sort_keys=True)
    else:
        logger.info("No modifications needed.") 
开发者ID:ArztSamuel,项目名称:DRL_DeliveryDuel,代码行数:22,代码来源:generate_json.py

示例4: test_default_time_limit

# 需要导入模块: import gym [as 别名]
# 或者: from gym import envs [as 别名]
def test_default_time_limit():
    # We need an env without a default limit
    register(
        id='test.NoLimitDummyVNCEnv-v0',
        entry_point='universe.envs:DummyVNCEnv',
        tags={
            'vnc': True,
            },
    )

    env = gym.make('test.NoLimitDummyVNCEnv-v0')
    env.configure(_n=1)
    env = wrappers.TimeLimit(env)
    env.reset()

    assert env._max_episode_seconds == wrappers.time_limit.DEFAULT_MAX_EPISODE_SECONDS
    assert env._max_episode_steps == None 
开发者ID:openai,项目名称:universe,代码行数:19,代码来源:test_time_limit.py

示例5: get_gym_stats

# 需要导入模块: import gym [as 别名]
# 或者: from gym import envs [as 别名]
def get_gym_stats():
    """Return a pandas DataFrame of the environment IDs."""
    df = []
    for e in gym.envs.registry.all():
        print(e.id)
        df.append(env_stats(gym.make(e.id)))
    cols = [
        "id",
        "continuous_actions",
        "continuous_observations",
        "action_dim",
        #  "action_ids",
        "deterministic",
        "multidim_actions",
        "multidim_observations",
        "n_actions_per_dim",
        "n_obs_per_dim",
        "obs_dim",
        #  "obs_ids",
        "seed",
        "tuple_actions",
        "tuple_observations",
    ]
    return df if NO_PD else pd.DataFrame(df)[cols] 
开发者ID:ddbourgin,项目名称:numpy-ml,代码行数:26,代码来源:rl_utils.py

示例6: should_skip_env_spec_for_tests

# 需要导入模块: import gym [as 别名]
# 或者: from gym import envs [as 别名]
def should_skip_env_spec_for_tests(spec):
    # We skip tests for envs that require dependencies or are otherwise
    # troublesome to run frequently
    ep = spec.entry_point
    # Skip mujoco tests for pull request CI
    if skip_mujoco and (ep.startswith('gym.envs.mujoco') or ep.startswith('gym.envs.robotics:')):
        return True
    try:
        import atari_py
    except ImportError:
        if ep.startswith('gym.envs.atari'):
            return True
    try:
        import Box2D
    except ImportError:
        if ep.startswith('gym.envs.box2d'):
            return True

    if (    'GoEnv' in ep or
            'HexEnv' in ep or
            (ep.startswith("gym.envs.atari") and not spec.id.startswith("Pong") and not spec.id.startswith("Seaquest"))
    ):
        logger.warn("Skipping tests for env {}".format(ep))
        return True
    return False 
开发者ID:hust512,项目名称:DQN-DDPG_Stock_Trading,代码行数:27,代码来源:spec_list.py

示例7: step_wait

# 需要导入模块: import gym [as 别名]
# 或者: from gym import envs [as 别名]
def step_wait(self):
        obs = []
        rews = []
        dones = []
        infos = []

        for i in range(self.num_envs):
            obs_tuple, reward, done, info = self.envs[i].step(self.actions[i])
            if done:
                obs_tuple = self.envs[i].reset()
            obs.append(obs_tuple)
            rews.append(reward)
            dones.append(done)
            infos.append(info)

        return np.stack(obs), np.stack(rews), np.stack(dones), infos 
开发者ID:justinglibert,项目名称:bezos,代码行数:18,代码来源:envs.py

示例8: register

# 需要导入模块: import gym [as 别名]
# 或者: from gym import envs [as 别名]
def register(id, **kwargs):
    """Idempotent version of gym.envs.registration.registry.

    Needed since aprl.envs can get imported multiple times, e.g. when deserializing policies.
    """
    try:
        existing_spec = registration.spec(id)
        new_spec = registration.EnvSpec(id, **kwargs)
        assert existing_spec.__dict__ == new_spec.__dict__
    except gym.error.UnregisteredEnv:  # not previously registered
        registration.register(id, **kwargs)


# Low-dimensional multi-agent environments 
开发者ID:HumanCompatibleAI,项目名称:adversarial-policies,代码行数:16,代码来源:__init__.py

示例9: make_env

# 需要导入模块: import gym [as 别名]
# 或者: from gym import envs [as 别名]
def make_env(env_id, seed, rank, log_dir, downsample=True, color=False):
    def _thunk():
        env = gym.make(env_id)

        is_atari = hasattr(gym.envs, 'atari') and isinstance(
            env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
        if is_atari:
            env = make_atari(env_id)
            env = AtariARIWrapper(env)

        env.seed(seed + rank)


        if str(env.__class__.__name__).find('TimeLimit') >= 0:
            env = TimeLimitMask(env)

        if log_dir is not None:
            env = bench.Monitor(
                env,
                os.path.join(log_dir, str(rank)),
                allow_early_resets=False)

        if is_atari:
            if len(env.observation_space.shape) == 3:
                env = wrap_deepmind(env, downsample=downsample, color=color)
        elif len(env.observation_space.shape) == 3:
            raise NotImplementedError(
                "CNN models work only for atari,\n"
                "please use a custom wrapper for a custom pixel input env.\n"
                "See wrap_deepmind for an example.")

        # If the input has shape (W,H,3), wrap for PyTorch convolutions
        obs_shape = env.observation_space.shape
        if len(obs_shape) == 3 and obs_shape[2] in [1, 3]:
            env = TransposeImage(env, op=[2, 0, 1])

        return env

    return _thunk 
开发者ID:mila-iqia,项目名称:atari-representation-learning,代码行数:41,代码来源:envs.py

示例10: __init__

# 需要导入模块: import gym [as 别名]
# 或者: from gym import envs [as 别名]
def __init__(self, env_name, record_video=True, video_schedule=None, log_dir=None, record_log=True,
                 force_reset=False):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)
        self.env = env
        self.env_id = env.spec.id

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        logger.log("observation space: {}".format(self._observation_space))
        self._action_space = convert_gym_space(env.action_space)
        logger.log("action space: {}".format(self._action_space))
        self._horizon = env.spec.tags['wrapper_config.TimeLimit.max_episode_steps']
        self._log_dir = log_dir
        self._force_reset = force_reset 
开发者ID:nosyndicate,项目名称:pytorchrl,代码行数:35,代码来源:gym_env.py

示例11: is_mujoco_env

# 需要导入模块: import gym [as 别名]
# 或者: from gym import envs [as 别名]
def is_mujoco_env(env):
    from gym.envs import mujoco
    if not hasattr(env, "env"):
        return False
    return gym.envs.mujoco.mujoco_env.MujocoEnv in env.env.__class__.__bases__ 
开发者ID:keiohta,项目名称:tf2rl,代码行数:7,代码来源:utils.py

示例12: is_atari_env

# 需要导入模块: import gym [as 别名]
# 或者: from gym import envs [as 别名]
def is_atari_env(env):
    from gym.envs import atari
    if not hasattr(env, "env"):
        return False
    return gym.envs.atari.atari_env.AtariEnv == env.env.__class__ 
开发者ID:keiohta,项目名称:tf2rl,代码行数:7,代码来源:utils.py

示例13: make_env

# 需要导入模块: import gym [as 别名]
# 或者: from gym import envs [as 别名]
def make_env(env_id, seed, rank, episode_life=True):
    def _thunk():
        random_seed(seed)
        if env_id.startswith("dm"):
            import dm_control2gym
            _, domain, task = env_id.split('-')
            env = dm_control2gym.make(domain_name=domain, task_name=task)
        else:
            env = gym.make(env_id)
        is_atari = hasattr(gym.envs, 'atari') and isinstance(
            env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
        if is_atari:
            env = make_atari(env_id)
        env.seed(seed + rank)
        env = OriginalReturnWrapper(env)
        if is_atari:
            env = wrap_deepmind(env,
                                episode_life=episode_life,
                                clip_rewards=False,
                                frame_stack=False,
                                scale=False)
            obs_shape = env.observation_space.shape
            if len(obs_shape) == 3:
                env = TransposeImage(env)
            env = FrameStack(env, 4)

        return env

    return _thunk 
开发者ID:ShangtongZhang,项目名称:DeepRL,代码行数:31,代码来源:envs.py

示例14: __init__

# 需要导入模块: import gym [as 别名]
# 或者: from gym import envs [as 别名]
def __init__(self, env_fns):
        self.envs = [fn() for fn in env_fns]
        env = self.envs[0]
        VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
        self.actions = None 
开发者ID:ShangtongZhang,项目名称:DeepRL,代码行数:7,代码来源:envs.py

示例15: step_wait

# 需要导入模块: import gym [as 别名]
# 或者: from gym import envs [as 别名]
def step_wait(self):
        data = []
        for i in range(self.num_envs):
            obs, rew, done, info = self.envs[i].step(self.actions[i])
            if done:
                obs = self.envs[i].reset()
            data.append([obs, rew, done, info])
        obs, rew, done, info = zip(*data)
        return obs, np.asarray(rew), np.asarray(done), info 
开发者ID:ShangtongZhang,项目名称:DeepRL,代码行数:11,代码来源:envs.py


注:本文中的gym.envs方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。