根据 indices
将 updates
分散到形状为 shape
的张量中。
用法
tf.scatter_nd(
indices, updates, shape, name=None
)
参数
-
indices
一个Tensor
。必须是以下类型之一:int32
,int64
。 index 的张量。 -
updates
一个Tensor
。分散到输出张量中的值。 -
shape
一个Tensor
。必须与indices
具有相同的类型。一维。输出张量的形状。 -
name
操作的名称(可选)。
返回
-
一个
Tensor
。具有与updates
相同的类型。
根据指定的 indices
的各个值,通过散布稀疏的 updates
来更新输入张量。此操作返回带有您指定的 shape
的 output
张量。此操作是 tf.gather_nd
运算符的逆运算,它从给定的张量中提取值或切片。
此操作类似于 tf.tensor_scatter_add
,不同之处在于张量是零初始化的。调用 tf.scatter_nd(indices, values, shape)
与调用 tf.tensor_scatter_add(tf.zeros(shape, values.dtype), indices, values)
相同。
如果indices
包含重复项,则将重复项values
累加(求和)。
警告:应用更新的顺序是不确定的,因此如果 indices
包含重复项,则输出将是不确定的;由于某些数值近似问题,以不同顺序求和的数字可能会产生不同的结果。
indices
是形状为 shape
的整数张量。 indices
的最后一个维度最多可以是 shape
的等级:
indices.shape[-1] <= shape.rank
indices
的最后一个维度对应于 shape
的维度 indices.shape[-1]
的元素索引(如果 indices.shape[-1] = shape.rank
)或切片(如果 indices.shape[-1] < shape.rank
)。
updates
是一个具有形状的张量:
indices.shape[:-1] + shape[indices.shape[-1]:]
scatter 操作的最简单形式是按索引在张量中插入单个元素。考虑一个示例,您希望在具有 8 个元素的 rank-1 张量中插入 4 个分散元素。
在 Python 中,这个分散操作看起来像这样:
indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
shape = tf.constant([8])
scatter = tf.scatter_nd(indices, updates, shape)
print(scatter)
生成的张量如下所示:
[0, 11, 0, 10, 9, 0, 0, 12]
您还可以一次插入更高等级张量的整个切片。例如,您可以在具有两个新值矩阵的 rank-3 张量的第一个维度中插入两个切片。
在 Python 中,这个分散操作看起来像这样:
indices = tf.constant([[0], [2]])
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]],
[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]]])
shape = tf.constant([4, 4, 4])
scatter = tf.scatter_nd(indices, updates, shape)
print(scatter)
生成的张量如下所示:
[[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
请注意,在 CPU 上,如果发现超出范围的索引,则会返回错误。在 GPU 上,如果发现超出范围的索引,则忽略该索引。
相关用法
- Python tf.scan用法及代码示例
- Python tf.summary.scalar用法及代码示例
- Python tf.strings.substr用法及代码示例
- Python tf.strings.reduce_join用法及代码示例
- Python tf.sparse.cross用法及代码示例
- Python tf.sparse.mask用法及代码示例
- Python tf.strings.regex_full_match用法及代码示例
- Python tf.sparse.split用法及代码示例
- Python tf.strings.regex_replace用法及代码示例
- Python tf.signal.overlap_and_add用法及代码示例
- Python tf.strings.length用法及代码示例
- Python tf.strided_slice用法及代码示例
- Python tf.sparse.to_dense用法及代码示例
- Python tf.strings.bytes_split用法及代码示例
- Python tf.summary.text用法及代码示例
- Python tf.shape用法及代码示例
- Python tf.sparse.expand_dims用法及代码示例
- Python tf.signal.frame用法及代码示例
- Python tf.sparse.maximum用法及代码示例
- Python tf.signal.linear_to_mel_weight_matrix用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.scatter_nd。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。