当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.raw_ops.TensorScatterSub用法及代码示例


根据 indices 从现有张量中减去稀疏的 updates

用法

tf.raw_ops.TensorScatterSub(
    tensor, indices, updates, name=None
)

参数

  • tensor 一个Tensor。要复制/更新的张量。
  • indices 一个Tensor。必须是以下类型之一:int32 , int64。索引张量。
  • updates 一个Tensor。必须与 tensor 具有相同的类型。更新分散到输出中。
  • name 操作的名称(可选)。

返回

  • 一个Tensor。具有与 tensor 相同的类型。

此操作通过从传入的 tensor 中减去稀疏的 updates 来创建一个新张量。此操作与 tf.scatter_nd_sub 非常相似,不同之处在于从现有张量(而不是变量)中减去更新。如果无法重新使用现有张量的内存,则制作并更新副本。

indices 是一个整数张量,其中包含指向形状为 shape 的新张量的索引。 indices 的最后一个维度最多可以是 shape 的等级:

indices.shape[-1] <= shape.rank

indices 的最后一个维度对应于沿 shape 的维度 indices.shape[-1] 的元素(如果是 indices.shape[-1] = shape.rank )或切片(如果是 indices.shape[-1] < shape.rank )的索引。 updates 是一个带形状的张量

indices.shape[:-1] + shape[indices.shape[-1]:]

tensor_scatter_sub 的最简单形式是按索引从张量中减去单个元素。例如,假设我们想在一个有 8 个元素的 rank-1 张量中插入 4 个分散的元素。

在 Python 中,这个散点减法操作看起来像这样:

indices = tf.constant([[4], [3], [1], [7]])
    updates = tf.constant([9, 10, 11, 12])
    tensor = tf.ones([8], dtype=tf.int32)
    updated = tf.tensor_scatter_nd_sub(tensor, indices, updates)
    print(updated)

生成的张量如下所示:

[1, -10, 1, -9, -8, 1, 1, -11]

我们还可以一次插入更高等级张量的整个切片。例如,如果我们想在具有两个新值矩阵的 rank-3 张量的第一维中插入两个切片。

在 Python 中,此分散添加操作如下所示:

indices = tf.constant([[0], [2]])
    updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
                            [7, 7, 7, 7], [8, 8, 8, 8]],
                           [[5, 5, 5, 5], [6, 6, 6, 6],
                            [7, 7, 7, 7], [8, 8, 8, 8]]])
    tensor = tf.ones([4, 4, 4],dtype=tf.int32)
    updated = tf.tensor_scatter_nd_sub(tensor, indices, updates)
    print(updated)

生成的张量如下所示:

[[[-4, -4, -4, -4], [-5, -5, -5, -5], [-6, -6, -6, -6], [-7, -7, -7, -7]],
 [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]],
 [[-4, -4, -4, -4], [-5, -5, -5, -5], [-6, -6, -6, -6], [-7, -7, -7, -7]],
 [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]

请注意,在 CPU 上,如果发现超出范围的索引,则会返回错误。在 GPU 上,如果发现超出范围的索引,则忽略该索引。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.raw_ops.TensorScatterSub。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。