將 params
中的切片收集到形狀由 indices
指定的張量中。
用法
tf.gather_nd(
params, indices, batch_dims=0, name=None
)
參數
-
params
一個Tensor
。從中收集值的張量。 -
indices
一個Tensor
。必須是以下類型之一:int32
,int64
。索引張量。 -
name
操作的名稱(可選)。 -
batch_dims
整數或標量 'Tensor'。批次維度的數量。
返回
-
一個
Tensor
。具有與params
相同的類型。
indices
是 Tensor
索引的 params
。索引向量沿 indices
的最後一個軸排列。
這類似於 tf.gather
,其中 indices
將切片定義為 params
的第一維。在 tf.gather_nd
中,indices
將切片定義為 params
的第一個 N
維度,其中 N = indices.shape[-1]
。
警告:在 CPU 上,如果發現超出範圍的索引,則返回錯誤。在 GPU 上,如果發現超出範圍的索引,則將 0 存儲在相應的輸出值中。
收集標量
在最簡單的情況下,indices
中的向量索引 params
的完整等級:
tf.gather_nd(
indices=[[0, 0],
[1, 1]],
params = [['a', 'b'],
['c', 'd']]).numpy()
array([b'a', b'd'], dtype=object)
在這種情況下,結果比 indices
少 1 軸,並且每個索引向量都被從 params
索引的標量替換。
在這種情況下,形狀關係為:
index_depth = indices.shape[-1]
assert index_depth == params.shape.rank
result_shape = indices.shape[:-1]
如果 indices
的秩為 K
,則將 indices
視為 params
的索引的 (K-1) 維張量會很有幫助。
收集切片
如果索引向量沒有索引 params
的完整等級,則結果中的每個位置都包含一個參數切片。此示例從矩陣中收集行:
tf.gather_nd(
indices = [[1],
[0]],
params = [['a', 'b', 'c'],
['d', 'e', 'f']]).numpy()
array([[b'd', b'e', b'f'],
[b'a', b'b', b'c']], dtype=object)
這裏 indices
包含 [2]
索引向量,每個索引向量的長度為 1
。每個索引向量都引用params
矩陣的行。每行的形狀為 [3]
,因此輸出形狀為 [2, 3]
。
在這種情況下,形狀之間的關係是:
index_depth = indices.shape[-1]
outer_shape = indices.shape[:-1]
assert index_depth <= params.shape.rank
inner_shape = params.shape[index_depth:]
output_shape = outer_shape + inner_shape
將這種情況下的結果視為tensors-of-tensors 會很有幫助。外部張量的形狀由 indices
的前導維度設置。而內張量的形狀是單片的形狀。
批次
此外,params
和 indices
都可以具有完全匹配的 M
前導批次尺寸。在這種情況下 batch_dims
必須設置為 M
。
例如,要從一批矩陣中的每一個中收集一行,您可以將索引向量的前導元素設置為它們在批處理中的位置:
tf.gather_nd(
indices = [[0, 1],
[1, 0],
[2, 4],
[3, 2],
[4, 1]],
params=tf.zeros([5, 7, 3])).shape.as_list()
[5, 3]
batch_dims
參數允許您從索引中省略那些前導位置維度:
tf.gather_nd(
batch_dims=1,
indices = [[1],
[0],
[4],
[2],
[1]],
params=tf.zeros([5, 7, 3])).shape.as_list()
[5, 3]
這相當於為批次維度中的每個位置校準一個單獨的gather_nd
。
params=tf.zeros([5, 7, 3])
indices=tf.zeros([5, 1])
batch_dims = 1
index_depth = indices.shape[-1]
batch_shape = indices.shape[:batch_dims]
assert params.shape[:batch_dims] == batch_shape
outer_shape = indices.shape[batch_dims:-1]
assert index_depth <= params.shape.rank
inner_shape = params.shape[batch_dims + index_depth:]
output_shape = batch_shape + outer_shape + inner_shape
output_shape.as_list()
[5, 3]
更多示例
索引到 3-張量:
tf.gather_nd(
indices = [[1]],
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[[b'a1', b'b1'],
[b'c1', b'd1']]], dtype=object)
tf.gather_nd(
indices = [[0, 1], [1, 0]],
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[b'c0', b'd0'],
[b'a1', b'b1']], dtype=object)
tf.gather_nd(
indices = [[0, 0, 1], [1, 0, 1]],
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([b'b0', b'b1'], dtype=object)
以下示例適用於僅索引具有前導額外維度的情況。如果 'params' 和 'indices' 都具有前導批處理維度,請使用 'batch_dims' 參數以批處理模式運行 gather_nd。
批量索引到矩陣中:
tf.gather_nd(
indices = [[[0, 0]], [[0, 1]]],
params = [['a', 'b'], ['c', 'd']]).numpy()
array([[b'a'],
[b'b']], dtype=object)
批量切片索引到矩陣中:
tf.gather_nd(
indices = [[[1]], [[0]]],
params = [['a', 'b'], ['c', 'd']]).numpy()
array([[[b'c', b'd']],
[[b'a', b'b']]], dtype=object)
批量索引到 3-張量:
tf.gather_nd(
indices = [[[1]], [[0]]],
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[[[b'a1', b'b1'],
[b'c1', b'd1']]],
[[[b'a0', b'b0'],
[b'c0', b'd0']]]], dtype=object)
tf.gather_nd(
indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[[b'c0', b'd0'],
[b'a1', b'b1']],
[[b'a0', b'b0'],
[b'c1', b'd1']]], dtype=object)
tf.gather_nd(
indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]],
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[b'b0', b'b1'],
[b'd0', b'c1']], dtype=object)
批處理 'params' 和 'indices' 的示例:
tf.gather_nd(
batch_dims = 1,
indices = [[1],
[0]],
params = [[['a0', 'b0'],
['c0', 'd0']],
[['a1', 'b1'],
['c1', 'd1']]]).numpy()
array([[b'c0', b'd0'],
[b'a1', b'b1']], dtype=object)
tf.gather_nd(
batch_dims = 1,
indices = [[[1]], [[0]]],
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[[b'c0', b'd0']],
[[b'a1', b'b1']]], dtype=object)
tf.gather_nd(
batch_dims = 1,
indices = [[[1, 0]], [[0, 1]]],
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[b'c0'],
[b'b1']], dtype=object)
另見tf.gather
。
相關用法
- Python tf.gather用法及代碼示例
- Python tf.grad_pass_through用法及代碼示例
- Python tf.group用法及代碼示例
- Python tf.get_current_name_scope用法及代碼示例
- Python tf.gradients用法及代碼示例
- Python tf.get_static_value用法及代碼示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代碼示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_distribute_dataset用法及代碼示例
- Python tf.compat.v1.data.TFRecordDataset.interleave用法及代碼示例
- Python tf.summary.scalar用法及代碼示例
- Python tf.linalg.LinearOperatorFullMatrix.matvec用法及代碼示例
- Python tf.linalg.LinearOperatorToeplitz.solve用法及代碼示例
- Python tf.raw_ops.TPUReplicatedInput用法及代碼示例
- Python tf.raw_ops.Bitcast用法及代碼示例
- Python tf.compat.v1.distributions.Bernoulli.cross_entropy用法及代碼示例
- Python tf.compat.v1.Variable.eval用法及代碼示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代碼示例
- Python tf.distribute.OneDeviceStrategy.experimental_distribute_values_from_function用法及代碼示例
- Python tf.math.special.fresnel_cos用法及代碼示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.gather_nd。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。