當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch dsplit用法及代碼示例

本文簡要介紹python語言中 torch.dsplit 的用法。

用法:

torch.dsplit(input, indices_or_sections) → List of Tensors

參數

根據 indices_or_sections 將具有三個或更多維度的張量 input 拆分為多個深度方向的張量。每個拆分都是 input 的一個視圖。

這等效於調用 torch.tensor_split(input, indices_or_sections, dim=2) (拆分維度為 1),除了如果 indices_or_sections 是整數它必須將拆分維度均分,否則運行時錯誤將被拋出。

該函數基於 NumPy 的 numpy.dsplit()

例子:

>>> t = torch.arange(16.0).reshape(2, 2, 4)
>>> t
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.]],
        [[ 8.,  9., 10., 11.],
         [12., 13., 14., 15.]]])
>>> torch.dsplit(t, 2)
(tensor([[[ 0.,  1.],
        [ 4.,  5.]],
       [[ 8.,  9.],
        [12., 13.]]]),
 tensor([[[ 2.,  3.],
          [ 6.,  7.]],
         [[10., 11.],
          [14., 15.]]]))
>>> torch.dsplit(t, [3, 6])
(tensor([[[ 0.,  1.,  2.],
          [ 4.,  5.,  6.]],
         [[ 8.,  9., 10.],
          [12., 13., 14.]]]),
 tensor([[[ 3.],
          [ 7.]],
         [[11.],
          [15.]]]),
 tensor([], size=(2, 2, 0)))

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.dsplit。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。