当前位置: 首页>>代码示例>>Python>>正文


Python spaces.MultiDiscrete方法代码示例

本文整理汇总了Python中gym.spaces.MultiDiscrete方法的典型用法代码示例。如果您正苦于以下问题:Python spaces.MultiDiscrete方法的具体用法?Python spaces.MultiDiscrete怎么用?Python spaces.MultiDiscrete使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在gym.spaces的用法示例。


在下文中一共展示了spaces.MultiDiscrete方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from gym import spaces [as 别名]
# 或者: from gym.spaces import MultiDiscrete [as 别名]
def __init__(self, env, body_names, radius_multiplier=1.7,
                 grab_dist=None, grab_exclusive=False,
                 obj_in_game_metadata_keys=None):
        super().__init__(env)
        self.n_agents = self.unwrapped.n_agents
        self.body_names = body_names
        self.n_obj = len(body_names)
        self.obj_in_game_metadata_keys = obj_in_game_metadata_keys
        self.action_space.spaces['action_pull'] = (
            Tuple([MultiDiscrete([2] * self.n_obj) for _ in range(self.n_agents)]))

        self.observation_space = update_obs_space(
            env, {'obj_pull': (self.n_obj, 1),
                  'you_pull': (self.n_obj, self.n_agents)})

        self.grab_radius = radius_multiplier * self.metadata['box_size']
        self.grab_dist = grab_dist
        self.grab_exclusive = grab_exclusive 
开发者ID:openai,项目名称:multi-agent-emergence-environments,代码行数:20,代码来源:manipulation.py

示例2: get_action_type

# 需要导入模块: from gym import spaces [as 别名]
# 或者: from gym.spaces import MultiDiscrete [as 别名]
def get_action_type(action_space):
    '''Method to get the action type to choose prob. dist. to sample actions from NN logits output'''
    if isinstance(action_space, spaces.Box):
        shape = action_space.shape
        assert len(shape) == 1
        if shape[0] == 1:
            return 'continuous'
        else:
            return 'multi_continuous'
    elif isinstance(action_space, spaces.Discrete):
        return 'discrete'
    elif isinstance(action_space, spaces.MultiDiscrete):
        return 'multi_discrete'
    elif isinstance(action_space, spaces.MultiBinary):
        return 'multi_binary'
    else:
        raise NotImplementedError


# action_policy base methods 
开发者ID:ConvLab,项目名称:ConvLab,代码行数:22,代码来源:policy_util.py

示例3: set_gym_space_attr

# 需要导入模块: from gym import spaces [as 别名]
# 或者: from gym.spaces import MultiDiscrete [as 别名]
def set_gym_space_attr(gym_space):
    '''Set missing gym space attributes for standardization'''
    if isinstance(gym_space, spaces.Box):
        setattr(gym_space, 'is_discrete', False)
    elif isinstance(gym_space, spaces.Discrete):
        setattr(gym_space, 'is_discrete', True)
        setattr(gym_space, 'low', 0)
        setattr(gym_space, 'high', gym_space.n)
    elif isinstance(gym_space, spaces.MultiBinary):
        setattr(gym_space, 'is_discrete', True)
        setattr(gym_space, 'low', np.full(gym_space.n, 0))
        setattr(gym_space, 'high', np.full(gym_space.n, 2))
    elif isinstance(gym_space, spaces.MultiDiscrete):
        setattr(gym_space, 'is_discrete', True)
        setattr(gym_space, 'low', np.zeros_like(gym_space.nvec))
        setattr(gym_space, 'high', np.array(gym_space.nvec))
    else:
        raise ValueError('gym_space not recognized') 
开发者ID:ConvLab,项目名称:ConvLab,代码行数:20,代码来源:base.py

示例4: pick_action

# 需要导入模块: from gym import spaces [as 别名]
# 或者: from gym.spaces import MultiDiscrete [as 别名]
def pick_action(self, state: Union[int, float, np.ndarray]
                    ) -> Union[int, float, np.ndarray]:
        """ Pick an action given a state.

        Picks uniformly random from all possible actions, using the environments
        action_space.sample() method.

        Parameters
        ----------
        state: int
            An integer corresponding to a state of a DiscreteEnv.
            Not used in this agent.

        Returns
        -------
        Union[int, float, np.ndarray]
            An action
        """
        # if other spaces are needed, check if their sample method conforms with
        # returned type, change if necessary.
        assert isinstance(self.env.action_space,
                          (Box, Discrete, MultiDiscrete, MultiBinary))
        return self.env.action_space.sample() 
开发者ID:JohannesHeidecke,项目名称:irl-benchmark,代码行数:25,代码来源:random_agent.py

示例5: encode_observation

# 需要导入模块: from gym import spaces [as 别名]
# 或者: from gym.spaces import MultiDiscrete [as 别名]
def encode_observation(ob_space, placeholder):
    '''
    Encode input in the way that is appropriate to the observation space

    Parameters:
    ----------

    ob_space: gym.Space             observation space

    placeholder: tf.placeholder     observation input placeholder
    '''
    if isinstance(ob_space, Discrete):
        return tf.to_float(tf.one_hot(placeholder, ob_space.n))
    elif isinstance(ob_space, Box):
        return tf.to_float(placeholder)
    elif isinstance(ob_space, MultiDiscrete):
        placeholder = tf.cast(placeholder, tf.int32)
        one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
        return tf.concat(one_hots, axis=-1)
    else:
        raise NotImplementedError 
开发者ID:hiwonjoon,项目名称:ICML2019-TREX,代码行数:23,代码来源:input.py

示例6: __init__

# 需要导入模块: from gym import spaces [as 别名]
# 或者: from gym.spaces import MultiDiscrete [as 别名]
def __init__(self, env, eat_thresh=0.5, max_food_health=10, respawn_time=np.inf,
                 food_rew_type='selfish', reward_scale=1.0, reward_scale_obs=False):
        super().__init__(env)
        self.eat_thresh = eat_thresh
        self.max_food_health = max_food_health
        self.respawn_time = respawn_time
        self.food_rew_type = food_rew_type
        self.n_agents = self.metadata['n_agents']

        if type(reward_scale) not in [list, tuple, np.ndarray]:
            reward_scale = [reward_scale, reward_scale]
        self.reward_scale = reward_scale
        self.reward_scale_obs = reward_scale_obs

        # Reset obs/action space to match
        self.max_n_food = self.metadata['max_n_food']
        self.curr_n_food = self.metadata['curr_n_food']
        self.max_food_size = self.metadata['food_size']
        food_dim = 5 if self.reward_scale_obs else 4
        self.observation_space = update_obs_space(self.env, {'food_obs': (self.max_n_food, food_dim),
                                                             'food_health': (self.max_n_food, 1),
                                                             'food_eat': (self.max_n_food, 1)})
        self.action_space.spaces['action_eat_food'] = Tuple([MultiDiscrete([2] * self.max_n_food)
                                                             for _ in range(self.n_agents)]) 
开发者ID:openai,项目名称:multi-agent-emergence-environments,代码行数:26,代码来源:food.py

示例7: _detect_gym_spaces

# 需要导入模块: from gym import spaces [as 别名]
# 或者: from gym.spaces import MultiDiscrete [as 别名]
def _detect_gym_spaces(gym_space):
        if isinstance(gym_space, spaces.Discrete):
            return {"Discrete": (gym_space.n,)}
        elif isinstance(gym_space, spaces.MultiDiscrete):
            raise NotImplementedError
        elif isinstance(gym_space, spaces.MultiBinary):
            return {"MultiBinary": (gym_space.n,)}
        elif isinstance(gym_space, spaces.Box):
            return {"Box": gym_space.shape}
        elif isinstance(gym_space, spaces.Dict):
            return {
                name: list(Space._detect_gym_spaces(s).values())[0]
                for name, s in gym_space.spaces.items()
            }
        elif isinstance(gym_space, spaces.Tuple):
            return {
                idx: list(Space._detect_gym_spaces(s).values())[0]
                for idx, s in enumerate(gym_space.spaces)
            } 
开发者ID:heronsystems,项目名称:adeptRL,代码行数:21,代码来源:_spaces.py

示例8: dtypes_from_gym

# 需要导入模块: from gym import spaces [as 别名]
# 或者: from gym.spaces import MultiDiscrete [as 别名]
def dtypes_from_gym(gym_space):
        if isinstance(gym_space, spaces.Discrete):
            return {"Discrete": gym_space.dtype}
        elif isinstance(gym_space, spaces.MultiDiscrete):
            raise NotImplementedError
        elif isinstance(gym_space, spaces.MultiBinary):
            return {"MultiBinary": gym_space.dtype}
        elif isinstance(gym_space, spaces.Box):
            return {"Box": gym_space.dtype}
        elif isinstance(gym_space, spaces.Dict):
            return {
                name: list(Space._detect_gym_spaces(s).values())[0]
                for name, s in gym_space.spaces.items()
            }
        elif isinstance(gym_space, spaces.Tuple):
            return {
                idx: list(Space._detect_gym_spaces(s).values())[0]
                for idx, s in enumerate(gym_space.spaces)
            }
        else:
            raise NotImplementedError 
开发者ID:heronsystems,项目名称:adeptRL,代码行数:23,代码来源:_spaces.py

示例9: gym_space_distribution

# 需要导入模块: from gym import spaces [as 别名]
# 或者: from gym.spaces import MultiDiscrete [as 别名]
def gym_space_distribution(space):
    """
    Create a Distribution from a gym.Space.

    If the space is not supported, throws an
    UnsupportedActionSpace exception.
    """
    if isinstance(space, spaces.Discrete):
        return CategoricalSoftmax(space.n)
    elif isinstance(space, spaces.Box):
        return BoxGaussian(space.low, space.high)
    elif isinstance(space, spaces.MultiBinary):
        return MultiBernoulli(space.n)
    elif isinstance(space, spaces.Tuple):
        sub_dists = tuple(gym_space_distribution(s) for s in space.spaces)
        return TupleDistribution(sub_dists)
    elif isinstance(space, spaces.MultiDiscrete):
        discretes = tuple(CategoricalSoftmax(n) for n in space.nvec)
        return TupleDistribution(discretes, to_sample=lambda x: np.array(x, dtype=space.dtype))
    raise UnsupportedGymSpace(space) 
开发者ID:flyyufelix,项目名称:sonic_contest,代码行数:22,代码来源:gym.py

示例10: __init__

# 需要导入模块: from gym import spaces [as 别名]
# 或者: from gym.spaces import MultiDiscrete [as 别名]
def __init__(self, env_config):
        self.state = None
        self.agent_1 = 0
        self.agent_2 = 1
        # MADDPG emits action logits instead of actual discrete actions
        self.actions_are_logits = env_config.get("actions_are_logits", False)
        self.one_hot_state_encoding = env_config.get("one_hot_state_encoding",
                                                     False)
        self.with_state = env_config.get("separate_state_space", False)

        if not self.one_hot_state_encoding:
            self.observation_space = Discrete(6)
            self.with_state = False
        else:
            # Each agent gets the full state (one-hot encoding of which of the
            # three states are active) as input with the receiving agent's
            # ID (1 or 2) concatenated onto the end.
            if self.with_state:
                self.observation_space = Dict({
                    "obs": MultiDiscrete([2, 2, 2, 3]),
                    ENV_STATE: MultiDiscrete([2, 2, 2])
                })
            else:
                self.observation_space = MultiDiscrete([2, 2, 2, 3]) 
开发者ID:ray-project,项目名称:ray,代码行数:26,代码来源:two_step_game.py

示例11: make_default_action_extractor

# 需要导入模块: from gym import spaces [as 别名]
# 或者: from gym.spaces import MultiDiscrete [as 别名]
def make_default_action_extractor(env: Env):
    """ Returns the default action extractor for the environment """
    action_space = env.action_space
    if isinstance(action_space, spaces.Discrete):
        # Canonical rule to return one-hot encoded actions for discrete
        return discrete_action_extractor
    elif isinstance(action_space, spaces.MultiDiscrete):
        return multi_discrete_action_extractor
    elif isinstance(action_space, spaces.Box):
        # Canonical rule to scale actions to CONTINUOUS_TRAINING_ACTION_RANGE
        return make_box_action_extractor(action_space)
    else:
        raise NotImplementedError(f"Unsupport action space: {action_space}")


#######################################
### Default obs preprocessors.
### These should operate on single obs.
####################################### 
开发者ID:facebookresearch,项目名称:ReAgent,代码行数:21,代码来源:default_preprocessors.py

示例12: test_step_with_bigger_slate

# 需要导入模块: from gym import spaces [as 别名]
# 或者: from gym.spaces import MultiDiscrete [as 别名]
def test_step_with_bigger_slate(self):
    # Initialize agent.
    slate_size = 5
    num_candidates = 5
    action_space = spaces.MultiDiscrete(num_candidates * np.ones((slate_size,)))
    agent = cluster_bandit_agent.ClusterBanditAgent(
        self.dummy_observation_space(), action_space)

    # Create a set of documents
    document_sampler = ie.IETopicDocumentSampler(seed=1)
    documents = {}
    for i in range(num_candidates):
      video = document_sampler.sample_document()
      documents[i] = video.create_observation()

    # Past observation shows Topic 1 is better.
    user_obs = np.array([1, 1, 0, 1])
    sufficient_stats_observation = self.doc_user_to_sufficient_stats(
        documents, user_obs)
    slate = agent.step(0, sufficient_stats_observation)
    # Documents in Topic 0 sorted by quality: 1, 2.
    # Documents in Topic 1 sorted by quality: 0, 4, 3.
    self.assertAllEqual(slate, [0, 4, 3, 1, 2]) 
开发者ID:google-research,项目名称:recsim,代码行数:25,代码来源:cluster_bandit_agent_test.py

示例13: test_bundle_and_unbundle

# 需要导入模块: from gym import spaces [as 别名]
# 或者: from gym.spaces import MultiDiscrete [as 别名]
def test_bundle_and_unbundle(self):
    # Initialize agent
    slate_size = 2
    num_candidates = 5
    action_space = spaces.MultiDiscrete(num_candidates * np.ones((slate_size,)))

    agent = cluster_bandit_agent.ClusterBanditAgent(
        self.dummy_observation_space(), action_space)

    # Create a set of documents
    document_sampler = ie.IETopicDocumentSampler()
    documents = {}
    for i in range(num_candidates):
      video = document_sampler.sample_document()
      documents[i] = video.create_observation()

    # Test that slate indices in correct range and length is correct
    sufficient_stats_observation = self.doc_user_to_sufficient_stats(
        documents, np.array([0, 0, 0, 0]))

    agent.step(1, sufficient_stats_observation)

    bundle_dict = agent.bundle_and_checkpoint('', 0)
    self.assertTrue(agent.unbundle('', 0, bundle_dict))
    self.assertEqual(bundle_dict, agent.bundle_and_checkpoint('', 0)) 
开发者ID:google-research,项目名称:recsim,代码行数:27,代码来源:cluster_bandit_agent_test.py

示例14: test_step

# 需要导入模块: from gym import spaces [as 别名]
# 或者: from gym.spaces import MultiDiscrete [as 别名]
def test_step(self):
    # Create a simple user
    slate_size = 2
    user_model = iev.IEvUserModel(
        slate_size,
        choice_model_ctor=choice_model.MultinomialLogitChoiceModel,
        response_model_ctor=iev.IEvResponse)

    # Create a candidate_set with 5 items
    num_candidates = 5
    document_sampler = iev.IEvVideoSampler()
    ievsim = environment.Environment(user_model, document_sampler,
                                     num_candidates, slate_size)

    # Create agent
    action_space = spaces.MultiDiscrete(num_candidates * np.ones((slate_size,)))
    agent = random_agent.RandomAgent(action_space, random_seed=0)

    # This agent doesn't use the previous user response
    observation, documents = ievsim.reset()
    slate = agent.step(1, dict(user=observation, doc=documents))
    self.assertAllEqual(slate, [2, 0]) 
开发者ID:google-research,项目名称:recsim,代码行数:24,代码来源:random_agent_test.py

示例15: test_slate_indices_and_length

# 需要导入模块: from gym import spaces [as 别名]
# 或者: from gym.spaces import MultiDiscrete [as 别名]
def test_slate_indices_and_length(self):
    # Initialize agent
    slate_size = 2
    num_candidates = 100
    action_space = spaces.MultiDiscrete(num_candidates * np.ones((slate_size,)))

    user_model = iev.IEvUserModel(
        slate_size,
        choice_model_ctor=choice_model.MultinomialLogitChoiceModel,
        response_model_ctor=iev.IEvResponse)
    agent = random_agent.RandomAgent(action_space, random_seed=0)

    # Create a set of documents
    document_sampler = iev.IEvVideoSampler()
    ievenv = environment.Environment(user_model, document_sampler,
                                     num_candidates, slate_size)

    # Test that slate indices in correct range and length is correct
    observation, documents = ievenv.reset()
    slate = agent.step(1, dict(user=observation, doc=documents))
    self.assertLen(slate, slate_size)
    self.assertAllInSet(slate, range(num_candidates)) 
开发者ID:google-research,项目名称:recsim,代码行数:24,代码来源:random_agent_test.py


注:本文中的gym.spaces.MultiDiscrete方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。