當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.raw_ops.GatherNd用法及代碼示例


params 中的切片收集到形狀由 indices 指定的張量中。

用法

tf.raw_ops.GatherNd(
    params, indices, name=None
)

參數

  • params 一個Tensor。從中收集值的張量。
  • indices 一個Tensor。必須是以下類型之一:int32int64。索引張量。
  • name 操作的名稱(可選)。

返回

  • 一個Tensor。具有與 params 相同的類型。

indices 是一個 K-dimensional 整數張量,最好將其視為 params 索引的 (K-1) 維張量,其中每個元素定義一個 params 切片:

output[\\(i_0, ..., i_{K-2}\\)] = params[indices[\\(i_0, ..., i_{K-2}\\)]]

而在 tf.gather indices 將切片定義為 paramsaxis 維度,而在 tf.gather_nd 中,indices 將切片定義為 params 的第一個 N 維度,其中 N = indices.shape[-1]

indices 的最後一個維度最多可以是 params 的等級:

indices.shape[-1] <= params.rank

indices 的最後一個維度對應於 params 的維度 indices.shape[-1] 的元素(如果是 indices.shape[-1] == params.rank )或切片(如果是 indices.shape[-1] < params.rank )。輸出張量具有形狀

indices.shape[:-1] + params.shape[indices.shape[-1]:]

請注意,在 CPU 上,如果發現超出範圍的索引,則會返回錯誤。在 GPU 上,如果發現超出範圍的索引,則將 0 存儲在相應的輸出值中。

下麵的一些例子。

對矩陣的簡單索引:

indices = [[0, 0], [1, 1]]
    params = [['a', 'b'], ['c', 'd']]
    output = ['a', 'd']

將索引切片為矩陣:

indices = [[1], [0]]
    params = [['a', 'b'], ['c', 'd']]
    output = [['c', 'd'], ['a', 'b']]

索引到 3-張量:

indices = [[1]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [[['a1', 'b1'], ['c1', 'd1']]]


    indices = [[0, 1], [1, 0]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [['c0', 'd0'], ['a1', 'b1']]


    indices = [[0, 0, 1], [1, 0, 1]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = ['b0', 'b1']

批量索引到矩陣中:

indices = [[[0, 0]], [[0, 1]]]
    params = [['a', 'b'], ['c', 'd']]
    output = [['a'], ['b']]

批量切片索引到矩陣中:

indices = [[[1]], [[0]]]
    params = [['a', 'b'], ['c', 'd']]
    output = [[['c', 'd']], [['a', 'b']]]

批量索引到 3-張量:

indices = [[[1]], [[0]]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [[[['a1', 'b1'], ['c1', 'd1']]],
              [[['a0', 'b0'], ['c0', 'd0']]]]

    indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [[['c0', 'd0'], ['a1', 'b1']],
              [['a0', 'b0'], ['c1', 'd1']]]


    indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [['b0', 'b1'], ['d0', 'c1']]

另見 tf.gather tf.batch_gather

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.raw_ops.GatherNd。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。