根據 indices
從現有張量中減去稀疏的 updates
。
用法
tf.tensor_scatter_nd_sub(
tensor, indices, updates, name=None
)
參數
-
tensor
一個Tensor
。要複製/更新的張量。 -
indices
一個Tensor
。必須是以下類型之一:int32
,int64
。索引張量。 -
updates
一個Tensor
。必須與tensor
具有相同的類型。更新分散到輸出中。 -
name
操作的名稱(可選)。
返回
-
一個
Tensor
。具有與tensor
相同的類型。
此操作通過從傳入的 tensor
中減去稀疏的 updates
來創建一個新張量。此操作與 tf.scatter_nd_sub
非常相似,不同之處在於從現有張量(而不是變量)中減去更新。如果無法重新使用現有張量的內存,則製作並更新副本。
indices
是一個整數張量,其中包含指向形狀為 shape
的新張量的索引。 indices
的最後一個維度最多可以是 shape
的等級:
indices.shape[-1] <= shape.rank
indices
的最後一個維度對應於沿 shape
的維度 indices.shape[-1]
的元素(如果是 indices.shape[-1] = shape.rank
)或切片(如果是 indices.shape[-1] < shape.rank
)的索引。 updates
是一個帶形狀的張量
indices.shape[:-1] + shape[indices.shape[-1]:]
tensor_scatter_sub 的最簡單形式是按索引從張量中減去單個元素。例如,假設我們想在一個有 8 個元素的 rank-1 張量中插入 4 個分散的元素。
在 Python 中,這個散點減法操作看起來像這樣:
indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
tensor = tf.ones([8], dtype=tf.int32)
updated = tf.tensor_scatter_nd_sub(tensor, indices, updates)
print(updated)
生成的張量如下所示:
[1, -10, 1, -9, -8, 1, 1, -11]
我們還可以一次插入更高等級張量的整個切片。例如,如果我們想在具有兩個新值矩陣的 rank-3 張量的第一維中插入兩個切片。
在 Python 中,此分散添加操作如下所示:
indices = tf.constant([[0], [2]])
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]],
[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]]])
tensor = tf.ones([4, 4, 4],dtype=tf.int32)
updated = tf.tensor_scatter_nd_sub(tensor, indices, updates)
print(updated)
生成的張量如下所示:
[[[-4, -4, -4, -4], [-5, -5, -5, -5], [-6, -6, -6, -6], [-7, -7, -7, -7]],
[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]],
[[-4, -4, -4, -4], [-5, -5, -5, -5], [-6, -6, -6, -6], [-7, -7, -7, -7]],
[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]
請注意,在 CPU 上,如果發現超出範圍的索引,則會返回錯誤。在 GPU 上,如果發現超出範圍的索引,則忽略該索引。
相關用法
- Python tf.tensor_scatter_nd_max用法及代碼示例
- Python tf.tensor_scatter_nd_update用法及代碼示例
- Python tf.tensor_scatter_nd_add用法及代碼示例
- Python tf.test.is_built_with_rocm用法及代碼示例
- Python tf.test.TestCase.assertLogs用法及代碼示例
- Python tf.test.is_gpu_available用法及代碼示例
- Python tf.test.TestCase.assertItemsEqual用法及代碼示例
- Python tf.test.TestCase.assertWarns用法及代碼示例
- Python tf.test.TestCase.create_tempfile用法及代碼示例
- Python tf.test.TestCase.cached_session用法及代碼示例
- Python tf.test.TestCase.captureWritesToStream用法及代碼示例
- Python tf.test.create_local_cluster用法及代碼示例
- Python tf.test.TestCase.assertCountEqual用法及代碼示例
- Python tf.test.TestCase.assertRaises用法及代碼示例
- Python tf.test.is_built_with_cuda用法及代碼示例
- Python tf.test.compute_gradient用法及代碼示例
- Python tf.test.gpu_device_name用法及代碼示例
- Python tf.test.TestCase.session用法及代碼示例
- Python tf.test.TestCase.create_tempdir用法及代碼示例
- Python tf.test.is_built_with_gpu_support用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.tensor_scatter_nd_sub。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。