根据 indices
将 updates
分散到现有张量中。
用法
tf.tensor_scatter_nd_update(
tensor, indices, updates, name=None
)
参数
-
tensor
要复制/更新的张量。 -
indices
要更新的 index 。 -
updates
适用于 index 的更新。 -
name
操作的可选名称。
返回
- 具有给定形状的新张量,并根据索引应用更新。
此操作通过将稀疏的 updates
应用于输入 tensor
来创建一个新的张量。这类似于索引分配。
# Not implemented: tensors cannot be updated inplace.
tensor[indices] = updates
如果在 CPU 上发现超出范围的索引,则返回错误。
警告:此操作有一些特定于 GPU 的语义。
- 如果发现超出范围的索引,则忽略该索引。
- 应用更新的顺序是不确定的,因此如果
indices
包含重复项,则输出将是不确定的。
此操作与
非常相似,只是更新分散在现有张量上(而不是 zero-tensor)。如果无法重新使用现有张量的内存,则制作并更新副本。tf.scatter_nd
一般来说:
indices
是一个整数张量 - 要在tensor
中更新的索引。indices
已至少两个轴,最后一个轴是索引向量的深度。- 对于
indices
中的每个索引向量,在updates
中都有相应的条目。 - 如果索引向量的长度与
tensor
的等级匹配,则索引向量每个都指向tensor
中的标量,并且每次更新都是一个标量。 - 如果索引向量的长度小于
tensor
的秩,则每个索引向量都指向tensor
的切片,并且更新的形状必须与该切片匹配。
总体而言,这导致以下形状约束:
assert tf.rank(indices) >= 2
index_depth = indices.shape[-1]
batch_shape = indices.shape[:-1]
assert index_depth <= tf.rank(tensor)
outer_shape = tensor.shape[:index_depth]
inner_shape = tensor.shape[index_depth:]
assert updates.shape == batch_shape + inner_shape
典型用法通常比这种一般形式简单得多,从简单的例子开始可以更好地理解:
标量更新
最简单的用法是按索引将标量元素插入张量。在这种情况下,index_depth
必须等于输入 tensor
的等级,切片 indices
的每一列是输入 tensor
轴的索引。
在这个最简单的情况下,形状约束是:
num_updates, index_depth = indices.shape.as_list()
assert updates.shape == [num_updates]
assert index_depth == tf.rank(tensor)`
例如,在 8 个元素的 rank-1 张量中插入 4 个分散元素。
这个分散操作看起来像这样:
tensor = [0, 0, 0, 0, 0, 0, 0, 0] # tf.rank(tensor) == 1
indices = [[1], [3], [4], [7]] # num_updates == 4, index_depth == 1
updates = [9, 10, 11, 12] # num_updates == 4
print(tf.tensor_scatter_nd_update(tensor, indices, updates))
tf.Tensor([ 0 9 0 10 11 0 0 12], shape=(8,), dtype=int32)
updates
的长度(第一轴)必须等于 indices
的长度:num_updates
。这是插入的更新数。每个标量更新都插入到索引位置的tensor
。
对于更高等级的输入 tensor
标量更新可以通过使用匹配
的 tf.rank(tensor)
index_depth
插入:
tensor = [[1, 1], [1, 1], [1, 1]] # tf.rank(tensor) == 2
indices = [[0, 1], [2, 0]] # num_updates == 2, index_depth == 2
updates = [5, 10] # num_updates == 2
print(tf.tensor_scatter_nd_update(tensor, indices, updates))
tf.Tensor(
[[ 1 5]
[ 1 1]
[10 1]], shape=(3, 2), dtype=int32)
切片更新
当输入tensor
具有多个轴散布时,可用于更新整个切片。
在这种情况下,将输入tensor
视为两级array-of-arrays 会很有帮助。这个两级数组的形状分为 outer_shape
和 inner_shape
。
indices
索引输入张量的外层(outer_shape
)。并将该位置的 sub-array 替换为 updates
列表中的相应项目。每次更新的形状是 inner_shape
。
更新切片列表时,形状约束为:
num_updates, index_depth = indices.shape.as_list()
inner_shape = tensor.shape[:index_depth]
outer_shape = tensor.shape[index_depth:]
assert updates.shape == [num_updates, inner_shape]
例如,要更新 (6, 3)
tensor
的行:
tensor = tf.zeros([6, 3], dtype=tf.int32)
使用索引深度为 1。
indices = tf.constant([[2], [4]]) # num_updates == 2, index_depth == 1
num_updates, index_depth = indices.shape.as_list()
outer_shape
是 6
,内部形状是 3
:
outer_shape = tensor.shape[:index_depth]
inner_shape = tensor.shape[index_depth:]
正在索引 2 行,因此必须提供 2 updates
。每个更新的形状都必须与 inner_shape
相匹配。
# num_updates == 2, inner_shape==3
updates = tf.constant([[1, 2, 3],
[4, 5, 6]])
总而言之,这给出了:
tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()
array([[0, 0, 0],
[0, 0, 0],
[1, 2, 3],
[0, 0, 0],
[4, 5, 6],
[0, 0, 0]], dtype=int32)
更多切片更新示例
代表一批大小均匀的视频剪辑的张量自然有 5 个轴:[batch_size, time, width, height, channels]
。
例如:
batch_size, time, width, height, channels = 13,11,7,5,3
video_batch = tf.zeros([batch_size, time, width, height, channels])
要替换选择的视频剪辑:
- 使用 1 的
index_depth
(索引outer_shape
:[batch_size]
) - 为每个更新提供与
inner_shape
匹配的形状:[time, width, height, channels]
。
用一个替换前两个剪辑:
indices = [[0],[1]]
new_clips = tf.ones([2, time, width, height, channels])
tf.tensor_scatter_nd_update(video_batch, indices, new_clips)
要替换视频中的选定帧:
indices
对于outer_shape
的index_depth
必须为 2:[batch_size, time]
。updates
的形状必须像图像列表。每个更新都必须有一个形状,匹配inner_shape
:[width, height, channels]
。
要替换前三个视频剪辑的第一帧:
indices = [[0, 0], [1, 0], [2, 0]] # num_updates=3, index_depth=2
new_images = tf.ones([
# num_updates=3, inner_shape=(width, height, channels)
3, width, height, channels])
tf.tensor_scatter_nd_update(video_batch, indices, new_images)
折叠索引
在简单的情况下,将indices
和updates
视为列表很方便,但这不是严格的要求。 indices
和 updates
可以折叠成 batch_shape
,而不是平面 num_updates
。这个 batch_shape
是 indices
的所有轴,除了最里面的 index_depth
轴。
index_depth = indices.shape[-1]
batch_shape = indices.shape[:-1]
注意:一个例外是 batch_shape
不能是 []
。您不能通过传递形状为 [index_depth]
的索引来更新单个索引。
updates
必须具有匹配的 batch_shape
(inner_shape
之前的轴)。
assert updates.shape == batch_shape + inner_shape
注意:结果等效于展平 indices
和 updates
的 batch_shape
轴。当构造"folded" 索引和更新更自然时,这种概括只是避免了重塑的需要。
通过这种概括,完整的形状约束是:
assert tf.rank(indices) >= 2
index_depth = indices.shape[-1]
batch_shape = indices.shape[:-1]
assert index_depth <= tf.rank(tensor)
outer_shape = tensor.shape[:index_depth]
inner_shape = tensor.shape[index_depth:]
assert updates.shape == batch_shape + inner_shape
例如,要在 (5,5)
矩阵上绘制 X
,请从以下索引开始:
tensor = tf.zeros([5,5])
indices = tf.constant([
[[0,0],
[1,1],
[2,2],
[3,3],
[4,4]],
[[0,4],
[1,3],
[2,2],
[3,1],
[4,0]],
])
indices.shape.as_list() # batch_shape == [2, 5], index_depth == 2
[2, 5, 2]
这里 indices
的形状不是 [num_updates, index_depth]
,而是 batch_shape+[index_depth]
的形状。
由于 index_depth
等于 tensor
的等级:
outer_shape
是(5,5)
inner_shape
是()
- 每次更新都是标量updates.shape
是batch_shape + inner_shape == (5,2) + ()
updates = [
[1,1,1,1,1],
[1,1,1,1,1],
]
把它放在一起给出:
tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()
array([[1., 0., 0., 0., 1.],
[0., 1., 0., 1., 0.],
[0., 0., 1., 0., 0.],
[0., 1., 0., 1., 0.],
[1., 0., 0., 0., 1.]], dtype=float32)
相关用法
- Python tf.tensor_scatter_nd_max用法及代码示例
- Python tf.tensor_scatter_nd_sub用法及代码示例
- Python tf.tensor_scatter_nd_add用法及代码示例
- Python tf.test.is_built_with_rocm用法及代码示例
- Python tf.test.TestCase.assertLogs用法及代码示例
- Python tf.test.is_gpu_available用法及代码示例
- Python tf.test.TestCase.assertItemsEqual用法及代码示例
- Python tf.test.TestCase.assertWarns用法及代码示例
- Python tf.test.TestCase.create_tempfile用法及代码示例
- Python tf.test.TestCase.cached_session用法及代码示例
- Python tf.test.TestCase.captureWritesToStream用法及代码示例
- Python tf.test.create_local_cluster用法及代码示例
- Python tf.test.TestCase.assertRaises用法及代码示例
- Python tf.test.TestCase.assertCountEqual用法及代码示例
- Python tf.test.is_built_with_cuda用法及代码示例
- Python tf.test.compute_gradient用法及代码示例
- Python tf.test.gpu_device_name用法及代码示例
- Python tf.test.TestCase.session用法及代码示例
- Python tf.test.TestCase.create_tempdir用法及代码示例
- Python tf.test.is_built_with_gpu_support用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.tensor_scatter_nd_update。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。