当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch split用法及代码示例


本文简要介绍python语言中 torch.split 的用法。

用法:

torch.split(tensor, split_size_or_sections, dim=0)

参数

  • tensor(Tensor) -张量分裂。

  • split_size_or_sections(int) 或者(list(int)) -单个块的大小或每个块的大小列表

  • dim(int) -沿其分割张量的维度。

将张量拆分为块。每个块都是原始张量的视图。

如果split_size_or_sections 是整数类型,那么 tensor 将被分成大小相等的块(如果可能)。如果沿给定维度 dim 的张量大小不能被 split_size 整除,则最后一个块将更小。

如果 split_size_or_sections 是一个列表,那么 tensor 将根据 split_size_or_sections 被拆分为大小在 dim 中的 len(split_size_or_sections) 块。

例子:

>>> a = torch.arange(10).reshape(5,2)
>>> a
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
         [2, 3]]),
 tensor([[4, 5],
         [6, 7]]),
 tensor([[8, 9]]))
>>> torch.split(a, [1,4])
(tensor([[0, 1]]),
 tensor([[2, 3],
         [4, 5],
         [6, 7],
         [8, 9]]))

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.split。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。