當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch sort用法及代碼示例


本文簡要介紹python語言中 torch.sort 的用法。

用法:

torch.sort(input, dim=- 1, descending=False, stable=False, *, out=None)

參數

  • input(Tensor) -輸入張量。

  • dim(int,可選的) -要排序的維度

  • descending(bool,可選的) -控製排序順序(升序或降序)

  • stable(bool,可選的) -使排序例程穩定,從而保證保留等效元素的順序。

關鍵字參數

out(tuple,可選的) -( Tensor , LongTensor ) 的輸出元組,可以選擇用作輸出緩衝區

沿給定維度按值升序對input 張量的元素進行排序。

如果沒有給出dim,則選擇input的最後一個維度。

如果descendingTrue,則元素按值降序排序。

如果stableTrue,則排序例程變得穩定,保留等效元素的順序。

返回 (values, indices) 的命名元組,其中 values 是排序後的值,indices 是原始 input 張量中元素的索引。

例子:

>>> x = torch.randn(3, 4)
>>> sorted, indices = torch.sort(x)
>>> sorted
tensor([[-0.2162,  0.0608,  0.6719,  2.3332],
        [-0.5793,  0.0061,  0.6058,  0.9497],
        [-0.5071,  0.3343,  0.9553,  1.0960]])
>>> indices
tensor([[ 1,  0,  2,  3],
        [ 3,  1,  0,  2],
        [ 0,  3,  1,  2]])

>>> sorted, indices = torch.sort(x, 0)
>>> sorted
tensor([[-0.5071, -0.2162,  0.6719, -0.5793],
        [ 0.0608,  0.0061,  0.9497,  0.3343],
        [ 0.6058,  0.9553,  1.0960,  2.3332]])
>>> indices
tensor([[ 2,  0,  0,  1],
        [ 0,  1,  1,  2],
        [ 1,  2,  2,  0]])
>>> x = torch.tensor([0, 1] * 9)
>>> x.sort()
torch.return_types.sort(
    values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
    indices=tensor([ 2, 16,  4,  6, 14,  8,  0, 10, 12,  9, 17, 15, 13, 11,  7,  5,  3,  1]))
>>> x.sort(stable=True)
torch.return_types.sort(
    values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
    indices=tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16,  1,  3,  5,  7,  9, 11, 13, 15, 17]))

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.sort。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。