當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.math.top_k用法及代碼示例


查找最後一個維度的 k 最大條目的值和索引。

用法

tf.math.top_k(
    input, k=1, sorted=True, name=None
)

參數

  • input 一維或更高的Tensor,最後一維至少為k
  • k 0-D int32 Tensor 。沿最後一個維度(沿矩陣的每一行)查找的頂部元素的數量。
  • sorted 如果為 true,則生成的 k 元素將按值降序排序。
  • name 操作的可選名稱。

返回

  • 具有兩個命名字段的元組:
  • values k 沿每個最後一個維度切片的最大元素。
  • indices valuesinput 的最後一個維度內的索引。

如果輸入是向量 (rank=1),則在向量中找到 k 最大的條目,並將它們的值和索引作為向量輸出。因此 values[j]j - input 中最大的條目,其索引是 indices[j]

result = tf.math.top_k([1, 2, 98, 1, 1, 99, 3, 1, 3, 96, 4, 1],
                        k=3)
result.values.numpy()
array([99, 98, 96], dtype=int32)
result.indices.numpy()
array([5, 2, 9], dtype=int32)

對於矩陣(分別是更高等級的輸入),計算每行中的頂部 k 條目(分別是沿最後一個維度的向量)。因此,

input = tf.random.normal(shape=(3,4,5,6))
k = 2
values, indices  = tf.math.top_k(input, k=k)
values.shape.as_list()
[3, 4, 5, 2]

values.shape == indices.shape == input.shape[:-1] + [k]
True

這些索引可用於 gather 來自形狀與 input 匹配的張量。

gathered_values = tf.gather(input, indices, batch_dims=-1)
assert tf.reduce_all(gathered_values == values)

如果兩個元素相等,則首先出現lower-index 元素。

result = tf.math.top_k([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0],
                       k=3)
result.indices.numpy()
array([0, 1, 3], dtype=int32)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.math.top_k。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。