當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.argsort用法及代碼示例


返回張量的索引,該索引給出其沿軸的排序順序。

用法

tf.argsort(
    values, axis=-1, direction='ASCENDING', stable=False, name=None
)

參數

  • values 一維或更高數字 Tensor.
  • axis 要排序的軸。默認值為 -1,對最後一個軸進行排序。
  • direction 對值進行排序的方向('ASCENDING''DESCENDING')。
  • stable 如果為 True,則原始張量中的相等元素將不會按照返回的順序重新排序。不穩定排序尚未實現,但出於性能原因最終將成為默認排序。如果您需要穩定的訂單,請傳遞stable=True 以獲得向前兼容性。
  • name 操作的可選名稱。

返回

  • values 具有相同形狀的 int32 Tensor 。將沿給定 axis 對給定 values 的每個切片進行排序的索引。

拋出

  • ValueError 如果axis不是一個常數標量,或者方向無效。
  • tf.errors.InvalidArgumentError 如果 values.dtype 不是 floatint 類型。
values = [1, 10, 26.9, 2.8, 166.32, 62.3]
sort_order = tf.argsort(values)
sort_order.numpy()
array([0, 3, 1, 2, 5, 4], dtype=int32)

對於一維張量:

sorted = tf.gather(values, sort_order)
assert tf.reduce_all(sorted == tf.sort(values))

對於更高維度,輸出具有與 values 相同的形狀,但沿給定軸,值表示給定位置的張量切片中已排序元素的索引。

mat = [[30,20,10],
       [20,10,30],
       [10,30,20]]
indices = tf.argsort(mat)
indices.numpy()
array([[2, 1, 0],
       [1, 0, 2],
       [0, 2, 1]], dtype=int32)

如果 axis=-1 這些索引可用於使用 tf.gather 應用排序:

tf.gather(mat, indices, batch_dims=-1).numpy()
array([[10, 20, 30],
       [10, 20, 30],
       [10, 20, 30]], dtype=int32)

也可以看看:

  • tf.sort :沿軸排序。
  • tf.math.top_k :返回固定數量的頂部值和相應索引的部分排序。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.argsort。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。