根據 indices
將 updates
分散到形狀為 shape
的張量中。
用法
tf.scatter_nd(
indices, updates, shape, name=None
)
參數
-
indices
一個Tensor
。必須是以下類型之一:int32
,int64
。 index 的張量。 -
updates
一個Tensor
。分散到輸出張量中的值。 -
shape
一個Tensor
。必須與indices
具有相同的類型。一維。輸出張量的形狀。 -
name
操作的名稱(可選)。
返回
-
一個
Tensor
。具有與updates
相同的類型。
根據指定的 indices
的各個值,通過散布稀疏的 updates
來更新輸入張量。此操作返回帶有您指定的 shape
的 output
張量。此操作是 tf.gather_nd
運算符的逆運算,它從給定的張量中提取值或切片。
此操作類似於 tf.tensor_scatter_add
,不同之處在於張量是零初始化的。調用 tf.scatter_nd(indices, values, shape)
與調用 tf.tensor_scatter_add(tf.zeros(shape, values.dtype), indices, values)
相同。
如果indices
包含重複項,則將重複項values
累加(求和)。
警告:應用更新的順序是不確定的,因此如果 indices
包含重複項,則輸出將是不確定的;由於某些數值近似問題,以不同順序求和的數字可能會產生不同的結果。
indices
是形狀為 shape
的整數張量。 indices
的最後一個維度最多可以是 shape
的等級:
indices.shape[-1] <= shape.rank
indices
的最後一個維度對應於 shape
的維度 indices.shape[-1]
的元素索引(如果 indices.shape[-1] = shape.rank
)或切片(如果 indices.shape[-1] < shape.rank
)。
updates
是一個具有形狀的張量:
indices.shape[:-1] + shape[indices.shape[-1]:]
scatter 操作的最簡單形式是按索引在張量中插入單個元素。考慮一個示例,您希望在具有 8 個元素的 rank-1 張量中插入 4 個分散元素。
在 Python 中,這個分散操作看起來像這樣:
indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
shape = tf.constant([8])
scatter = tf.scatter_nd(indices, updates, shape)
print(scatter)
生成的張量如下所示:
[0, 11, 0, 10, 9, 0, 0, 12]
您還可以一次插入更高等級張量的整個切片。例如,您可以在具有兩個新值矩陣的 rank-3 張量的第一個維度中插入兩個切片。
在 Python 中,這個分散操作看起來像這樣:
indices = tf.constant([[0], [2]])
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]],
[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]]])
shape = tf.constant([4, 4, 4])
scatter = tf.scatter_nd(indices, updates, shape)
print(scatter)
生成的張量如下所示:
[[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
請注意,在 CPU 上,如果發現超出範圍的索引,則會返回錯誤。在 GPU 上,如果發現超出範圍的索引,則忽略該索引。
相關用法
- Python tf.scan用法及代碼示例
- Python tf.summary.scalar用法及代碼示例
- Python tf.strings.substr用法及代碼示例
- Python tf.strings.reduce_join用法及代碼示例
- Python tf.sparse.cross用法及代碼示例
- Python tf.sparse.mask用法及代碼示例
- Python tf.strings.regex_full_match用法及代碼示例
- Python tf.sparse.split用法及代碼示例
- Python tf.strings.regex_replace用法及代碼示例
- Python tf.signal.overlap_and_add用法及代碼示例
- Python tf.strings.length用法及代碼示例
- Python tf.strided_slice用法及代碼示例
- Python tf.sparse.to_dense用法及代碼示例
- Python tf.strings.bytes_split用法及代碼示例
- Python tf.summary.text用法及代碼示例
- Python tf.shape用法及代碼示例
- Python tf.sparse.expand_dims用法及代碼示例
- Python tf.signal.frame用法及代碼示例
- Python tf.sparse.maximum用法及代碼示例
- Python tf.signal.linear_to_mel_weight_matrix用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.scatter_nd。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。