当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch vsplit用法及代码示例


本文简要介绍python语言中 torch.vsplit 的用法。

用法:

torch.vsplit(input, indices_or_sections) → List of Tensors

参数

根据 indices_or_sections 将具有两个或多个维度的张量 input 垂直拆分为多个张量。每个拆分都是 input 的一个视图。

这等效于调用 torch.tensor_split(input, indices_or_sections, dim=0) (拆分维度为 0),除非 indices_or_sections 是整数,它必须将拆分维度均分,否则运行时错误将被抛出。

该函数基于 NumPy 的 numpy.vsplit()

例子:

>>> t = torch.arange(16.0).reshape(4,4)
>>> t
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])
>>> torch.vsplit(t, 2)
(tensor([[0., 1., 2., 3.],
         [4., 5., 6., 7.]]),
 tensor([[ 8.,  9., 10., 11.],
         [12., 13., 14., 15.]]))
>>> torch.vsplit(t, [3, 6])
(tensor([[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]]),
 tensor([[12., 13., 14., 15.]]),
 tensor([], size=(0, 4)))

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.vsplit。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。