當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.tensor_scatter_nd_add用法及代碼示例


根據 indices 將稀疏的 updates 添加到現有張量。

用法

tf.tensor_scatter_nd_add(
    tensor, indices, updates, name=None
)

參數

  • tensor 一個Tensor。要複製/更新的張量。
  • indices 一個Tensor。必須是以下類型之一:int32 , int64。索引張量。
  • updates 一個Tensor。必須與 tensor 具有相同的類型。更新分散到輸出中。
  • name 操作的名稱(可選)。

返回

  • 一個Tensor。具有與 tensor 相同的類型。

此操作通過將稀疏的 updates 添加到傳入的 tensor 來創建一個新的張量。此操作與 tf.compat.v1.scatter_nd_add 非常相似,隻是將更新添加到現有張量(而不是變量)上。如果無法重新使用現有張量的內存,則製作並更新副本。

indices 是一個整數張量,其中包含指向形狀為 tensor.shape 的新張量的索引。 indices 的最後一個維度最多可以是 tensor.shape 的等級:

indices.shape[-1] <= tensor.shape.rank

indices 的最後一個維度對應於沿 tensor.shape 的維度 indices.shape[-1] 的元素(如果是 indices.shape[-1] = tensor.shape.rank )或切片(如果是 indices.shape[-1] < tensor.shape.rank )的索引。 updates 是一個帶形狀的張量

indices.shape[:-1] + tensor.shape[indices.shape[-1]:]

tensor_scatter_add 的最簡單形式是按索引將單個元素添加到張量。例如,假設我們想在 8 個元素的 rank-1 張量中添加 4 個元素。

在 Python 中,此分散添加操作如下所示:

indices = tf.constant([[4], [3], [1], [7]])
    updates = tf.constant([9, 10, 11, 12])
    tensor = tf.ones([8], dtype=tf.int32)
    updated = tf.tensor_scatter_nd_add(tensor, indices, updates)
    print(updated)

生成的張量如下所示:

[1, 12, 1, 11, 10, 1, 1, 13]

我們還可以一次插入更高等級張量的整個切片。例如,如果我們想在具有兩個新值矩陣的 rank-3 張量的第一維中插入兩個切片。

在 Python 中,此分散添加操作如下所示:

indices = tf.constant([[0], [2]])
    updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
                            [7, 7, 7, 7], [8, 8, 8, 8]],
                           [[5, 5, 5, 5], [6, 6, 6, 6],
                            [7, 7, 7, 7], [8, 8, 8, 8]]])
    tensor = tf.ones([4, 4, 4],dtype=tf.int32)
    updated = tf.tensor_scatter_nd_add(tensor, indices, updates)
    print(updated)

生成的張量如下所示:

[[[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]],
 [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]],
 [[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]],
 [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]

請注意,在 CPU 上,如果發現超出範圍的索引,則會返回錯誤。在 GPU 上,如果發現超出範圍的索引,則忽略該索引。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.tensor_scatter_nd_add。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。