当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.raw_ops.RaggedGather用法及代码示例


根据 indicesparams0 收集参差不齐的切片。

用法

tf.raw_ops.RaggedGather(
    params_nested_splits, params_dense_values, indices, OUTPUT_RAGGED_RANK,
    name=None
)

参数

  • params_nested_splits 至少 1 个具有相同类型的 Tensor 对象的列表:int32 , int64。为 params RaggedTensor 输入定义 row-partitioning 的 nested_row_splits 张量。
  • params_dense_values 一个Tensorparams RaggedTensor 的 flat_values。 python 级别的术语从dense_values 更改为flat_values,因此dense_values 是不推荐使用的名称。
  • indices 一个Tensor。必须是以下类型之一:int32 , int64。应该收集的值的params 的最外层维度中的索引。
  • OUTPUT_RAGGED_RANK int>= 0 。输出 RaggedTensor 的参差不齐等级。 output_nested_splits 将包含这个数量的 row_splits 张量。该值应等于 indices.shape.ndims + params.ragged_rank - 1
  • name 操作的名称(可选)。

返回

  • Tensor 对象的元组(output_nested_splits、output_dense_values)。
  • output_nested_splits params_nested_splits 具有相同类型的 OUTPUT_RAGGED_RANK Tensor 对象的列表。
  • output_dense_values 一个Tensor。具有与 params_dense_values 相同的类型。

输出由 output_dense_valuesoutput_nested_splits 组成的 RaggedTensor 输出,例如:

output.shape = indices.shape + params.shape[1:]
output.ragged_rank = indices.shape.ndims + params.ragged_rank
output[i...j, d0...dn] = params[indices[i...j], d0...dn]

其中

  • params = ragged.from_nested_row_splits(params_dense_values, params_nested_splits)提供应收集的值。
  • indices 是具有 dtype int32int64 的密集张量,指示应收集哪些值。
  • output = ragged.from_nested_row_splits(output_dense_values, output_nested_splits)是输出张量。

(注意:这个c++ op用于实现higher-level python tf.ragged.gather op,它也支持不规则索引。)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.raw_ops.RaggedGather。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。