當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch flatten用法及代碼示例


本文簡要介紹python語言中 torch.flatten 的用法。

用法:

torch.flatten(input, start_dim=0, end_dim=- 1) → Tensor

參數

  • input(Tensor) -輸入張量。

  • start_dim(int) -第一個要展平的dim

  • end_dim(int) -最後一個dim要壓平

通過將 input 重塑為一維張量來展平它。如果 start_dimend_dim 已通過,則隻有以 start_dim 開頭並以 end_dim 結尾的尺寸會被展平。 input 中的元素順序保持不變。

與 NumPy 的 flatten 不同,它總是複製輸入的數據,這個函數可以返回原始對象、視圖或副本。如果沒有展平尺寸,則返回原始對象input。否則,如果輸入可以被視為展平形狀,則返回該視圖。最後,隻有當輸入不能被視為扁平形狀時,才會複製輸入的數據。有關何時返回視圖的詳細信息,請參閱 torch.Tensor.view()

注意

展平零維張量將返回一維視圖。

例子:

>>> t = torch.tensor([[[1, 2],
...                    [3, 4]],
...                   [[5, 6],
...                    [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.flatten。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。