當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.tensor_scatter_nd_update用法及代碼示例


根據 indicesupdates 分散到現有張量中。

用法

tf.tensor_scatter_nd_update(
    tensor, indices, updates, name=None
)

參數

  • tensor 要複製/更新的張量。
  • indices 要更新的 index 。
  • updates 適用於 index 的更新。
  • name 操作的可選名稱。

返回

  • 具有給定形狀的新張量,並根據索引應用更新。

此操作通過將稀疏的 updates 應用於輸入 tensor 來創建一個新的張量。這類似於索引分配。

# Not implemented: tensors cannot be updated inplace.
tensor[indices] = updates

如果在 CPU 上發現超出範圍的索引,則返回錯誤。

警告:此操作有一些特定於 GPU 的語義。

  • 如果發現超出範圍的索引,則忽略該索引。
  • 應用更新的順序是不確定的,因此如果 indices 包含重複項,則輸出將是不確定的。

此操作與 tf.scatter_nd 非常相似,隻是更新分散在現有張量上(而不是 zero-tensor)。如果無法重新使用現有張量的內存,則製作並更新副本。

一般來說:

  • indices 是一個整數張量 - 要在 tensor 中更新的索引。
  • indices至少兩個軸,最後一個軸是索引向量的深度。
  • 對於 indices 中的每個索引向量,在 updates 中都有相應的條目。
  • 如果索引向量的長度與 tensor 的等級匹配,則索引向量每個都指向 tensor 中的標量,並且每次更新都是一個標量。
  • 如果索引向量的長度小於 tensor 的秩,則每個索引向量都指向 tensor 的切片,並且更新的形狀必須與該切片匹配。

總體而言,這導致以下形狀約束:

assert tf.rank(indices) >= 2
index_depth = indices.shape[-1]
batch_shape = indices.shape[:-1]
assert index_depth <= tf.rank(tensor)
outer_shape = tensor.shape[:index_depth]
inner_shape = tensor.shape[index_depth:]
assert updates.shape == batch_shape + inner_shape

典型用法通常比這種一般形式簡單得多,從簡單的例子開始可以更好地理解:

標量更新

最簡單的用法是按索引將標量元素插入張量。在這種情況下,index_depth 必須等於輸入 tensor 的等級,切片 indices 的每一列是輸入 tensor 軸的索引。

在這個最簡單的情況下,形狀約束是:

num_updates, index_depth = indices.shape.as_list()
assert updates.shape == [num_updates]
assert index_depth == tf.rank(tensor)`

例如,在 8 個元素的 rank-1 張量中插入 4 個分散元素。

這個分散操作看起來像這樣:

tensor = [0, 0, 0, 0, 0, 0, 0, 0]    # tf.rank(tensor) == 1
indices = [[1], [3], [4], [7]]       # num_updates == 4, index_depth == 1
updates = [9, 10, 11, 12]            # num_updates == 4
print(tf.tensor_scatter_nd_update(tensor, indices, updates))
tf.Tensor([ 0 9  0 10  11  0  0 12], shape=(8,), dtype=int32)

updates 的長度(第一軸)必須等於 indices 的長度:num_updates。這是插入的更新數。每個標量更新都插入到索引位置的tensor

對於更高等級的輸入 tensor 標量更新可以通過使用匹配 tf.rank(tensor) index_depth 插入:

tensor = [[1, 1], [1, 1], [1, 1]]    # tf.rank(tensor) == 2
indices = [[0, 1], [2, 0]]           # num_updates == 2, index_depth == 2
updates = [5, 10]                    # num_updates == 2
print(tf.tensor_scatter_nd_update(tensor, indices, updates))
tf.Tensor(
    [[ 1  5]
     [ 1  1]
     [10  1]], shape=(3, 2), dtype=int32)

切片更新

當輸入tensor 具有多個軸散布時,可用於更新整個切片。

在這種情況下,將輸入tensor 視為兩級array-of-arrays 會很有幫助。這個兩級數組的形狀分為 outer_shapeinner_shape

indices 索引輸入張量的外層(outer_shape)。並將該位置的 sub-array 替換為 updates 列表中的相應項目。每次更新的形狀是 inner_shape

更新切片列表時,形狀約束為:

num_updates, index_depth = indices.shape.as_list()
inner_shape = tensor.shape[:index_depth]
outer_shape = tensor.shape[index_depth:]
assert updates.shape == [num_updates, inner_shape]

例如,要更新 (6, 3) tensor 的行:

tensor = tf.zeros([6, 3], dtype=tf.int32)

使用索引深度為 1。

indices = tf.constant([[2], [4]])     # num_updates == 2, index_depth == 1
num_updates, index_depth = indices.shape.as_list()

outer_shape6 ,內部形狀是 3

outer_shape = tensor.shape[:index_depth]
inner_shape = tensor.shape[index_depth:]

正在索引 2 行,因此必須提供 2 updates。每個更新的形狀都必須與 inner_shape 相匹配。

# num_updates == 2, inner_shape==3
updates = tf.constant([[1, 2, 3],
                       [4, 5, 6]])

總而言之,這給出了:

tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()
array([[0, 0, 0],
       [0, 0, 0],
       [1, 2, 3],
       [0, 0, 0],
       [4, 5, 6],
       [0, 0, 0]], dtype=int32)

更多切片更新示例

代表一批大小均勻的視頻剪輯的張量自然有 5 個軸:[batch_size, time, width, height, channels]

例如:

batch_size, time, width, height, channels = 13,11,7,5,3
video_batch = tf.zeros([batch_size, time, width, height, channels])

要替換選擇的視頻剪輯:

  • 使用 1 的 index_depth(索引 outer_shape : [batch_size] )
  • 為每個更新提供與 inner_shape 匹配的形狀:[time, width, height, channels]

用一個替換前兩個剪輯:

indices = [[0],[1]]
new_clips = tf.ones([2, time, width, height, channels])
tf.tensor_scatter_nd_update(video_batch, indices, new_clips)

要替換視頻中的選定幀:

  • indices 對於 outer_shapeindex_depth 必須為 2:[batch_size, time]
  • updates 的形狀必須像圖像列表。每個更新都必須有一個形狀,匹配 inner_shape : [width, height, channels]

要替換前三個視頻剪輯的第一幀:

indices = [[0, 0], [1, 0], [2, 0]] # num_updates=3, index_depth=2
new_images = tf.ones([
  # num_updates=3, inner_shape=(width, height, channels)
  3, width, height, channels])
tf.tensor_scatter_nd_update(video_batch, indices, new_images)

折疊索引

在簡單的情況下,將indicesupdates 視為列表很方便,但這不是嚴格的要求。 indicesupdates 可以折疊成 batch_shape ,而不是平麵 num_updates 。這個 batch_shapeindices 的所有軸,除了最裏麵的 index_depth 軸。

index_depth = indices.shape[-1]
batch_shape = indices.shape[:-1]

注意:一個例外是 batch_shape 不能是 [] 。您不能通過傳遞形狀為 [index_depth] 的索引來更新單個索引。

updates 必須具有匹配的 batch_shape(inner_shape 之前的軸)。

assert updates.shape == batch_shape + inner_shape

注意:結果等效於展平 indicesupdatesbatch_shape 軸。當構造"folded" 索引和更新更自然時,這種概括隻是避免了重塑的需要。

通過這種概括,完整的形狀約束是:

assert tf.rank(indices) >= 2
index_depth = indices.shape[-1]
batch_shape = indices.shape[:-1]
assert index_depth <= tf.rank(tensor)
outer_shape = tensor.shape[:index_depth]
inner_shape = tensor.shape[index_depth:]
assert updates.shape == batch_shape + inner_shape

例如,要在 (5,5) 矩陣上繪製 X,請從以下索引開始:

tensor = tf.zeros([5,5])
indices = tf.constant([
 [[0,0],
  [1,1],
  [2,2],
  [3,3],
  [4,4]],
 [[0,4],
  [1,3],
  [2,2],
  [3,1],
  [4,0]],
])
indices.shape.as_list()  # batch_shape == [2, 5], index_depth == 2
[2, 5, 2]

這裏 indices 的形狀不是 [num_updates, index_depth] ,而是 batch_shape+[index_depth] 的形狀。

由於 index_depth 等於 tensor 的等級:

  • outer_shape(5,5)
  • inner_shape() - 每次更新都是標量
  • updates.shapebatch_shape + inner_shape == (5,2) + ()
updates = [
  [1,1,1,1,1],
  [1,1,1,1,1],
]

把它放在一起給出:

tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()
array([[1., 0., 0., 0., 1.],
       [0., 1., 0., 1., 0.],
       [0., 0., 1., 0., 0.],
       [0., 1., 0., 1., 0.],
       [1., 0., 0., 0., 1.]], dtype=float32)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.tensor_scatter_nd_update。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。