当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.raw_ops.GatherV2用法及代码示例


根据 indicesparamsaxis 收集切片。

用法

tf.raw_ops.GatherV2(
    params, indices, axis, batch_dims=0, name=None
)

参数

  • params 一个Tensor。从中收集值的张量。必须至少排名 axis + 1
  • indices 一个Tensor。必须是以下类型之一:int32 , int64。索引张量。必须在 [0, params.shape[axis]) 范围内。
  • axis 一个Tensor。必须是以下类型之一:int32 , int64params 中要从中收集 indices 的轴。默认为第一个维度。支持负索引。
  • batch_dims 可选的 int 。默认为 0
  • name 操作的名称(可选)。

返回

  • 一个Tensor。具有与 params 相同的类型。

indices 必须是任意维度的整数张量(通常为 0-D 或 1-D)。生成形状为 params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1:] 的输出张量,其中:

# Scalar indices (output is rank(params) - 1).
    output[a_0, ..., a_n, b_0, ..., b_n] =
      params[a_0, ..., a_n, indices, b_0, ..., b_n]

    # Vector indices (output is rank(params)).
    output[a_0, ..., a_n, i, b_0, ..., b_n] =
      params[a_0, ..., a_n, indices[i], b_0, ..., b_n]

    # Higher rank indices (output is rank(params) + rank(indices) - 1).
    output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] =
      params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n]

请注意,在 CPU 上,如果发现超出范围的索引,则会返回错误。在 GPU 上,如果发现超出范围的索引,则将 0 存储在相应的输出值中。

另请参见 tf.batch_gathertf.gather_nd

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.raw_ops.GatherV2。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。