当前位置: 首页>>代码示例>>Python>>正文


Python tree.map_structure方法代码示例

本文整理汇总了Python中tree.map_structure方法的典型用法代码示例。如果您正苦于以下问题:Python tree.map_structure方法的具体用法?Python tree.map_structure怎么用?Python tree.map_structure使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tree的用法示例。


在下文中一共展示了tree.map_structure方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: convert_to_non_torch_type

# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def convert_to_non_torch_type(stats):
    """Converts values in `stats` to non-Tensor numpy or python types.

    Args:
        stats (any): Any (possibly nested) struct, the values in which will be
            converted and returned as a new struct with all torch tensors
            being converted to numpy types.

    Returns:
        Any: A new struct with the same structure as `stats`, but with all
            values converted to non-torch Tensor types.
    """

    # The mapping function used to numpyize torch Tensors.
    def mapping(item):
        if isinstance(item, torch.Tensor):
            return item.cpu().item() if len(item.size()) == 0 else \
                item.cpu().detach().numpy()
        else:
            return item

    return tree.map_structure(mapping, stats) 
开发者ID:ray-project,项目名称:ray,代码行数:24,代码来源:torch_ops.py

示例2: clip_action

# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def clip_action(action, action_space):
    """Clips all actions in `flat_actions` according to the given Spaces.

    Args:
        flat_actions (List[np.ndarray]): The (flattened) list of single action
            components. List will have len=1 for "primitive" action Spaces.
        flat_space (List[Space]): The (flattened) list of single action Space
            objects. Has to be of same length as `flat_actions`.

    Returns:
        List[np.ndarray]: Flattened list of single clipped "primitive" actions.
    """

    def map_(a, s):
        if isinstance(s, gym.spaces.Box):
            a = np.clip(a, s.low, s.high)
        return a

    return tree.map_structure(map_, action, action_space) 
开发者ID:ray-project,项目名称:ray,代码行数:21,代码来源:policy.py

示例3: testMapStructureWithStrings

# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def testMapStructureWithStrings(self):
    ab_tuple = collections.namedtuple("ab_tuple", "a, b")
    inp_a = ab_tuple(a="foo", b=("bar", "baz"))
    inp_b = ab_tuple(a=2, b=(1, 3))
    out = tree.map_structure(lambda string, repeats: string * repeats,
                             inp_a,
                             inp_b)
    self.assertEqual("foofoo", out.a)
    self.assertEqual("bar", out.b[0])
    self.assertEqual("bazbazbaz", out.b[1])

    nt = ab_tuple(a=("something", "something_else"),
                  b="yet another thing")
    rev_nt = tree.map_structure(lambda x: x[::-1], nt)
    # Check the output is the correct structure, and all strings are reversed.
    tree.assert_same_structure(nt, rev_nt)
    self.assertEqual(nt.a[0][::-1], rev_nt.a[0])
    self.assertEqual(nt.a[1][::-1], rev_nt.a[1])
    self.assertEqual(nt.b[::-1], rev_nt.b) 
开发者ID:deepmind,项目名称:tree,代码行数:21,代码来源:tree_test.py

示例4: nest_to_numpy

# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def nest_to_numpy(nest_of_tensors):
  """Converts a nest of eager tensors to a nest of numpy arrays.

  Leaves non-`tf.Tensor` elements untouched.

  A common use case for this method is to transform a `graphs.GraphsTuple` of
  tensors into a `graphs.GraphsTuple` of arrays, or nests containing
  `graphs.GraphsTuple`s.

  Args:
    nest_of_tensors: Nest containing `tf.Tensor`s.

  Returns:
    A nest with the same structure where `tf.Tensor`s are replaced by numpy
    arrays and all other elements are kept the same.
  """
  return tree.map_structure(
      lambda x: x.numpy() if isinstance(x, tf.Tensor) else x,
      nest_of_tensors) 
开发者ID:deepmind,项目名称:graph_nets,代码行数:21,代码来源:utils_tf.py

示例5: make_action

# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def make_action(self):
    """Returns a single action conforming to the environment's action_spec()."""
    spec = self.environment.action_spec()
    return tree.map_structure(lambda s: s.generate_value(), spec) 
开发者ID:deepmind,项目名称:dm_env,代码行数:6,代码来源:test_utils.py

示例6: convert_to_torch_tensor

# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def convert_to_torch_tensor(stats, device=None):
    """Converts any struct to torch.Tensors.

    stats (any): Any (possibly nested) struct, the values in which will be
        converted and returned as a new struct with all leaves converted
        to torch tensors.

    Returns:
        Any: A new struct with the same structure as `stats`, but with all
            values converted to torch Tensor types.
    """

    def mapping(item):
        # Already torch tensor -> make sure it's on right device.
        if torch.is_tensor(item):
            return item if device is None else item.to(device)
        # Special handling of "Repeated" values.
        elif isinstance(item, RepeatedValues):
            return RepeatedValues(
                tree.map_structure(mapping, item.values),
                item.lengths, item.max_len)
        tensor = torch.from_numpy(np.asarray(item))
        # Floatify all float64 tensors.
        if tensor.dtype == torch.double:
            tensor = tensor.float()
        return tensor if device is None else tensor.to(device)

    return tree.map_structure(mapping, stats) 
开发者ID:ray-project,项目名称:ray,代码行数:30,代码来源:torch_ops.py

示例7: __init__

# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def __init__(self, inputs, model, *, child_distributions, input_lens,
                 action_space):
        ActionDistribution.__init__(self, inputs, model)

        self.action_space_struct = get_base_struct_from_space(action_space)

        input_lens = np.array(input_lens, dtype=np.int32)
        split_inputs = tf.split(inputs, input_lens, axis=1)
        self.flat_child_distributions = tree.map_structure(
            lambda dist, input_: dist(input_, model), child_distributions,
            split_inputs) 
开发者ID:ray-project,项目名称:ray,代码行数:13,代码来源:tf_action_dist.py

示例8: sample

# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def sample(self):
        child_distributions = tree.unflatten_as(self.action_space_struct,
                                                self.flat_child_distributions)
        return tree.map_structure(lambda s: s.sample(), child_distributions) 
开发者ID:ray-project,项目名称:ray,代码行数:6,代码来源:tf_action_dist.py

示例9: deterministic_sample

# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def deterministic_sample(self):
        child_distributions = tree.unflatten_as(self.action_space_struct,
                                                self.flat_child_distributions)
        return tree.map_structure(lambda s: s.deterministic_sample(),
                                  child_distributions) 
开发者ID:ray-project,项目名称:ray,代码行数:7,代码来源:tf_action_dist.py

示例10: __init__

# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def __init__(self, inputs, model, *, child_distributions, input_lens,
                 action_space):
        """Initializes a TorchMultiActionDistribution object.

        Args:
            inputs (torch.Tensor): A single tensor of shape [BATCH, size].
            model (ModelV2): The ModelV2 object used to produce inputs for this
                distribution.
            child_distributions (any[torch.Tensor]): Any struct
                that contains the child distribution classes to use to
                instantiate the child distributions from `inputs`. This could
                be an already flattened list or a struct according to
                `action_space`.
            input_lens (any[int]): A flat list or a nested struct of input
                split lengths used to split `inputs`.
            action_space (Union[gym.spaces.Dict,gym.spaces.Tuple]): The complex
                and possibly nested action space.
        """
        if not isinstance(inputs, torch.Tensor):
            inputs = torch.Tensor(inputs)
        super().__init__(inputs, model)

        self.action_space_struct = get_base_struct_from_space(action_space)

        input_lens = tree.flatten(input_lens)
        flat_child_distributions = tree.flatten(child_distributions)
        split_inputs = torch.split(inputs, input_lens, dim=1)
        self.flat_child_distributions = tree.map_structure(
            lambda dist, input_: dist(input_, model), flat_child_distributions,
            list(split_inputs)) 
开发者ID:ray-project,项目名称:ray,代码行数:32,代码来源:torch_action_dist.py

示例11: logp

# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def logp(self, x):
        if isinstance(x, np.ndarray):
            x = torch.Tensor(x)
        # Single tensor input (all merged).
        if isinstance(x, torch.Tensor):
            split_indices = []
            for dist in self.flat_child_distributions:
                if isinstance(dist, TorchCategorical):
                    split_indices.append(1)
                else:
                    split_indices.append(dist.sample().size()[1])
            split_x = list(torch.split(x, split_indices, dim=1))
        # Structured or flattened (by single action component) input.
        else:
            split_x = tree.flatten(x)

        def map_(val, dist):
            # Remove extra categorical dimension.
            if isinstance(dist, TorchCategorical):
                val = torch.squeeze(val, dim=-1).int()
            return dist.logp(val)

        # Remove extra categorical dimension and take the logp of each
        # component.
        flat_logps = tree.map_structure(map_, split_x,
                                        self.flat_child_distributions)

        return functools.reduce(lambda a, b: a + b, flat_logps) 
开发者ID:ray-project,项目名称:ray,代码行数:30,代码来源:torch_action_dist.py

示例12: compute_actions

# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def compute_actions(self,
                        observation,
                        add_noise=False,
                        update=True,
                        **kwargs):
        # Batch is given as list of one.
        if isinstance(observation, list) and len(observation) == 1:
            observation = observation[0]
        observation = self.preprocessor.transform(observation)
        observation = self.observation_filter(observation[None], update=update)
        # `actions` is a list of (component) batches.
        # Eager mode.
        if not self.sess:
            dist_inputs, _ = self.model({SampleBatch.CUR_OBS: observation})
            dist = self.dist_class(dist_inputs, self.model)
            actions = dist.sample()
            actions = tree.map_structure(lambda a: a.numpy(), actions)
        # Graph mode.
        else:
            actions = self.sess.run(
                self.sampler, feed_dict={self.inputs: observation})

        if add_noise:
            actions = tree.map_structure(self._add_noise, actions,
                                         self.action_space_struct)
        # Convert `flat_actions` to a list of lists of action components
        # (list of single actions).
        actions = unbatch(actions)
        return actions 
开发者ID:ray-project,项目名称:ray,代码行数:31,代码来源:es_tf_policy.py

示例13: logp

# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def logp(self, x):
        # Single tensor input (all merged).
        if isinstance(x, (tf.Tensor, np.ndarray)):
            split_indices = []
            for dist in self.flat_child_distributions:
                if isinstance(dist, Categorical):
                    split_indices.append(1)
                else:
                    split_indices.append(tf.shape(dist.sample())[1])
            split_x = tf.split(x, split_indices, axis=1)
        # Structured or flattened (by single action component) input.
        else:
            split_x = tree.flatten(x)

        def map_(val, dist):
            # Remove extra categorical dimension.
            if isinstance(dist, Categorical):
                val = tf.cast(tf.squeeze(val, axis=-1), tf.int32)
            return dist.logp(val)

        # Remove extra categorical dimension and take the logp of each
        # component.
        flat_logps = tree.map_structure(map_, split_x,
                                        self.flat_child_distributions)

        return functools.reduce(lambda a, b: a + b, flat_logps) 
开发者ID:ray-project,项目名称:ray,代码行数:28,代码来源:tf_action_dist.py

示例14: update

# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def update(
      self,
      timestep: dm_env.TimeStep,
      action: base.Action,
      new_timestep: dm_env.TimeStep,
  ):
    """Receives a transition and performs a learning update."""

    self._buffer.append(timestep, action, new_timestep)

    # When the batch is full, do a step of SGD.
    if self._buffer.full() or new_timestep.last():
      trajectory = self._buffer.drain()
      trajectory = tree.map_structure(tf.convert_to_tensor, trajectory)
      self._step(trajectory) 
开发者ID:deepmind,项目名称:bsuite,代码行数:17,代码来源:agent.py

示例15: update

# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def update(
      self,
      timestep: dm_env.TimeStep,
      action: base.Action,
      new_timestep: dm_env.TimeStep,
  ):
    """Receives a transition and performs a learning update."""
    self._buffer.append(timestep, action, new_timestep)

    if self._buffer.full() or new_timestep.last():
      trajectory = self._buffer.drain()
      trajectory = tree.map_structure(tf.convert_to_tensor, trajectory)
      self._rollout_initial_state = self._step(trajectory) 
开发者ID:deepmind,项目名称:bsuite,代码行数:15,代码来源:agent.py


注:本文中的tree.map_structure方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。