當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch hsplit用法及代碼示例


本文簡要介紹python語言中 torch.hsplit 的用法。

用法:

torch.hsplit(input, indices_or_sections) → List of Tensors

參數

根據 indices_or_sections 將具有一維或多維的張量 input 水平拆分為多個張量。每個拆分都是 input 的一個視圖。

如果 input 是一維的,這相當於調用 torch.tensor_split(input, indices_or_sections, dim=0) (拆分維度為零),如果 input 有兩個或更多維度,則相當於調用torch.tensor_split(input, indices_or_sections, dim=1) (拆分維度為 1),但如果 indices_or_sections 為整數,則必須將拆分維度均分,否則將引發運行時錯誤。

該函數基於 NumPy 的 numpy.hsplit()

例子:

>>> t = torch.arange(16.0).reshape(4,4)
>>> t
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])
>>> torch.hsplit(t, 2)
(tensor([[ 0.,  1.],
         [ 4.,  5.],
         [ 8.,  9.],
         [12., 13.]]),
 tensor([[ 2.,  3.],
         [ 6.,  7.],
         [10., 11.],
         [14., 15.]]))
>>> torch.hsplit(t, [3, 6])
(tensor([[ 0.,  1.,  2.],
         [ 4.,  5.,  6.],
         [ 8.,  9., 10.],
         [12., 13., 14.]]),
 tensor([[ 3.],
         [ 7.],
         [11.],
         [15.]]),
 tensor([], size=(4, 0)))

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.hsplit。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。