当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.raw_ops.GatherNd用法及代码示例


params 中的切片收集到形状由 indices 指定的张量中。

用法

tf.raw_ops.GatherNd(
    params, indices, name=None
)

参数

  • params 一个Tensor。从中收集值的张量。
  • indices 一个Tensor。必须是以下类型之一:int32int64。索引张量。
  • name 操作的名称(可选)。

返回

  • 一个Tensor。具有与 params 相同的类型。

indices 是一个 K-dimensional 整数张量,最好将其视为 params 索引的 (K-1) 维张量,其中每个元素定义一个 params 切片:

output[\\(i_0, ..., i_{K-2}\\)] = params[indices[\\(i_0, ..., i_{K-2}\\)]]

而在 tf.gather indices 将切片定义为 paramsaxis 维度,而在 tf.gather_nd 中,indices 将切片定义为 params 的第一个 N 维度,其中 N = indices.shape[-1]

indices 的最后一个维度最多可以是 params 的等级:

indices.shape[-1] <= params.rank

indices 的最后一个维度对应于 params 的维度 indices.shape[-1] 的元素(如果是 indices.shape[-1] == params.rank )或切片(如果是 indices.shape[-1] < params.rank )。输出张量具有形状

indices.shape[:-1] + params.shape[indices.shape[-1]:]

请注意,在 CPU 上,如果发现超出范围的索引,则会返回错误。在 GPU 上,如果发现超出范围的索引,则将 0 存储在相应的输出值中。

下面的一些例子。

对矩阵的简单索引:

indices = [[0, 0], [1, 1]]
    params = [['a', 'b'], ['c', 'd']]
    output = ['a', 'd']

将索引切片为矩阵:

indices = [[1], [0]]
    params = [['a', 'b'], ['c', 'd']]
    output = [['c', 'd'], ['a', 'b']]

索引到 3-张量:

indices = [[1]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [[['a1', 'b1'], ['c1', 'd1']]]


    indices = [[0, 1], [1, 0]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [['c0', 'd0'], ['a1', 'b1']]


    indices = [[0, 0, 1], [1, 0, 1]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = ['b0', 'b1']

批量索引到矩阵中:

indices = [[[0, 0]], [[0, 1]]]
    params = [['a', 'b'], ['c', 'd']]
    output = [['a'], ['b']]

批量切片索引到矩阵中:

indices = [[[1]], [[0]]]
    params = [['a', 'b'], ['c', 'd']]
    output = [[['c', 'd']], [['a', 'b']]]

批量索引到 3-张量:

indices = [[[1]], [[0]]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [[[['a1', 'b1'], ['c1', 'd1']]],
              [[['a0', 'b0'], ['c0', 'd0']]]]

    indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [[['c0', 'd0'], ['a1', 'b1']],
              [['a0', 'b0'], ['c1', 'd1']]]


    indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [['b0', 'b1'], ['d0', 'c1']]

另见 tf.gather tf.batch_gather

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.raw_ops.GatherNd。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。