當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.scatter_nd用法及代碼示例


根據 indicesupdates 分散到形狀為 shape 的張量中。

用法

tf.scatter_nd(
    indices, updates, shape, name=None
)

參數

  • indices 一個Tensor。必須是以下類型之一:int32 , int64。 index 的張量。
  • updates 一個Tensor。分散到輸出張量中的值。
  • shape 一個Tensor。必須與 indices 具有相同的類型。一維。輸出張量的形狀。
  • name 操作的名稱(可選)。

返回

  • 一個Tensor。具有與 updates 相同的類型。

根據指定的 indices 的各個值,通過散布稀疏的 updates 來更新輸入張量。此操作返回帶有您指定的 shapeoutput 張量。此操作是 tf.gather_nd 運算符的逆運算,它從給定的張量中提取值或切片。

此操作類似於 tf.tensor_scatter_add ,不同之處在於張量是零初始化的。調用 tf.scatter_nd(indices, values, shape) 與調用 tf.tensor_scatter_add(tf.zeros(shape, values.dtype), indices, values) 相同。

如果indices 包含重複項,則將重複項values 累加(求和)。

警告:應用更新的順序是不確定的,因此如果 indices 包含重複項,則輸出將是不確定的;由於某些數值近似問題,以不同順序求和的數字可能會產生不同的結果。

indices 是形狀為 shape 的整數張量。 indices 的最後一個維度最多可以是 shape 的等級:

indices.shape[-1] <= shape.rank

indices 的最後一個維度對應於 shape 的維度 indices.shape[-1] 的元素索引(如果 indices.shape[-1] = shape.rank )或切片(如果 indices.shape[-1] < shape.rank )。

updates 是一個具有形狀的張量:

indices.shape[:-1] + shape[indices.shape[-1]:]

scatter 操作的最簡單形式是按索引在張量中插入單個元素。考慮一個示例,您希望在具有 8 個元素的 rank-1 張量中插入 4 個分散元素。

在 Python 中,這個分散操作看起來像這樣:

indices = tf.constant([[4], [3], [1], [7]])
    updates = tf.constant([9, 10, 11, 12])
    shape = tf.constant([8])
    scatter = tf.scatter_nd(indices, updates, shape)
    print(scatter)

生成的張量如下所示:

[0, 11, 0, 10, 9, 0, 0, 12]

您還可以一次插入更高等級張量的整個切片。例如,您可以在具有兩個新值矩陣的 rank-3 張量的第一個維度中插入兩個切片。

在 Python 中,這個分散操作看起來像這樣:

indices = tf.constant([[0], [2]])
    updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
                            [7, 7, 7, 7], [8, 8, 8, 8]],
                           [[5, 5, 5, 5], [6, 6, 6, 6],
                            [7, 7, 7, 7], [8, 8, 8, 8]]])
    shape = tf.constant([4, 4, 4])
    scatter = tf.scatter_nd(indices, updates, shape)
    print(scatter)

生成的張量如下所示:

[[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
 [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
 [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
 [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]

請注意,在 CPU 上,如果發現超出範圍的索引,則會返回錯誤。在 GPU 上,如果發現超出範圍的索引,則忽略該索引。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.scatter_nd。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。