當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.raw_ops.GatherV2用法及代碼示例


根據 indicesparamsaxis 收集切片。

用法

tf.raw_ops.GatherV2(
    params, indices, axis, batch_dims=0, name=None
)

參數

  • params 一個Tensor。從中收集值的張量。必須至少排名 axis + 1
  • indices 一個Tensor。必須是以下類型之一:int32 , int64。索引張量。必須在 [0, params.shape[axis]) 範圍內。
  • axis 一個Tensor。必須是以下類型之一:int32 , int64params 中要從中收集 indices 的軸。默認為第一個維度。支持負索引。
  • batch_dims 可選的 int 。默認為 0
  • name 操作的名稱(可選)。

返回

  • 一個Tensor。具有與 params 相同的類型。

indices 必須是任意維度的整數張量(通常為 0-D 或 1-D)。生成形狀為 params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1:] 的輸出張量,其中:

# Scalar indices (output is rank(params) - 1).
    output[a_0, ..., a_n, b_0, ..., b_n] =
      params[a_0, ..., a_n, indices, b_0, ..., b_n]

    # Vector indices (output is rank(params)).
    output[a_0, ..., a_n, i, b_0, ..., b_n] =
      params[a_0, ..., a_n, indices[i], b_0, ..., b_n]

    # Higher rank indices (output is rank(params) + rank(indices) - 1).
    output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] =
      params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n]

請注意,在 CPU 上,如果發現超出範圍的索引,則會返回錯誤。在 GPU 上,如果發現超出範圍的索引,則將 0 存儲在相應的輸出值中。

另請參見 tf.batch_gathertf.gather_nd

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.raw_ops.GatherV2。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。