當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch split用法及代碼示例


本文簡要介紹python語言中 torch.split 的用法。

用法:

torch.split(tensor, split_size_or_sections, dim=0)

參數

  • tensor(Tensor) -張量分裂。

  • split_size_or_sections(int) 或者(list(int)) -單個塊的大小或每個塊的大小列表

  • dim(int) -沿其分割張量的維度。

將張量拆分為塊。每個塊都是原始張量的視圖。

如果split_size_or_sections 是整數類型,那麽 tensor 將被分成大小相等的塊(如果可能)。如果沿給定維度 dim 的張量大小不能被 split_size 整除,則最後一個塊將更小。

如果 split_size_or_sections 是一個列表,那麽 tensor 將根據 split_size_or_sections 被拆分為大小在 dim 中的 len(split_size_or_sections) 塊。

例子:

>>> a = torch.arange(10).reshape(5,2)
>>> a
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
         [2, 3]]),
 tensor([[4, 5],
         [6, 7]]),
 tensor([[8, 9]]))
>>> torch.split(a, [1,4])
(tensor([[0, 1]]),
 tensor([[2, 3],
         [4, 5],
         [6, 7],
         [8, 9]]))

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.split。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。