當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch roll用法及代碼示例

本文簡要介紹python語言中 torch.roll 的用法。

用法:

torch.roll(input, shifts, dims=None) → Tensor

參數

  • input(Tensor) -輸入張量。

  • shifts(int或者python的元組:ints) -張量元素移動的位置數。如果 shifts 是元組,則 dims 必須是大小相同的元組,每個維度都會滾動對應的值

  • dims(int或者python的元組:ints) -滾動的軸

沿給定維度滾動張量。超出最後一個位置的元素在第一個位置重新引入。如果未指定尺寸,張量將在滾動前被展平,然後恢複到原始形狀。

例子:

>>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2)
>>> x
tensor([[1, 2],
        [3, 4],
        [5, 6],
        [7, 8]])
>>> torch.roll(x, 1, 0)
tensor([[7, 8],
        [1, 2],
        [3, 4],
        [5, 6]])
>>> torch.roll(x, -1, 0)
tensor([[3, 4],
        [5, 6],
        [7, 8],
        [1, 2]])
>>> torch.roll(x, shifts=(2, 1), dims=(0, 1))
tensor([[6, 5],
        [8, 7],
        [2, 1],
        [4, 3]])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.roll。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。