當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.raw_ops.RaggedGather用法及代碼示例


根據 indicesparams0 收集參差不齊的切片。

用法

tf.raw_ops.RaggedGather(
    params_nested_splits, params_dense_values, indices, OUTPUT_RAGGED_RANK,
    name=None
)

參數

  • params_nested_splits 至少 1 個具有相同類型的 Tensor 對象的列表:int32 , int64。為 params RaggedTensor 輸入定義 row-partitioning 的 nested_row_splits 張量。
  • params_dense_values 一個Tensorparams RaggedTensor 的 flat_values。 python 級別的術語從dense_values 更改為flat_values,因此dense_values 是不推薦使用的名稱。
  • indices 一個Tensor。必須是以下類型之一:int32 , int64。應該收集的值的params 的最外層維度中的索引。
  • OUTPUT_RAGGED_RANK int>= 0 。輸出 RaggedTensor 的參差不齊等級。 output_nested_splits 將包含這個數量的 row_splits 張量。該值應等於 indices.shape.ndims + params.ragged_rank - 1
  • name 操作的名稱(可選)。

返回

  • Tensor 對象的元組(output_nested_splits、output_dense_values)。
  • output_nested_splits params_nested_splits 具有相同類型的 OUTPUT_RAGGED_RANK Tensor 對象的列表。
  • output_dense_values 一個Tensor。具有與 params_dense_values 相同的類型。

輸出由 output_dense_valuesoutput_nested_splits 組成的 RaggedTensor 輸出,例如:

output.shape = indices.shape + params.shape[1:]
output.ragged_rank = indices.shape.ndims + params.ragged_rank
output[i...j, d0...dn] = params[indices[i...j], d0...dn]

其中

  • params = ragged.from_nested_row_splits(params_dense_values, params_nested_splits)提供應收集的值。
  • indices 是具有 dtype int32int64 的密集張量,指示應收集哪些值。
  • output = ragged.from_nested_row_splits(output_dense_values, output_nested_splits)是輸出張量。

(注意:這個c++ op用於實現higher-level python tf.ragged.gather op,它也支持不規則索引。)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.raw_ops.RaggedGather。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。