当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.Variable.scatter_nd_update用法及代码示例


用法

scatter_nd_update(
    indices, updates, name=None
)

参数

  • indices 要在操作中使用的索引。
  • updates 要在操作中使用的值。
  • name 操作的名称。

返回

  • 更新的变量。

对变量中的单个值或切片应用稀疏赋值。

变量具有排名 Pindices 是排名 QTensor

indices 必须是整数张量,包含对自身的索引。它必须是形状 [d_0, ..., d_{Q-2}, K] 其中 0 < K <= P

indices 的最内维度(长度为 K )对应于沿着 self 的第 K 维度的元素(如果是 K = P )或切片(如果是 K < P )的索引。

updates 是等级为 Q-1+P-KTensor,形状:

[d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]].

例如,假设我们要将 4 个分散的元素添加到 rank-1 张量到 8 个元素。在 Python 中,该更新将如下所示:

v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
    indices = tf.constant([[4], [3], [1] ,[7]])
    updates = tf.constant([9, 10, 11, 12])
    v.scatter_nd_update(indices, updates)
    print(v)

对 v 的更新结果如下所示:

[1, 11, 3, 10, 9, 6, 7, 12]

有关如何更新切片的更多详细信息,请参阅tf.scatter_nd

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.Variable.scatter_nd_update。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。