当前位置: 首页>>代码示例>>Python>>正文


Python tree.flatten方法代码示例

本文整理汇总了Python中tree.flatten方法的典型用法代码示例。如果您正苦于以下问题:Python tree.flatten方法的具体用法?Python tree.flatten怎么用?Python tree.flatten使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tree的用法示例。


在下文中一共展示了tree.flatten方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _get_tf_exploration_action_op

# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def _get_tf_exploration_action_op(self, action_dist, explore):
        sample = action_dist.sample()
        deterministic_sample = action_dist.deterministic_sample()
        action = tf.cond(
            tf.constant(explore) if isinstance(explore, bool) else explore,
            true_fn=lambda: sample,
            false_fn=lambda: deterministic_sample)

        def logp_false_fn():
            batch_size = tf.shape(tree.flatten(action)[0])[0]
            return tf.zeros(shape=(batch_size, ), dtype=tf.float32)

        logp = tf.cond(
            tf.constant(explore) if isinstance(explore, bool) else explore,
            true_fn=lambda: action_dist.sampled_action_logp(),
            false_fn=logp_false_fn)

        return action, logp 
开发者ID:ray-project,项目名称:ray,代码行数:20,代码来源:stochastic_sampling.py

示例2: testAttrsFlattenAndUnflatten

# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def testAttrsFlattenAndUnflatten(self):

    class BadAttr(object):
      """Class that has a non-iterable __attrs_attrs__."""
      __attrs_attrs__ = None

    @attr.s
    class SampleAttr(object):
      field1 = attr.ib()
      field2 = attr.ib()

    field_values = [1, 2]
    sample_attr = SampleAttr(*field_values)
    self.assertFalse(tree._is_attrs(field_values))
    self.assertTrue(tree._is_attrs(sample_attr))
    flat = tree.flatten(sample_attr)
    self.assertEqual(field_values, flat)
    restructured_from_flat = tree.unflatten_as(sample_attr, flat)
    self.assertIsInstance(restructured_from_flat, SampleAttr)
    self.assertEqual(restructured_from_flat, sample_attr)

    # Check that flatten fails if attributes are not iterable
    with self.assertRaisesRegex(TypeError, "object is not iterable"):
      flat = tree.flatten(BadAttr()) 
开发者ID:deepmind,项目名称:tree,代码行数:26,代码来源:tree_test.py

示例3: testGradients

# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def testGradients(self, is_multi_actions):
    self._setUpLoss(is_multi_actions)
    with self.test_session() as sess:
      total_loss = tf.reduce_sum(self._loss)
      gradients = tf.gradients(
          [total_loss], nest.flatten(self._policy_logits_nest))
      grad_policy_logits_nest = sess.run(gradients)
      for grad_policy_logits in grad_policy_logits_nest:
        self.assertAllClose(grad_policy_logits,
                            [[[0, 0], [-0.731, 0.731]],
                             [[1, -1], [0, 0]]], atol=1e-4)
      dead_grads = tf.gradients(
          [total_loss],
          nest.flatten(self._actions_nest) + [self._action_values])
      for grad in dead_grads:
        self.assertIsNone(grad) 
开发者ID:deepmind,项目名称:trfl,代码行数:18,代码来源:discrete_policy_gradient_ops_test.py

示例4: testPolicyGradients

# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def testPolicyGradients(self, is_multi_actions):
    if is_multi_actions:
      loss = self.multi_op.extra.policy_gradient_loss
      policy_logits_nest = self.multi_policy_logits
    else:
      loss = self.op.extra.policy_gradient_loss
      policy_logits_nest = self.policy_logits

    grad_policy_list = [
        tf.gradients(loss, policy_logits)[0] * self.num_actions
        for policy_logits in nest.flatten(policy_logits_nest)]

    for grad_policy in grad_policy_list:
      self.assertEqual(grad_policy.get_shape(), tf.TensorShape([2, 1, 3]))

    self.assertAllEqual(tf.gradients(loss, self.baseline_values), [None])
    self.assertAllEqual(tf.gradients(loss, self.invalid_grad_inputs),
                        self.invalid_grad_outputs) 
开发者ID:deepmind,项目名称:trfl,代码行数:20,代码来源:discrete_policy_gradient_ops_test.py

示例5: testEntropyGradients

# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def testEntropyGradients(self, is_multi_actions):
    if is_multi_actions:
      loss = self.multi_op.extra.entropy_loss
      policy_logits_nest = self.multi_policy_logits
    else:
      loss = self.op.extra.entropy_loss
      policy_logits_nest = self.policy_logits

    grad_policy_list = [
        tf.gradients(loss, policy_logits)[0] * self.num_actions
        for policy_logits in nest.flatten(policy_logits_nest)]

    for grad_policy in grad_policy_list:
      self.assertEqual(grad_policy.get_shape(), tf.TensorShape([2, 1, 3]))

    self.assertAllEqual(tf.gradients(loss, self.baseline_values), [None])
    self.assertAllEqual(tf.gradients(loss, self.invalid_grad_inputs),
                        self.invalid_grad_outputs) 
开发者ID:deepmind,项目名称:trfl,代码行数:20,代码来源:discrete_policy_gradient_ops_test.py

示例6: test_nested_structure

# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def test_nested_structure(self):
    regular_graph = self._graph
    graph_with_nested_fields = regular_graph.map(
        lambda x: {"a": x, "b": tf.random.uniform([4, 6])})

    nested_structure = [
        None,
        regular_graph,
        (graph_with_nested_fields,),
        tf.random.uniform([10, 6])]
    nested_structure_numpy = utils_tf.nest_to_numpy(nested_structure)

    tree.assert_same_structure(nested_structure, nested_structure_numpy)

    for tensor_or_none, array_or_none in zip(
        tree.flatten(nested_structure),
        tree.flatten(nested_structure_numpy)):
      if tensor_or_none is None:
        self.assertIsNone(array_or_none)
        continue

      self.assertIsNotNone(array_or_none)
      self.assertNDArrayNear(
          tensor_or_none.numpy(),
          array_or_none, 1e-8) 
开发者ID:deepmind,项目名称:graph_nets,代码行数:27,代码来源:utils_tf_test.py

示例7: step

# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def step(self, action):
        self.steps += 1
        action = tree.flatten(action)
        reward = 0.0
        for a, o, space in zip(action, self.current_obs_flattened,
                               self.flattened_action_space):
            # Box: -abs(diff).
            if isinstance(space, gym.spaces.Box):
                reward -= np.abs(np.sum(a - o))
            # Discrete: +1.0 if exact match.
            if isinstance(space, gym.spaces.Discrete):
                reward += 1.0 if a == o else 0.0
        done = self.steps >= self.episode_len
        return self._next_obs(), reward, done, {} 
开发者ID:ray-project,项目名称:ray,代码行数:16,代码来源:nested_space_repeat_after_me_env.py

示例8: _next_obs

# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def _next_obs(self):
        self.current_obs = self.observation_space.sample()
        self.current_obs_flattened = tree.flatten(self.current_obs)
        return self.current_obs 
开发者ID:ray-project,项目名称:ray,代码行数:6,代码来源:nested_space_repeat_after_me_env.py

示例9: flatten_space

# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def flatten_space(space):
    """Flattens a gym.Space into its primitive components.

    Primitive components are any non Tuple/Dict spaces.

    Args:
        space(gym.Space): The gym.Space to flatten. This may be any
            supported type (including nested Tuples and Dicts).

    Returns:
        List[gym.Space]: The flattened list of primitive Spaces. This list
            does not contain Tuples or Dicts anymore.
    """

    def _helper_flatten(space_, l):
        from ray.rllib.utils.spaces.flexdict import FlexDict
        if isinstance(space_, Tuple):
            for s in space_:
                _helper_flatten(s, l)
        elif isinstance(space_, (Dict, FlexDict)):
            for k in space_.spaces:
                _helper_flatten(space_[k], l)
        else:
            l.append(space_)

    ret = []
    _helper_flatten(space, ret)
    return ret 
开发者ID:ray-project,项目名称:ray,代码行数:30,代码来源:space_utils.py

示例10: flatten_to_single_ndarray

# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def flatten_to_single_ndarray(input_):
    """Returns a single np.ndarray given a list/tuple of np.ndarrays.

    Args:
        input_ (Union[List[np.ndarray],np.ndarray]): The list of ndarrays or
            a single ndarray.

    Returns:
        np.ndarray: The result after concatenating all single arrays in input_.

    Examples:
        >>> flatten_to_single_ndarray([
        >>>     np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
        >>>     np.array([7, 8, 9]),
        >>> ])
        >>> # Will return:
        >>> # np.array([
        >>> #     1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0
        >>> # ])
    """
    # Concatenate complex inputs.
    if isinstance(input_, (list, tuple, dict)):
        expanded = []
        for in_ in tree.flatten(input_):
            expanded.append(np.reshape(in_, [-1]))
        input_ = np.concatenate(expanded, axis=0).flatten()
    return input_ 
开发者ID:ray-project,项目名称:ray,代码行数:29,代码来源:space_utils.py

示例11: unbatch

# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def unbatch(batches_struct):
    """Converts input from (nested) struct of batches to batch of structs.

    Input: Struct of different batches (each batch has size=3):
        {"a": [1, 2, 3], "b": ([4, 5, 6], [7.0, 8.0, 9.0])}
    Output: Batch (list) of structs (each of these structs representing a
        single action):
        [
            {"a": 1, "b": (4, 7.0)},  <- action 1
            {"a": 2, "b": (5, 8.0)},  <- action 2
            {"a": 3, "b": (6, 9.0)},  <- action 3
        ]

    Args:
        batches_struct (any): The struct of component batches. Each leaf item
            in this struct represents the batch for a single component
            (in case struct is tuple/dict).
            Alternatively, `batches_struct` may also simply be a batch of
            primitives (non tuple/dict).

    Returns:
        List[struct[components]]: The list of rows. Each item
            in the returned list represents a single (maybe complex) struct.
    """
    flat_batches = tree.flatten(batches_struct)

    out = []
    for batch_pos in range(len(flat_batches[0])):
        out.append(
            tree.unflatten_as(
                batches_struct,
                [flat_batches[i][batch_pos]
                 for i in range(len(flat_batches))]))
    return out 
开发者ID:ray-project,项目名称:ray,代码行数:36,代码来源:space_utils.py

示例12: logp

# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def logp(self, x):
        # Single tensor input (all merged).
        if isinstance(x, (tf.Tensor, np.ndarray)):
            split_indices = []
            for dist in self.flat_child_distributions:
                if isinstance(dist, Categorical):
                    split_indices.append(1)
                else:
                    split_indices.append(tf.shape(dist.sample())[1])
            split_x = tf.split(x, split_indices, axis=1)
        # Structured or flattened (by single action component) input.
        else:
            split_x = tree.flatten(x)

        def map_(val, dist):
            # Remove extra categorical dimension.
            if isinstance(dist, Categorical):
                val = tf.cast(tf.squeeze(val, axis=-1), tf.int32)
            return dist.logp(val)

        # Remove extra categorical dimension and take the logp of each
        # component.
        flat_logps = tree.map_structure(map_, split_x,
                                        self.flat_child_distributions)

        return functools.reduce(lambda a, b: a + b, flat_logps) 
开发者ID:ray-project,项目名称:ray,代码行数:28,代码来源:tf_action_dist.py

示例13: __init__

# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def __init__(self, inputs, model, *, child_distributions, input_lens,
                 action_space):
        """Initializes a TorchMultiActionDistribution object.

        Args:
            inputs (torch.Tensor): A single tensor of shape [BATCH, size].
            model (ModelV2): The ModelV2 object used to produce inputs for this
                distribution.
            child_distributions (any[torch.Tensor]): Any struct
                that contains the child distribution classes to use to
                instantiate the child distributions from `inputs`. This could
                be an already flattened list or a struct according to
                `action_space`.
            input_lens (any[int]): A flat list or a nested struct of input
                split lengths used to split `inputs`.
            action_space (Union[gym.spaces.Dict,gym.spaces.Tuple]): The complex
                and possibly nested action space.
        """
        if not isinstance(inputs, torch.Tensor):
            inputs = torch.Tensor(inputs)
        super().__init__(inputs, model)

        self.action_space_struct = get_base_struct_from_space(action_space)

        input_lens = tree.flatten(input_lens)
        flat_child_distributions = tree.flatten(child_distributions)
        split_inputs = torch.split(inputs, input_lens, dim=1)
        self.flat_child_distributions = tree.map_structure(
            lambda dist, input_: dist(input_, model), flat_child_distributions,
            list(split_inputs)) 
开发者ID:ray-project,项目名称:ray,代码行数:32,代码来源:torch_action_dist.py

示例14: logp

# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def logp(self, x):
        if isinstance(x, np.ndarray):
            x = torch.Tensor(x)
        # Single tensor input (all merged).
        if isinstance(x, torch.Tensor):
            split_indices = []
            for dist in self.flat_child_distributions:
                if isinstance(dist, TorchCategorical):
                    split_indices.append(1)
                else:
                    split_indices.append(dist.sample().size()[1])
            split_x = list(torch.split(x, split_indices, dim=1))
        # Structured or flattened (by single action component) input.
        else:
            split_x = tree.flatten(x)

        def map_(val, dist):
            # Remove extra categorical dimension.
            if isinstance(dist, TorchCategorical):
                val = torch.squeeze(val, dim=-1).int()
            return dist.logp(val)

        # Remove extra categorical dimension and take the logp of each
        # component.
        flat_logps = tree.map_structure(map_, split_x,
                                        self.flat_child_distributions)

        return functools.reduce(lambda a, b: a + b, flat_logps) 
开发者ID:ray-project,项目名称:ray,代码行数:30,代码来源:torch_action_dist.py

示例15: _step

# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def _step(self, transitions: Sequence[tf.Tensor]):
    """Does a step of SGD for the whole ensemble over `transitions`."""
    o_tm1, a_tm1, r_t, d_t, o_t, m_t, z_t = transitions
    variables = tree.flatten(
        [model.trainable_variables for model in self._ensemble])
    with tf.GradientTape() as tape:
      losses = []
      for k in range(self._num_ensemble):
        net = self._ensemble[k]
        target_net = self._target_ensemble[k]

        # Q-learning loss with added reward noise + half-in bootstrap.
        q_values = net(o_tm1)
        one_hot_actions = tf.one_hot(a_tm1, depth=self._num_actions)
        train_value = tf.reduce_sum(q_values * one_hot_actions, axis=-1)
        target_value = tf.stop_gradient(tf.reduce_max(target_net(o_t), axis=-1))
        target_y = r_t + z_t[:, k] + self._discount * d_t * target_value
        loss = tf.square(train_value - target_y) * m_t[:, k]
        losses.append(loss)

      loss = tf.reduce_mean(tf.stack(losses))
      gradients = tape.gradient(loss, variables)
    self._total_steps.assign_add(1)
    self._optimizer.apply(gradients, variables)

    # Periodically update the target network.
    if tf.math.mod(self._total_steps, self._target_update_period) == 0:
      for k in range(self._num_ensemble):
        for src, dest in zip(self._ensemble[k].variables,
                             self._target_ensemble[k].variables):
          dest.assign(src) 
开发者ID:deepmind,项目名称:bsuite,代码行数:33,代码来源:agent.py


注:本文中的tree.flatten方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。