當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch tensor_split用法及代碼示例


本文簡要介紹python語言中 torch.tensor_split 的用法。

用法:

torch.tensor_split(input, indices_or_sections, dim=0) → List of Tensors

參數

  • input(Tensor) -要分裂的張量

  • indices_or_sections(Tensor,int或者list或者python的元組:ints) -

    如果 indices_or_sections 是整數 n 或值為 n 的零維長張量,則 input 沿維度 dim 拆分為 n 部分。如果 input 沿維度 dim 可被 n 整除,則每個部分將具有相同的大小 input.size(dim) / n 。如果 input 不能被 n 整除,則第一個 int(input.size(dim) % n) 部分的大小將為 int(input.size(dim) / n) + 1 ,其餘部分的大小為 int(input.size(dim) / n)

    如果indices_or_sections 是整數列表或元組,或一維長張量,則input 在列表、元組或張量中的每個索引處沿維度dim 拆分。例如,indices_or_sections=[2, 3]dim=0 將產生張量 input[:2]input[2:3]input[3:]

    如果indices_or_sections是張量,那麽它在CPU上必須是零維或一維長張量。

  • dim(int,可選的) -沿其分割張量的維度。默認值:0

根據 indices_or_sections 指定的索引或部分數量,將張量拆分為多個 sub-tensors,所有這些都是 input 的視圖,沿維度 dim。該函數基於 NumPy 的 numpy.array_split()

例子:

>>> x = torch.arange(8)
>>> torch.tensor_split(x, 3)
(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7]))

>>> x = torch.arange(7)
>>> torch.tensor_split(x, 3)
(tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6]))
>>> torch.tensor_split(x, (1, 6))
(tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6]))

>>> x = torch.arange(14).reshape(2, 7)
>>> x
tensor([[ 0,  1,  2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11, 12, 13]])
>>> torch.tensor_split(x, 3, dim=1)
(tensor([[0, 1, 2],
        [7, 8, 9]]),
 tensor([[ 3,  4],
        [10, 11]]),
 tensor([[ 5,  6],
        [12, 13]]))
>>> torch.tensor_split(x, (1, 6), dim=1)
(tensor([[0],
        [7]]),
 tensor([[ 1,  2,  3,  4,  5],
        [ 8,  9, 10, 11, 12]]),
 tensor([[ 6],
        [13]]))

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.tensor_split。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。