将 params
中的切片收集到形状由 indices
指定的张量中。
用法
tf.gather_nd(
params, indices, batch_dims=0, name=None
)
参数
-
params
一个Tensor
。从中收集值的张量。 -
indices
一个Tensor
。必须是以下类型之一:int32
,int64
。索引张量。 -
name
操作的名称(可选)。 -
batch_dims
整数或标量 'Tensor'。批次维度的数量。
返回
-
一个
Tensor
。具有与params
相同的类型。
indices
是 Tensor
索引的 params
。索引向量沿 indices
的最后一个轴排列。
这类似于 tf.gather
,其中 indices
将切片定义为 params
的第一维。在 tf.gather_nd
中,indices
将切片定义为 params
的第一个 N
维度,其中 N = indices.shape[-1]
。
警告:在 CPU 上,如果发现超出范围的索引,则返回错误。在 GPU 上,如果发现超出范围的索引,则将 0 存储在相应的输出值中。
收集标量
在最简单的情况下,indices
中的向量索引 params
的完整等级:
tf.gather_nd(
indices=[[0, 0],
[1, 1]],
params = [['a', 'b'],
['c', 'd']]).numpy()
array([b'a', b'd'], dtype=object)
在这种情况下,结果比 indices
少 1 轴,并且每个索引向量都被从 params
索引的标量替换。
在这种情况下,形状关系为:
index_depth = indices.shape[-1]
assert index_depth == params.shape.rank
result_shape = indices.shape[:-1]
如果 indices
的秩为 K
,则将 indices
视为 params
的索引的 (K-1) 维张量会很有帮助。
收集切片
如果索引向量没有索引 params
的完整等级,则结果中的每个位置都包含一个参数切片。此示例从矩阵中收集行:
tf.gather_nd(
indices = [[1],
[0]],
params = [['a', 'b', 'c'],
['d', 'e', 'f']]).numpy()
array([[b'd', b'e', b'f'],
[b'a', b'b', b'c']], dtype=object)
这里 indices
包含 [2]
索引向量,每个索引向量的长度为 1
。每个索引向量都引用params
矩阵的行。每行的形状为 [3]
,因此输出形状为 [2, 3]
。
在这种情况下,形状之间的关系是:
index_depth = indices.shape[-1]
outer_shape = indices.shape[:-1]
assert index_depth <= params.shape.rank
inner_shape = params.shape[index_depth:]
output_shape = outer_shape + inner_shape
将这种情况下的结果视为tensors-of-tensors 会很有帮助。外部张量的形状由 indices
的前导维度设置。而内张量的形状是单片的形状。
批次
此外,params
和 indices
都可以具有完全匹配的 M
前导批次尺寸。在这种情况下 batch_dims
必须设置为 M
。
例如,要从一批矩阵中的每一个中收集一行,您可以将索引向量的前导元素设置为它们在批处理中的位置:
tf.gather_nd(
indices = [[0, 1],
[1, 0],
[2, 4],
[3, 2],
[4, 1]],
params=tf.zeros([5, 7, 3])).shape.as_list()
[5, 3]
batch_dims
参数允许您从索引中省略那些前导位置维度:
tf.gather_nd(
batch_dims=1,
indices = [[1],
[0],
[4],
[2],
[1]],
params=tf.zeros([5, 7, 3])).shape.as_list()
[5, 3]
这相当于为批次维度中的每个位置校准一个单独的gather_nd
。
params=tf.zeros([5, 7, 3])
indices=tf.zeros([5, 1])
batch_dims = 1
index_depth = indices.shape[-1]
batch_shape = indices.shape[:batch_dims]
assert params.shape[:batch_dims] == batch_shape
outer_shape = indices.shape[batch_dims:-1]
assert index_depth <= params.shape.rank
inner_shape = params.shape[batch_dims + index_depth:]
output_shape = batch_shape + outer_shape + inner_shape
output_shape.as_list()
[5, 3]
更多示例
索引到 3-张量:
tf.gather_nd(
indices = [[1]],
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[[b'a1', b'b1'],
[b'c1', b'd1']]], dtype=object)
tf.gather_nd(
indices = [[0, 1], [1, 0]],
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[b'c0', b'd0'],
[b'a1', b'b1']], dtype=object)
tf.gather_nd(
indices = [[0, 0, 1], [1, 0, 1]],
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([b'b0', b'b1'], dtype=object)
以下示例适用于仅索引具有前导额外维度的情况。如果 'params' 和 'indices' 都具有前导批处理维度,请使用 'batch_dims' 参数以批处理模式运行 gather_nd。
批量索引到矩阵中:
tf.gather_nd(
indices = [[[0, 0]], [[0, 1]]],
params = [['a', 'b'], ['c', 'd']]).numpy()
array([[b'a'],
[b'b']], dtype=object)
批量切片索引到矩阵中:
tf.gather_nd(
indices = [[[1]], [[0]]],
params = [['a', 'b'], ['c', 'd']]).numpy()
array([[[b'c', b'd']],
[[b'a', b'b']]], dtype=object)
批量索引到 3-张量:
tf.gather_nd(
indices = [[[1]], [[0]]],
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[[[b'a1', b'b1'],
[b'c1', b'd1']]],
[[[b'a0', b'b0'],
[b'c0', b'd0']]]], dtype=object)
tf.gather_nd(
indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[[b'c0', b'd0'],
[b'a1', b'b1']],
[[b'a0', b'b0'],
[b'c1', b'd1']]], dtype=object)
tf.gather_nd(
indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]],
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[b'b0', b'b1'],
[b'd0', b'c1']], dtype=object)
批处理 'params' 和 'indices' 的示例:
tf.gather_nd(
batch_dims = 1,
indices = [[1],
[0]],
params = [[['a0', 'b0'],
['c0', 'd0']],
[['a1', 'b1'],
['c1', 'd1']]]).numpy()
array([[b'c0', b'd0'],
[b'a1', b'b1']], dtype=object)
tf.gather_nd(
batch_dims = 1,
indices = [[[1]], [[0]]],
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[[b'c0', b'd0']],
[[b'a1', b'b1']]], dtype=object)
tf.gather_nd(
batch_dims = 1,
indices = [[[1, 0]], [[0, 1]]],
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]).numpy()
array([[b'c0'],
[b'b1']], dtype=object)
另见tf.gather
。
相关用法
- Python tf.gather用法及代码示例
- Python tf.grad_pass_through用法及代码示例
- Python tf.group用法及代码示例
- Python tf.get_current_name_scope用法及代码示例
- Python tf.gradients用法及代码示例
- Python tf.get_static_value用法及代码示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代码示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_distribute_dataset用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.interleave用法及代码示例
- Python tf.summary.scalar用法及代码示例
- Python tf.linalg.LinearOperatorFullMatrix.matvec用法及代码示例
- Python tf.linalg.LinearOperatorToeplitz.solve用法及代码示例
- Python tf.raw_ops.TPUReplicatedInput用法及代码示例
- Python tf.raw_ops.Bitcast用法及代码示例
- Python tf.compat.v1.distributions.Bernoulli.cross_entropy用法及代码示例
- Python tf.compat.v1.Variable.eval用法及代码示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代码示例
- Python tf.distribute.OneDeviceStrategy.experimental_distribute_values_from_function用法及代码示例
- Python tf.math.special.fresnel_cos用法及代码示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.gather_nd。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。