当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.gather_nd用法及代码示例


params 中的切片收集到形状由 indices 指定的张量中。

用法

tf.gather_nd(
    params, indices, batch_dims=0, name=None
)

参数

  • params 一个Tensor。从中收集值的张量。
  • indices 一个Tensor。必须是以下类型之一:int32 , int64。索引张量。
  • name 操作的名称(可选)。
  • batch_dims 整数或标量 'Tensor'。批次维度的数量。

返回

  • 一个Tensor。具有与 params 相同的类型。

indicesTensor 索引的 params 。索引向量沿 indices 的最后一个轴排列。

这类似于 tf.gather ,其中 indices 将切片定义为 params 的第一维。在 tf.gather_nd 中,indices 将切片定义为 params 的第一个 N 维度,其中 N = indices.shape[-1]

警告:在 CPU 上,如果发现超出范围的索引,则返回错误。在 GPU 上,如果发现超出范围的索引,则将 0 存储在相应的输出值中。

收集标量

在最简单的情况下,indices 中的向量索引 params 的完整等级:

tf.gather_nd(
    indices=[[0, 0],
             [1, 1]],
    params = [['a', 'b'],
              ['c', 'd']]).numpy()
array([b'a', b'd'], dtype=object)

在这种情况下,结果比 indices 少 1 轴,并且每个索引向量都被从 params 索引的标量替换。

在这种情况下,形状关系为:

index_depth = indices.shape[-1]
assert index_depth == params.shape.rank
result_shape = indices.shape[:-1]

如果 indices 的秩为 K ,则将 indices 视为 params 的索引的 (K-1) 维张量会很有帮助。

收集切片

如果索引向量没有索引 params 的完整等级,则结果中的每个位置都包含一个参数切片。此示例从矩阵中收集行:

tf.gather_nd(
    indices = [[1],
               [0]],
    params = [['a', 'b', 'c'],
              ['d', 'e', 'f']]).numpy()
array([[b'd', b'e', b'f'],
       [b'a', b'b', b'c']], dtype=object)

这里 indices 包含 [2] 索引向量,每个索引向量的长度为 1 。每个索引向量都引用params 矩阵的行。每行的形状为 [3] ,因此输出形状为 [2, 3]

在这种情况下,形状之间的关系是:

index_depth = indices.shape[-1]
outer_shape = indices.shape[:-1]
assert index_depth <= params.shape.rank
inner_shape = params.shape[index_depth:]
output_shape = outer_shape + inner_shape

将这种情况下的结果视为tensors-of-tensors 会很有帮助。外部张量的形状由 indices 的前导维度设置。而内张量的形状是单片的形状。

批次

此外,paramsindices 都可以具有完全匹配的 M 前导批次尺寸。在这种情况下 batch_dims 必须设置为 M

例如,要从一批矩阵中的每一个中收集一行,您可以将索引向量的前导元素设置为它们在批处理中的位置:

tf.gather_nd(
    indices = [[0, 1],
               [1, 0],
               [2, 4],
               [3, 2],
               [4, 1]],
    params=tf.zeros([5, 7, 3])).shape.as_list()
[5, 3]

batch_dims 参数允许您从索引中省略那些前导位置维度:

tf.gather_nd(
    batch_dims=1,
    indices = [[1],
               [0],
               [4],
               [2],
               [1]],
    params=tf.zeros([5, 7, 3])).shape.as_list()
[5, 3]

这相当于为批次维度中的每个位置校准一个单独的gather_nd

params=tf.zeros([5, 7, 3])
indices=tf.zeros([5, 1])
batch_dims = 1

index_depth = indices.shape[-1]
batch_shape = indices.shape[:batch_dims]
assert params.shape[:batch_dims] == batch_shape
outer_shape = indices.shape[batch_dims:-1]
assert index_depth <= params.shape.rank
inner_shape = params.shape[batch_dims + index_depth:]
output_shape = batch_shape + outer_shape + inner_shape
output_shape.as_list()
[5, 3]

更多示例

索引到 3-张量:

tf.gather_nd(
    indices = [[1]],
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[[b'a1', b'b1'],
        [b'c1', b'd1']]], dtype=object)
tf.gather_nd(
    indices = [[0, 1], [1, 0]],
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[b'c0', b'd0'],
       [b'a1', b'b1']], dtype=object)
tf.gather_nd(
    indices = [[0, 0, 1], [1, 0, 1]],
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([b'b0', b'b1'], dtype=object)

以下示例适用于仅索引具有前导额外维度的情况。如果 'params' 和 'indices' 都具有前导批处理维度,请使用 'batch_dims' 参数以批处理模式运行 gather_nd。

批量索引到矩阵中:

tf.gather_nd(
    indices = [[[0, 0]], [[0, 1]]],
    params = [['a', 'b'], ['c', 'd']]).numpy()
array([[b'a'],
       [b'b']], dtype=object)

批量切片索引到矩阵中:

tf.gather_nd(
    indices = [[[1]], [[0]]],
    params = [['a', 'b'], ['c', 'd']]).numpy()
array([[[b'c', b'd']],
       [[b'a', b'b']]], dtype=object)

批量索引到 3-张量:

tf.gather_nd(
    indices = [[[1]], [[0]]],
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[[[b'a1', b'b1'],
         [b'c1', b'd1']]],
       [[[b'a0', b'b0'],
         [b'c0', b'd0']]]], dtype=object)
tf.gather_nd(
    indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[[b'c0', b'd0'],
        [b'a1', b'b1']],
       [[b'a0', b'b0'],
        [b'c1', b'd1']]], dtype=object)
tf.gather_nd(
    indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]],
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[b'b0', b'b1'],
       [b'd0', b'c1']], dtype=object)

批处理 'params' 和 'indices' 的示例:

tf.gather_nd(
    batch_dims = 1,
    indices = [[1],
               [0]],
    params = [[['a0', 'b0'],
               ['c0', 'd0']],
              [['a1', 'b1'],
               ['c1', 'd1']]]).numpy()
array([[b'c0', b'd0'],
       [b'a1', b'b1']], dtype=object)
tf.gather_nd(
    batch_dims = 1,
    indices = [[[1]], [[0]]],
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[[b'c0', b'd0']],
       [[b'a1', b'b1']]], dtype=object)
tf.gather_nd(
    batch_dims = 1,
    indices = [[[1, 0]], [[0, 1]]],
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[b'c0'],
       [b'b1']], dtype=object)

另见tf.gather

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.gather_nd。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。