本文整理汇总了Python中tree.map_structure方法的典型用法代码示例。如果您正苦于以下问题:Python tree.map_structure方法的具体用法?Python tree.map_structure怎么用?Python tree.map_structure使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tree
的用法示例。
在下文中一共展示了tree.map_structure方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: convert_to_non_torch_type
# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def convert_to_non_torch_type(stats):
"""Converts values in `stats` to non-Tensor numpy or python types.
Args:
stats (any): Any (possibly nested) struct, the values in which will be
converted and returned as a new struct with all torch tensors
being converted to numpy types.
Returns:
Any: A new struct with the same structure as `stats`, but with all
values converted to non-torch Tensor types.
"""
# The mapping function used to numpyize torch Tensors.
def mapping(item):
if isinstance(item, torch.Tensor):
return item.cpu().item() if len(item.size()) == 0 else \
item.cpu().detach().numpy()
else:
return item
return tree.map_structure(mapping, stats)
示例2: clip_action
# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def clip_action(action, action_space):
"""Clips all actions in `flat_actions` according to the given Spaces.
Args:
flat_actions (List[np.ndarray]): The (flattened) list of single action
components. List will have len=1 for "primitive" action Spaces.
flat_space (List[Space]): The (flattened) list of single action Space
objects. Has to be of same length as `flat_actions`.
Returns:
List[np.ndarray]: Flattened list of single clipped "primitive" actions.
"""
def map_(a, s):
if isinstance(s, gym.spaces.Box):
a = np.clip(a, s.low, s.high)
return a
return tree.map_structure(map_, action, action_space)
示例3: testMapStructureWithStrings
# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def testMapStructureWithStrings(self):
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
inp_a = ab_tuple(a="foo", b=("bar", "baz"))
inp_b = ab_tuple(a=2, b=(1, 3))
out = tree.map_structure(lambda string, repeats: string * repeats,
inp_a,
inp_b)
self.assertEqual("foofoo", out.a)
self.assertEqual("bar", out.b[0])
self.assertEqual("bazbazbaz", out.b[1])
nt = ab_tuple(a=("something", "something_else"),
b="yet another thing")
rev_nt = tree.map_structure(lambda x: x[::-1], nt)
# Check the output is the correct structure, and all strings are reversed.
tree.assert_same_structure(nt, rev_nt)
self.assertEqual(nt.a[0][::-1], rev_nt.a[0])
self.assertEqual(nt.a[1][::-1], rev_nt.a[1])
self.assertEqual(nt.b[::-1], rev_nt.b)
示例4: nest_to_numpy
# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def nest_to_numpy(nest_of_tensors):
"""Converts a nest of eager tensors to a nest of numpy arrays.
Leaves non-`tf.Tensor` elements untouched.
A common use case for this method is to transform a `graphs.GraphsTuple` of
tensors into a `graphs.GraphsTuple` of arrays, or nests containing
`graphs.GraphsTuple`s.
Args:
nest_of_tensors: Nest containing `tf.Tensor`s.
Returns:
A nest with the same structure where `tf.Tensor`s are replaced by numpy
arrays and all other elements are kept the same.
"""
return tree.map_structure(
lambda x: x.numpy() if isinstance(x, tf.Tensor) else x,
nest_of_tensors)
示例5: make_action
# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def make_action(self):
"""Returns a single action conforming to the environment's action_spec()."""
spec = self.environment.action_spec()
return tree.map_structure(lambda s: s.generate_value(), spec)
示例6: convert_to_torch_tensor
# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def convert_to_torch_tensor(stats, device=None):
"""Converts any struct to torch.Tensors.
stats (any): Any (possibly nested) struct, the values in which will be
converted and returned as a new struct with all leaves converted
to torch tensors.
Returns:
Any: A new struct with the same structure as `stats`, but with all
values converted to torch Tensor types.
"""
def mapping(item):
# Already torch tensor -> make sure it's on right device.
if torch.is_tensor(item):
return item if device is None else item.to(device)
# Special handling of "Repeated" values.
elif isinstance(item, RepeatedValues):
return RepeatedValues(
tree.map_structure(mapping, item.values),
item.lengths, item.max_len)
tensor = torch.from_numpy(np.asarray(item))
# Floatify all float64 tensors.
if tensor.dtype == torch.double:
tensor = tensor.float()
return tensor if device is None else tensor.to(device)
return tree.map_structure(mapping, stats)
示例7: __init__
# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def __init__(self, inputs, model, *, child_distributions, input_lens,
action_space):
ActionDistribution.__init__(self, inputs, model)
self.action_space_struct = get_base_struct_from_space(action_space)
input_lens = np.array(input_lens, dtype=np.int32)
split_inputs = tf.split(inputs, input_lens, axis=1)
self.flat_child_distributions = tree.map_structure(
lambda dist, input_: dist(input_, model), child_distributions,
split_inputs)
示例8: sample
# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def sample(self):
child_distributions = tree.unflatten_as(self.action_space_struct,
self.flat_child_distributions)
return tree.map_structure(lambda s: s.sample(), child_distributions)
示例9: deterministic_sample
# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def deterministic_sample(self):
child_distributions = tree.unflatten_as(self.action_space_struct,
self.flat_child_distributions)
return tree.map_structure(lambda s: s.deterministic_sample(),
child_distributions)
示例10: __init__
# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def __init__(self, inputs, model, *, child_distributions, input_lens,
action_space):
"""Initializes a TorchMultiActionDistribution object.
Args:
inputs (torch.Tensor): A single tensor of shape [BATCH, size].
model (ModelV2): The ModelV2 object used to produce inputs for this
distribution.
child_distributions (any[torch.Tensor]): Any struct
that contains the child distribution classes to use to
instantiate the child distributions from `inputs`. This could
be an already flattened list or a struct according to
`action_space`.
input_lens (any[int]): A flat list or a nested struct of input
split lengths used to split `inputs`.
action_space (Union[gym.spaces.Dict,gym.spaces.Tuple]): The complex
and possibly nested action space.
"""
if not isinstance(inputs, torch.Tensor):
inputs = torch.Tensor(inputs)
super().__init__(inputs, model)
self.action_space_struct = get_base_struct_from_space(action_space)
input_lens = tree.flatten(input_lens)
flat_child_distributions = tree.flatten(child_distributions)
split_inputs = torch.split(inputs, input_lens, dim=1)
self.flat_child_distributions = tree.map_structure(
lambda dist, input_: dist(input_, model), flat_child_distributions,
list(split_inputs))
示例11: logp
# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def logp(self, x):
if isinstance(x, np.ndarray):
x = torch.Tensor(x)
# Single tensor input (all merged).
if isinstance(x, torch.Tensor):
split_indices = []
for dist in self.flat_child_distributions:
if isinstance(dist, TorchCategorical):
split_indices.append(1)
else:
split_indices.append(dist.sample().size()[1])
split_x = list(torch.split(x, split_indices, dim=1))
# Structured or flattened (by single action component) input.
else:
split_x = tree.flatten(x)
def map_(val, dist):
# Remove extra categorical dimension.
if isinstance(dist, TorchCategorical):
val = torch.squeeze(val, dim=-1).int()
return dist.logp(val)
# Remove extra categorical dimension and take the logp of each
# component.
flat_logps = tree.map_structure(map_, split_x,
self.flat_child_distributions)
return functools.reduce(lambda a, b: a + b, flat_logps)
示例12: compute_actions
# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def compute_actions(self,
observation,
add_noise=False,
update=True,
**kwargs):
# Batch is given as list of one.
if isinstance(observation, list) and len(observation) == 1:
observation = observation[0]
observation = self.preprocessor.transform(observation)
observation = self.observation_filter(observation[None], update=update)
# `actions` is a list of (component) batches.
# Eager mode.
if not self.sess:
dist_inputs, _ = self.model({SampleBatch.CUR_OBS: observation})
dist = self.dist_class(dist_inputs, self.model)
actions = dist.sample()
actions = tree.map_structure(lambda a: a.numpy(), actions)
# Graph mode.
else:
actions = self.sess.run(
self.sampler, feed_dict={self.inputs: observation})
if add_noise:
actions = tree.map_structure(self._add_noise, actions,
self.action_space_struct)
# Convert `flat_actions` to a list of lists of action components
# (list of single actions).
actions = unbatch(actions)
return actions
示例13: logp
# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def logp(self, x):
# Single tensor input (all merged).
if isinstance(x, (tf.Tensor, np.ndarray)):
split_indices = []
for dist in self.flat_child_distributions:
if isinstance(dist, Categorical):
split_indices.append(1)
else:
split_indices.append(tf.shape(dist.sample())[1])
split_x = tf.split(x, split_indices, axis=1)
# Structured or flattened (by single action component) input.
else:
split_x = tree.flatten(x)
def map_(val, dist):
# Remove extra categorical dimension.
if isinstance(dist, Categorical):
val = tf.cast(tf.squeeze(val, axis=-1), tf.int32)
return dist.logp(val)
# Remove extra categorical dimension and take the logp of each
# component.
flat_logps = tree.map_structure(map_, split_x,
self.flat_child_distributions)
return functools.reduce(lambda a, b: a + b, flat_logps)
示例14: update
# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def update(
self,
timestep: dm_env.TimeStep,
action: base.Action,
new_timestep: dm_env.TimeStep,
):
"""Receives a transition and performs a learning update."""
self._buffer.append(timestep, action, new_timestep)
# When the batch is full, do a step of SGD.
if self._buffer.full() or new_timestep.last():
trajectory = self._buffer.drain()
trajectory = tree.map_structure(tf.convert_to_tensor, trajectory)
self._step(trajectory)
示例15: update
# 需要导入模块: import tree [as 别名]
# 或者: from tree import map_structure [as 别名]
def update(
self,
timestep: dm_env.TimeStep,
action: base.Action,
new_timestep: dm_env.TimeStep,
):
"""Receives a transition and performs a learning update."""
self._buffer.append(timestep, action, new_timestep)
if self._buffer.full() or new_timestep.last():
trajectory = self._buffer.drain()
trajectory = tree.map_structure(tf.convert_to_tensor, trajectory)
self._rollout_initial_state = self._step(trajectory)