當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.gather_nd用法及代碼示例


params 中的切片收集到形狀由 indices 指定的張量中。

用法

tf.gather_nd(
    params, indices, batch_dims=0, name=None
)

參數

  • params 一個Tensor。從中收集值的張量。
  • indices 一個Tensor。必須是以下類型之一:int32 , int64。索引張量。
  • name 操作的名稱(可選)。
  • batch_dims 整數或標量 'Tensor'。批次維度的數量。

返回

  • 一個Tensor。具有與 params 相同的類型。

indicesTensor 索引的 params 。索引向量沿 indices 的最後一個軸排列。

這類似於 tf.gather ,其中 indices 將切片定義為 params 的第一維。在 tf.gather_nd 中,indices 將切片定義為 params 的第一個 N 維度,其中 N = indices.shape[-1]

警告:在 CPU 上,如果發現超出範圍的索引,則返回錯誤。在 GPU 上,如果發現超出範圍的索引,則將 0 存儲在相應的輸出值中。

收集標量

在最簡單的情況下,indices 中的向量索引 params 的完整等級:

tf.gather_nd(
    indices=[[0, 0],
             [1, 1]],
    params = [['a', 'b'],
              ['c', 'd']]).numpy()
array([b'a', b'd'], dtype=object)

在這種情況下,結果比 indices 少 1 軸,並且每個索引向量都被從 params 索引的標量替換。

在這種情況下,形狀關係為:

index_depth = indices.shape[-1]
assert index_depth == params.shape.rank
result_shape = indices.shape[:-1]

如果 indices 的秩為 K ,則將 indices 視為 params 的索引的 (K-1) 維張量會很有幫助。

收集切片

如果索引向量沒有索引 params 的完整等級,則結果中的每個位置都包含一個參數切片。此示例從矩陣中收集行:

tf.gather_nd(
    indices = [[1],
               [0]],
    params = [['a', 'b', 'c'],
              ['d', 'e', 'f']]).numpy()
array([[b'd', b'e', b'f'],
       [b'a', b'b', b'c']], dtype=object)

這裏 indices 包含 [2] 索引向量,每個索引向量的長度為 1 。每個索引向量都引用params 矩陣的行。每行的形狀為 [3] ,因此輸出形狀為 [2, 3]

在這種情況下,形狀之間的關係是:

index_depth = indices.shape[-1]
outer_shape = indices.shape[:-1]
assert index_depth <= params.shape.rank
inner_shape = params.shape[index_depth:]
output_shape = outer_shape + inner_shape

將這種情況下的結果視為tensors-of-tensors 會很有幫助。外部張量的形狀由 indices 的前導維度設置。而內張量的形狀是單片的形狀。

批次

此外,paramsindices 都可以具有完全匹配的 M 前導批次尺寸。在這種情況下 batch_dims 必須設置為 M

例如,要從一批矩陣中的每一個中收集一行,您可以將索引向量的前導元素設置為它們在批處理中的位置:

tf.gather_nd(
    indices = [[0, 1],
               [1, 0],
               [2, 4],
               [3, 2],
               [4, 1]],
    params=tf.zeros([5, 7, 3])).shape.as_list()
[5, 3]

batch_dims 參數允許您從索引中省略那些前導位置維度:

tf.gather_nd(
    batch_dims=1,
    indices = [[1],
               [0],
               [4],
               [2],
               [1]],
    params=tf.zeros([5, 7, 3])).shape.as_list()
[5, 3]

這相當於為批次維度中的每個位置校準一個單獨的gather_nd

params=tf.zeros([5, 7, 3])
indices=tf.zeros([5, 1])
batch_dims = 1

index_depth = indices.shape[-1]
batch_shape = indices.shape[:batch_dims]
assert params.shape[:batch_dims] == batch_shape
outer_shape = indices.shape[batch_dims:-1]
assert index_depth <= params.shape.rank
inner_shape = params.shape[batch_dims + index_depth:]
output_shape = batch_shape + outer_shape + inner_shape
output_shape.as_list()
[5, 3]

更多示例

索引到 3-張量:

tf.gather_nd(
    indices = [[1]],
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[[b'a1', b'b1'],
        [b'c1', b'd1']]], dtype=object)
tf.gather_nd(
    indices = [[0, 1], [1, 0]],
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[b'c0', b'd0'],
       [b'a1', b'b1']], dtype=object)
tf.gather_nd(
    indices = [[0, 0, 1], [1, 0, 1]],
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([b'b0', b'b1'], dtype=object)

以下示例適用於僅索引具有前導額外維度的情況。如果 'params' 和 'indices' 都具有前導批處理維度,請使用 'batch_dims' 參數以批處理模式運行 gather_nd。

批量索引到矩陣中:

tf.gather_nd(
    indices = [[[0, 0]], [[0, 1]]],
    params = [['a', 'b'], ['c', 'd']]).numpy()
array([[b'a'],
       [b'b']], dtype=object)

批量切片索引到矩陣中:

tf.gather_nd(
    indices = [[[1]], [[0]]],
    params = [['a', 'b'], ['c', 'd']]).numpy()
array([[[b'c', b'd']],
       [[b'a', b'b']]], dtype=object)

批量索引到 3-張量:

tf.gather_nd(
    indices = [[[1]], [[0]]],
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[[[b'a1', b'b1'],
         [b'c1', b'd1']]],
       [[[b'a0', b'b0'],
         [b'c0', b'd0']]]], dtype=object)
tf.gather_nd(
    indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[[b'c0', b'd0'],
        [b'a1', b'b1']],
       [[b'a0', b'b0'],
        [b'c1', b'd1']]], dtype=object)
tf.gather_nd(
    indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]],
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[b'b0', b'b1'],
       [b'd0', b'c1']], dtype=object)

批處理 'params' 和 'indices' 的示例:

tf.gather_nd(
    batch_dims = 1,
    indices = [[1],
               [0]],
    params = [[['a0', 'b0'],
               ['c0', 'd0']],
              [['a1', 'b1'],
               ['c1', 'd1']]]).numpy()
array([[b'c0', b'd0'],
       [b'a1', b'b1']], dtype=object)
tf.gather_nd(
    batch_dims = 1,
    indices = [[[1]], [[0]]],
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[[b'c0', b'd0']],
       [[b'a1', b'b1']]], dtype=object)
tf.gather_nd(
    batch_dims = 1,
    indices = [[[1, 0]], [[0, 1]]],
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[b'c0'],
       [b'b1']], dtype=object)

另見tf.gather

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.gather_nd。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。