当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.math.top_k用法及代码示例


查找最后一个维度的 k 最大条目的值和索引。

用法

tf.math.top_k(
    input, k=1, sorted=True, name=None
)

参数

  • input 一维或更高的Tensor,最后一维至少为k
  • k 0-D int32 Tensor 。沿最后一个维度(沿矩阵的每一行)查找的顶部元素的数量。
  • sorted 如果为 true,则生成的 k 元素将按值降序排序。
  • name 操作的可选名称。

返回

  • 具有两个命名字段的元组:
  • values k 沿每个最后一个维度切片的最大元素。
  • indices valuesinput 的最后一个维度内的索引。

如果输入是向量 (rank=1),则在向量中找到 k 最大的条目,并将它们的值和索引作为向量输出。因此 values[j]j - input 中最大的条目,其索引是 indices[j]

result = tf.math.top_k([1, 2, 98, 1, 1, 99, 3, 1, 3, 96, 4, 1],
                        k=3)
result.values.numpy()
array([99, 98, 96], dtype=int32)
result.indices.numpy()
array([5, 2, 9], dtype=int32)

对于矩阵(分别是更高等级的输入),计算每行中的顶部 k 条目(分别是沿最后一个维度的向量)。因此,

input = tf.random.normal(shape=(3,4,5,6))
k = 2
values, indices  = tf.math.top_k(input, k=k)
values.shape.as_list()
[3, 4, 5, 2]

values.shape == indices.shape == input.shape[:-1] + [k]
True

这些索引可用于 gather 来自形状与 input 匹配的张量。

gathered_values = tf.gather(input, indices, batch_dims=-1)
assert tf.reduce_all(gathered_values == values)

如果两个元素相等,则首先出现lower-index 元素。

result = tf.math.top_k([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0],
                       k=3)
result.indices.numpy()
array([0, 1, 3], dtype=int32)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.math.top_k。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。