当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch Tensor.unflatten用法及代码示例


本文简要介绍python语言中 torch.Tensor.unflatten 的用法。

用法:

unflatten(dim, sizes)

参数

  • dim(联盟[int,str]) -要展开的尺寸

  • sizes(联盟[元组[int] 或者torch.Size,元组[元组[str,int]]]) - 未展平尺寸的新形状

sizes 给出的多个尺寸维度上扩展 self 张量的维度 dim

  • sizes 是未展平维度的新形状,如果 selfTensornamedshape (Tuple[(name: str, size: int),则它可以是 Tuple[int] 以及 torch.Size )]) 如果 selfNamedTensor 。 size 中的元素总数必须与未展平的原始 dim 中的元素数量相匹配。

例子

>>> torch.randn(3, 4, 1).unflatten(1, (2, 2)).shape
torch.Size([3, 2, 2, 1])
>>> torch.randn(3, 4, 1).unflatten(1, (-1, 2)).shape # the size -1 is inferred from the size of dimension 1
torch.Size([3, 2, 2, 1])
>>> torch.randn(2, 4, names=('A', 'B')).unflatten('B', (('B1', 2), ('B2', 2)))
tensor([[[-1.1772,  0.0180],
        [ 0.2412,  0.1431]],
        [[-1.1819, -0.8899],
        [ 1.5813,  0.2274]]], names=('A', 'B1', 'B2'))
>>> torch.randn(2, names=('A',)).unflatten('A', (('B1', -1), ('B2', 1)))
tensor([[-0.8591],
        [ 0.3100]], names=('B1', 'B2'))

警告

命名张量 API 是实验性的,可能会发生变化。

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.Tensor.unflatten。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。