當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch argsort用法及代碼示例


本文簡要介紹python語言中 torch.argsort 的用法。

用法:

torch.argsort(input, dim=- 1, descending=False) → LongTensor

參數

  • input(Tensor) -輸入張量。

  • dim(int,可選的) -要排序的維度

  • descending(bool,可選的) -控製排序順序(升序或降序)

返回沿給定維度按值升序對張量進行排序的索引。

這是 torch.sort() 返回的第二個值。有關此方法的確切語義,請參閱其文檔。

例子:

>>> a = torch.randn(4, 4)
>>> a
tensor([[ 0.0785,  1.5267, -0.8521,  0.4065],
        [ 0.1598,  0.0788, -0.0745, -1.2700],
        [ 1.2208,  1.0722, -0.7064,  1.2564],
        [ 0.0669, -0.2318, -0.8229, -0.9280]])


>>> torch.argsort(a, dim=1)
tensor([[2, 0, 3, 1],
        [3, 2, 1, 0],
        [2, 1, 0, 3],
        [3, 2, 1, 0]])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.argsort。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。