本文整理汇总了Python中tree.flatten方法的典型用法代码示例。如果您正苦于以下问题:Python tree.flatten方法的具体用法?Python tree.flatten怎么用?Python tree.flatten使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tree
的用法示例。
在下文中一共展示了tree.flatten方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _get_tf_exploration_action_op
# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def _get_tf_exploration_action_op(self, action_dist, explore):
sample = action_dist.sample()
deterministic_sample = action_dist.deterministic_sample()
action = tf.cond(
tf.constant(explore) if isinstance(explore, bool) else explore,
true_fn=lambda: sample,
false_fn=lambda: deterministic_sample)
def logp_false_fn():
batch_size = tf.shape(tree.flatten(action)[0])[0]
return tf.zeros(shape=(batch_size, ), dtype=tf.float32)
logp = tf.cond(
tf.constant(explore) if isinstance(explore, bool) else explore,
true_fn=lambda: action_dist.sampled_action_logp(),
false_fn=logp_false_fn)
return action, logp
示例2: testAttrsFlattenAndUnflatten
# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def testAttrsFlattenAndUnflatten(self):
class BadAttr(object):
"""Class that has a non-iterable __attrs_attrs__."""
__attrs_attrs__ = None
@attr.s
class SampleAttr(object):
field1 = attr.ib()
field2 = attr.ib()
field_values = [1, 2]
sample_attr = SampleAttr(*field_values)
self.assertFalse(tree._is_attrs(field_values))
self.assertTrue(tree._is_attrs(sample_attr))
flat = tree.flatten(sample_attr)
self.assertEqual(field_values, flat)
restructured_from_flat = tree.unflatten_as(sample_attr, flat)
self.assertIsInstance(restructured_from_flat, SampleAttr)
self.assertEqual(restructured_from_flat, sample_attr)
# Check that flatten fails if attributes are not iterable
with self.assertRaisesRegex(TypeError, "object is not iterable"):
flat = tree.flatten(BadAttr())
示例3: testGradients
# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def testGradients(self, is_multi_actions):
self._setUpLoss(is_multi_actions)
with self.test_session() as sess:
total_loss = tf.reduce_sum(self._loss)
gradients = tf.gradients(
[total_loss], nest.flatten(self._policy_logits_nest))
grad_policy_logits_nest = sess.run(gradients)
for grad_policy_logits in grad_policy_logits_nest:
self.assertAllClose(grad_policy_logits,
[[[0, 0], [-0.731, 0.731]],
[[1, -1], [0, 0]]], atol=1e-4)
dead_grads = tf.gradients(
[total_loss],
nest.flatten(self._actions_nest) + [self._action_values])
for grad in dead_grads:
self.assertIsNone(grad)
示例4: testPolicyGradients
# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def testPolicyGradients(self, is_multi_actions):
if is_multi_actions:
loss = self.multi_op.extra.policy_gradient_loss
policy_logits_nest = self.multi_policy_logits
else:
loss = self.op.extra.policy_gradient_loss
policy_logits_nest = self.policy_logits
grad_policy_list = [
tf.gradients(loss, policy_logits)[0] * self.num_actions
for policy_logits in nest.flatten(policy_logits_nest)]
for grad_policy in grad_policy_list:
self.assertEqual(grad_policy.get_shape(), tf.TensorShape([2, 1, 3]))
self.assertAllEqual(tf.gradients(loss, self.baseline_values), [None])
self.assertAllEqual(tf.gradients(loss, self.invalid_grad_inputs),
self.invalid_grad_outputs)
示例5: testEntropyGradients
# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def testEntropyGradients(self, is_multi_actions):
if is_multi_actions:
loss = self.multi_op.extra.entropy_loss
policy_logits_nest = self.multi_policy_logits
else:
loss = self.op.extra.entropy_loss
policy_logits_nest = self.policy_logits
grad_policy_list = [
tf.gradients(loss, policy_logits)[0] * self.num_actions
for policy_logits in nest.flatten(policy_logits_nest)]
for grad_policy in grad_policy_list:
self.assertEqual(grad_policy.get_shape(), tf.TensorShape([2, 1, 3]))
self.assertAllEqual(tf.gradients(loss, self.baseline_values), [None])
self.assertAllEqual(tf.gradients(loss, self.invalid_grad_inputs),
self.invalid_grad_outputs)
示例6: test_nested_structure
# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def test_nested_structure(self):
regular_graph = self._graph
graph_with_nested_fields = regular_graph.map(
lambda x: {"a": x, "b": tf.random.uniform([4, 6])})
nested_structure = [
None,
regular_graph,
(graph_with_nested_fields,),
tf.random.uniform([10, 6])]
nested_structure_numpy = utils_tf.nest_to_numpy(nested_structure)
tree.assert_same_structure(nested_structure, nested_structure_numpy)
for tensor_or_none, array_or_none in zip(
tree.flatten(nested_structure),
tree.flatten(nested_structure_numpy)):
if tensor_or_none is None:
self.assertIsNone(array_or_none)
continue
self.assertIsNotNone(array_or_none)
self.assertNDArrayNear(
tensor_or_none.numpy(),
array_or_none, 1e-8)
示例7: step
# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def step(self, action):
self.steps += 1
action = tree.flatten(action)
reward = 0.0
for a, o, space in zip(action, self.current_obs_flattened,
self.flattened_action_space):
# Box: -abs(diff).
if isinstance(space, gym.spaces.Box):
reward -= np.abs(np.sum(a - o))
# Discrete: +1.0 if exact match.
if isinstance(space, gym.spaces.Discrete):
reward += 1.0 if a == o else 0.0
done = self.steps >= self.episode_len
return self._next_obs(), reward, done, {}
示例8: _next_obs
# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def _next_obs(self):
self.current_obs = self.observation_space.sample()
self.current_obs_flattened = tree.flatten(self.current_obs)
return self.current_obs
示例9: flatten_space
# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def flatten_space(space):
"""Flattens a gym.Space into its primitive components.
Primitive components are any non Tuple/Dict spaces.
Args:
space(gym.Space): The gym.Space to flatten. This may be any
supported type (including nested Tuples and Dicts).
Returns:
List[gym.Space]: The flattened list of primitive Spaces. This list
does not contain Tuples or Dicts anymore.
"""
def _helper_flatten(space_, l):
from ray.rllib.utils.spaces.flexdict import FlexDict
if isinstance(space_, Tuple):
for s in space_:
_helper_flatten(s, l)
elif isinstance(space_, (Dict, FlexDict)):
for k in space_.spaces:
_helper_flatten(space_[k], l)
else:
l.append(space_)
ret = []
_helper_flatten(space, ret)
return ret
示例10: flatten_to_single_ndarray
# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def flatten_to_single_ndarray(input_):
"""Returns a single np.ndarray given a list/tuple of np.ndarrays.
Args:
input_ (Union[List[np.ndarray],np.ndarray]): The list of ndarrays or
a single ndarray.
Returns:
np.ndarray: The result after concatenating all single arrays in input_.
Examples:
>>> flatten_to_single_ndarray([
>>> np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
>>> np.array([7, 8, 9]),
>>> ])
>>> # Will return:
>>> # np.array([
>>> # 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0
>>> # ])
"""
# Concatenate complex inputs.
if isinstance(input_, (list, tuple, dict)):
expanded = []
for in_ in tree.flatten(input_):
expanded.append(np.reshape(in_, [-1]))
input_ = np.concatenate(expanded, axis=0).flatten()
return input_
示例11: unbatch
# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def unbatch(batches_struct):
"""Converts input from (nested) struct of batches to batch of structs.
Input: Struct of different batches (each batch has size=3):
{"a": [1, 2, 3], "b": ([4, 5, 6], [7.0, 8.0, 9.0])}
Output: Batch (list) of structs (each of these structs representing a
single action):
[
{"a": 1, "b": (4, 7.0)}, <- action 1
{"a": 2, "b": (5, 8.0)}, <- action 2
{"a": 3, "b": (6, 9.0)}, <- action 3
]
Args:
batches_struct (any): The struct of component batches. Each leaf item
in this struct represents the batch for a single component
(in case struct is tuple/dict).
Alternatively, `batches_struct` may also simply be a batch of
primitives (non tuple/dict).
Returns:
List[struct[components]]: The list of rows. Each item
in the returned list represents a single (maybe complex) struct.
"""
flat_batches = tree.flatten(batches_struct)
out = []
for batch_pos in range(len(flat_batches[0])):
out.append(
tree.unflatten_as(
batches_struct,
[flat_batches[i][batch_pos]
for i in range(len(flat_batches))]))
return out
示例12: logp
# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def logp(self, x):
# Single tensor input (all merged).
if isinstance(x, (tf.Tensor, np.ndarray)):
split_indices = []
for dist in self.flat_child_distributions:
if isinstance(dist, Categorical):
split_indices.append(1)
else:
split_indices.append(tf.shape(dist.sample())[1])
split_x = tf.split(x, split_indices, axis=1)
# Structured or flattened (by single action component) input.
else:
split_x = tree.flatten(x)
def map_(val, dist):
# Remove extra categorical dimension.
if isinstance(dist, Categorical):
val = tf.cast(tf.squeeze(val, axis=-1), tf.int32)
return dist.logp(val)
# Remove extra categorical dimension and take the logp of each
# component.
flat_logps = tree.map_structure(map_, split_x,
self.flat_child_distributions)
return functools.reduce(lambda a, b: a + b, flat_logps)
示例13: __init__
# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def __init__(self, inputs, model, *, child_distributions, input_lens,
action_space):
"""Initializes a TorchMultiActionDistribution object.
Args:
inputs (torch.Tensor): A single tensor of shape [BATCH, size].
model (ModelV2): The ModelV2 object used to produce inputs for this
distribution.
child_distributions (any[torch.Tensor]): Any struct
that contains the child distribution classes to use to
instantiate the child distributions from `inputs`. This could
be an already flattened list or a struct according to
`action_space`.
input_lens (any[int]): A flat list or a nested struct of input
split lengths used to split `inputs`.
action_space (Union[gym.spaces.Dict,gym.spaces.Tuple]): The complex
and possibly nested action space.
"""
if not isinstance(inputs, torch.Tensor):
inputs = torch.Tensor(inputs)
super().__init__(inputs, model)
self.action_space_struct = get_base_struct_from_space(action_space)
input_lens = tree.flatten(input_lens)
flat_child_distributions = tree.flatten(child_distributions)
split_inputs = torch.split(inputs, input_lens, dim=1)
self.flat_child_distributions = tree.map_structure(
lambda dist, input_: dist(input_, model), flat_child_distributions,
list(split_inputs))
示例14: logp
# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def logp(self, x):
if isinstance(x, np.ndarray):
x = torch.Tensor(x)
# Single tensor input (all merged).
if isinstance(x, torch.Tensor):
split_indices = []
for dist in self.flat_child_distributions:
if isinstance(dist, TorchCategorical):
split_indices.append(1)
else:
split_indices.append(dist.sample().size()[1])
split_x = list(torch.split(x, split_indices, dim=1))
# Structured or flattened (by single action component) input.
else:
split_x = tree.flatten(x)
def map_(val, dist):
# Remove extra categorical dimension.
if isinstance(dist, TorchCategorical):
val = torch.squeeze(val, dim=-1).int()
return dist.logp(val)
# Remove extra categorical dimension and take the logp of each
# component.
flat_logps = tree.map_structure(map_, split_x,
self.flat_child_distributions)
return functools.reduce(lambda a, b: a + b, flat_logps)
示例15: _step
# 需要导入模块: import tree [as 别名]
# 或者: from tree import flatten [as 别名]
def _step(self, transitions: Sequence[tf.Tensor]):
"""Does a step of SGD for the whole ensemble over `transitions`."""
o_tm1, a_tm1, r_t, d_t, o_t, m_t, z_t = transitions
variables = tree.flatten(
[model.trainable_variables for model in self._ensemble])
with tf.GradientTape() as tape:
losses = []
for k in range(self._num_ensemble):
net = self._ensemble[k]
target_net = self._target_ensemble[k]
# Q-learning loss with added reward noise + half-in bootstrap.
q_values = net(o_tm1)
one_hot_actions = tf.one_hot(a_tm1, depth=self._num_actions)
train_value = tf.reduce_sum(q_values * one_hot_actions, axis=-1)
target_value = tf.stop_gradient(tf.reduce_max(target_net(o_t), axis=-1))
target_y = r_t + z_t[:, k] + self._discount * d_t * target_value
loss = tf.square(train_value - target_y) * m_t[:, k]
losses.append(loss)
loss = tf.reduce_mean(tf.stack(losses))
gradients = tape.gradient(loss, variables)
self._total_steps.assign_add(1)
self._optimizer.apply(gradients, variables)
# Periodically update the target network.
if tf.math.mod(self._total_steps, self._target_update_period) == 0:
for k in range(self._num_ensemble):
for src, dest in zip(self._ensemble[k].variables,
self._target_ensemble[k].variables):
dest.assign(src)