当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.scatter_nd用法及代码示例


根据 indicesupdates 分散到形状为 shape 的张量中。

用法

tf.scatter_nd(
    indices, updates, shape, name=None
)

参数

  • indices 一个Tensor。必须是以下类型之一:int32 , int64。 index 的张量。
  • updates 一个Tensor。分散到输出张量中的值。
  • shape 一个Tensor。必须与 indices 具有相同的类型。一维。输出张量的形状。
  • name 操作的名称(可选)。

返回

  • 一个Tensor。具有与 updates 相同的类型。

根据指定的 indices 的各个值,通过散布稀疏的 updates 来更新输入张量。此操作返回带有您指定的 shapeoutput 张量。此操作是 tf.gather_nd 运算符的逆运算,它从给定的张量中提取值或切片。

此操作类似于 tf.tensor_scatter_add ,不同之处在于张量是零初始化的。调用 tf.scatter_nd(indices, values, shape) 与调用 tf.tensor_scatter_add(tf.zeros(shape, values.dtype), indices, values) 相同。

如果indices 包含重复项,则将重复项values 累加(求和)。

警告:应用更新的顺序是不确定的,因此如果 indices 包含重复项,则输出将是不确定的;由于某些数值近似问题,以不同顺序求和的数字可能会产生不同的结果。

indices 是形状为 shape 的整数张量。 indices 的最后一个维度最多可以是 shape 的等级:

indices.shape[-1] <= shape.rank

indices 的最后一个维度对应于 shape 的维度 indices.shape[-1] 的元素索引(如果 indices.shape[-1] = shape.rank )或切片(如果 indices.shape[-1] < shape.rank )。

updates 是一个具有形状的张量:

indices.shape[:-1] + shape[indices.shape[-1]:]

scatter 操作的最简单形式是按索引在张量中插入单个元素。考虑一个示例,您希望在具有 8 个元素的 rank-1 张量中插入 4 个分散元素。

在 Python 中,这个分散操作看起来像这样:

indices = tf.constant([[4], [3], [1], [7]])
    updates = tf.constant([9, 10, 11, 12])
    shape = tf.constant([8])
    scatter = tf.scatter_nd(indices, updates, shape)
    print(scatter)

生成的张量如下所示:

[0, 11, 0, 10, 9, 0, 0, 12]

您还可以一次插入更高等级张量的整个切片。例如,您可以在具有两个新值矩阵的 rank-3 张量的第一个维度中插入两个切片。

在 Python 中,这个分散操作看起来像这样:

indices = tf.constant([[0], [2]])
    updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
                            [7, 7, 7, 7], [8, 8, 8, 8]],
                           [[5, 5, 5, 5], [6, 6, 6, 6],
                            [7, 7, 7, 7], [8, 8, 8, 8]]])
    shape = tf.constant([4, 4, 4])
    scatter = tf.scatter_nd(indices, updates, shape)
    print(scatter)

生成的张量如下所示:

[[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
 [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
 [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
 [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]

请注意,在 CPU 上,如果发现超出范围的索引,则会返回错误。在 GPU 上,如果发现超出范围的索引,则忽略该索引。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.scatter_nd。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。