当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.argsort用法及代码示例


返回张量的索引,该索引给出其沿轴的排序顺序。

用法

tf.argsort(
    values, axis=-1, direction='ASCENDING', stable=False, name=None
)

参数

  • values 一维或更高数字 Tensor.
  • axis 要排序的轴。默认值为 -1,对最后一个轴进行排序。
  • direction 对值进行排序的方向('ASCENDING''DESCENDING')。
  • stable 如果为 True,则原始张量中的相等元素将不会按照返回的顺序重新排序。不稳定排序尚未实现,但出于性能原因最终将成为默认排序。如果您需要稳定的订单,请传递stable=True 以获得向前兼容性。
  • name 操作的可选名称。

返回

  • values 具有相同形状的 int32 Tensor 。将沿给定 axis 对给定 values 的每个切片进行排序的索引。

抛出

  • ValueError 如果axis不是一个常数标量,或者方向无效。
  • tf.errors.InvalidArgumentError 如果 values.dtype 不是 floatint 类型。
values = [1, 10, 26.9, 2.8, 166.32, 62.3]
sort_order = tf.argsort(values)
sort_order.numpy()
array([0, 3, 1, 2, 5, 4], dtype=int32)

对于一维张量:

sorted = tf.gather(values, sort_order)
assert tf.reduce_all(sorted == tf.sort(values))

对于更高维度,输出具有与 values 相同的形状,但沿给定轴,值表示给定位置的张量切片中已排序元素的索引。

mat = [[30,20,10],
       [20,10,30],
       [10,30,20]]
indices = tf.argsort(mat)
indices.numpy()
array([[2, 1, 0],
       [1, 0, 2],
       [0, 2, 1]], dtype=int32)

如果 axis=-1 这些索引可用于使用 tf.gather 应用排序:

tf.gather(mat, indices, batch_dims=-1).numpy()
array([[10, 20, 30],
       [10, 20, 30],
       [10, 20, 30]], dtype=int32)

也可以看看:

  • tf.sort :沿轴排序。
  • tf.math.top_k :返回固定数量的顶部值和相应索引的部分排序。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.argsort。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。