当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.tensor_scatter_nd_update用法及代码示例


根据 indicesupdates 分散到现有张量中。

用法

tf.tensor_scatter_nd_update(
    tensor, indices, updates, name=None
)

参数

  • tensor 要复制/更新的张量。
  • indices 要更新的 index 。
  • updates 适用于 index 的更新。
  • name 操作的可选名称。

返回

  • 具有给定形状的新张量,并根据索引应用更新。

此操作通过将稀疏的 updates 应用于输入 tensor 来创建一个新的张量。这类似于索引分配。

# Not implemented: tensors cannot be updated inplace.
tensor[indices] = updates

如果在 CPU 上发现超出范围的索引,则返回错误。

警告:此操作有一些特定于 GPU 的语义。

  • 如果发现超出范围的索引,则忽略该索引。
  • 应用更新的顺序是不确定的,因此如果 indices 包含重复项,则输出将是不确定的。

此操作与 tf.scatter_nd 非常相似,只是更新分散在现有张量上(而不是 zero-tensor)。如果无法重新使用现有张量的内存,则制作并更新副本。

一般来说:

  • indices 是一个整数张量 - 要在 tensor 中更新的索引。
  • indices至少两个轴,最后一个轴是索引向量的深度。
  • 对于 indices 中的每个索引向量,在 updates 中都有相应的条目。
  • 如果索引向量的长度与 tensor 的等级匹配,则索引向量每个都指向 tensor 中的标量,并且每次更新都是一个标量。
  • 如果索引向量的长度小于 tensor 的秩,则每个索引向量都指向 tensor 的切片,并且更新的形状必须与该切片匹配。

总体而言,这导致以下形状约束:

assert tf.rank(indices) >= 2
index_depth = indices.shape[-1]
batch_shape = indices.shape[:-1]
assert index_depth <= tf.rank(tensor)
outer_shape = tensor.shape[:index_depth]
inner_shape = tensor.shape[index_depth:]
assert updates.shape == batch_shape + inner_shape

典型用法通常比这种一般形式简单得多,从简单的例子开始可以更好地理解:

标量更新

最简单的用法是按索引将标量元素插入张量。在这种情况下,index_depth 必须等于输入 tensor 的等级,切片 indices 的每一列是输入 tensor 轴的索引。

在这个最简单的情况下,形状约束是:

num_updates, index_depth = indices.shape.as_list()
assert updates.shape == [num_updates]
assert index_depth == tf.rank(tensor)`

例如,在 8 个元素的 rank-1 张量中插入 4 个分散元素。

这个分散操作看起来像这样:

tensor = [0, 0, 0, 0, 0, 0, 0, 0]    # tf.rank(tensor) == 1
indices = [[1], [3], [4], [7]]       # num_updates == 4, index_depth == 1
updates = [9, 10, 11, 12]            # num_updates == 4
print(tf.tensor_scatter_nd_update(tensor, indices, updates))
tf.Tensor([ 0 9  0 10  11  0  0 12], shape=(8,), dtype=int32)

updates 的长度(第一轴)必须等于 indices 的长度:num_updates。这是插入的更新数。每个标量更新都插入到索引位置的tensor

对于更高等级的输入 tensor 标量更新可以通过使用匹配 tf.rank(tensor) index_depth 插入:

tensor = [[1, 1], [1, 1], [1, 1]]    # tf.rank(tensor) == 2
indices = [[0, 1], [2, 0]]           # num_updates == 2, index_depth == 2
updates = [5, 10]                    # num_updates == 2
print(tf.tensor_scatter_nd_update(tensor, indices, updates))
tf.Tensor(
    [[ 1  5]
     [ 1  1]
     [10  1]], shape=(3, 2), dtype=int32)

切片更新

当输入tensor 具有多个轴散布时,可用于更新整个切片。

在这种情况下,将输入tensor 视为两级array-of-arrays 会很有帮助。这个两级数组的形状分为 outer_shapeinner_shape

indices 索引输入张量的外层(outer_shape)。并将该位置的 sub-array 替换为 updates 列表中的相应项目。每次更新的形状是 inner_shape

更新切片列表时,形状约束为:

num_updates, index_depth = indices.shape.as_list()
inner_shape = tensor.shape[:index_depth]
outer_shape = tensor.shape[index_depth:]
assert updates.shape == [num_updates, inner_shape]

例如,要更新 (6, 3) tensor 的行:

tensor = tf.zeros([6, 3], dtype=tf.int32)

使用索引深度为 1。

indices = tf.constant([[2], [4]])     # num_updates == 2, index_depth == 1
num_updates, index_depth = indices.shape.as_list()

outer_shape6 ,内部形状是 3

outer_shape = tensor.shape[:index_depth]
inner_shape = tensor.shape[index_depth:]

正在索引 2 行,因此必须提供 2 updates。每个更新的形状都必须与 inner_shape 相匹配。

# num_updates == 2, inner_shape==3
updates = tf.constant([[1, 2, 3],
                       [4, 5, 6]])

总而言之,这给出了:

tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()
array([[0, 0, 0],
       [0, 0, 0],
       [1, 2, 3],
       [0, 0, 0],
       [4, 5, 6],
       [0, 0, 0]], dtype=int32)

更多切片更新示例

代表一批大小均匀的视频剪辑的张量自然有 5 个轴:[batch_size, time, width, height, channels]

例如:

batch_size, time, width, height, channels = 13,11,7,5,3
video_batch = tf.zeros([batch_size, time, width, height, channels])

要替换选择的视频剪辑:

  • 使用 1 的 index_depth(索引 outer_shape : [batch_size] )
  • 为每个更新提供与 inner_shape 匹配的形状:[time, width, height, channels]

用一个替换前两个剪辑:

indices = [[0],[1]]
new_clips = tf.ones([2, time, width, height, channels])
tf.tensor_scatter_nd_update(video_batch, indices, new_clips)

要替换视频中的选定帧:

  • indices 对于 outer_shapeindex_depth 必须为 2:[batch_size, time]
  • updates 的形状必须像图像列表。每个更新都必须有一个形状,匹配 inner_shape : [width, height, channels]

要替换前三个视频剪辑的第一帧:

indices = [[0, 0], [1, 0], [2, 0]] # num_updates=3, index_depth=2
new_images = tf.ones([
  # num_updates=3, inner_shape=(width, height, channels)
  3, width, height, channels])
tf.tensor_scatter_nd_update(video_batch, indices, new_images)

折叠索引

在简单的情况下,将indicesupdates 视为列表很方便,但这不是严格的要求。 indicesupdates 可以折叠成 batch_shape ,而不是平面 num_updates 。这个 batch_shapeindices 的所有轴,除了最里面的 index_depth 轴。

index_depth = indices.shape[-1]
batch_shape = indices.shape[:-1]

注意:一个例外是 batch_shape 不能是 [] 。您不能通过传递形状为 [index_depth] 的索引来更新单个索引。

updates 必须具有匹配的 batch_shape(inner_shape 之前的轴)。

assert updates.shape == batch_shape + inner_shape

注意:结果等效于展平 indicesupdatesbatch_shape 轴。当构造"folded" 索引和更新更自然时,这种概括只是避免了重塑的需要。

通过这种概括,完整的形状约束是:

assert tf.rank(indices) >= 2
index_depth = indices.shape[-1]
batch_shape = indices.shape[:-1]
assert index_depth <= tf.rank(tensor)
outer_shape = tensor.shape[:index_depth]
inner_shape = tensor.shape[index_depth:]
assert updates.shape == batch_shape + inner_shape

例如,要在 (5,5) 矩阵上绘制 X,请从以下索引开始:

tensor = tf.zeros([5,5])
indices = tf.constant([
 [[0,0],
  [1,1],
  [2,2],
  [3,3],
  [4,4]],
 [[0,4],
  [1,3],
  [2,2],
  [3,1],
  [4,0]],
])
indices.shape.as_list()  # batch_shape == [2, 5], index_depth == 2
[2, 5, 2]

这里 indices 的形状不是 [num_updates, index_depth] ,而是 batch_shape+[index_depth] 的形状。

由于 index_depth 等于 tensor 的等级:

  • outer_shape(5,5)
  • inner_shape() - 每次更新都是标量
  • updates.shapebatch_shape + inner_shape == (5,2) + ()
updates = [
  [1,1,1,1,1],
  [1,1,1,1,1],
]

把它放在一起给出:

tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()
array([[1., 0., 0., 0., 1.],
       [0., 1., 0., 1., 0.],
       [0., 0., 1., 0., 0.],
       [0., 1., 0., 1., 0.],
       [1., 0., 0., 0., 1.]], dtype=float32)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.tensor_scatter_nd_update。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。