当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch Unflatten用法及代码示例


本文简要介绍python语言中 torch.nn.Unflatten 的用法。

用法:

class torch.nn.Unflatten(dim, unflattened_size)

参数

  • dim(联盟[int,str]) -未展平的维度

  • unflattened_size(联盟[torch.Size,元组,List,NamedShape]) - 未展平尺寸的新形状

展开张量 dim 将其扩展为所需的形状。与 Sequential 一起使用。

  • dim指定输入张量的维度,当使用TensorNamedTensor时,它可以是intstr

  • unflattened_size 是张量未展平维度的新形状,它可以是整数的 tuple 或整数的 listTensor 输入的 torch.Size;用于NamedTensor 输入的NamedShape((name, size) 元组的元组)。

形状:
  • 输入: ,其中 是维度 dim 的大小,而 表示任意数量的维度,包括无维度。

  • 输出: ,其中 = unflattened_size

例子

>>> input = torch.randn(2, 50)
>>> # With tuple of ints
>>> m = nn.Sequential(
>>>     nn.Linear(50, 50),
>>>     nn.Unflatten(1, (2, 5, 5))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With torch.Size
>>> m = nn.Sequential(
>>>     nn.Linear(50, 50),
>>>     nn.Unflatten(1, torch.Size([2, 5, 5]))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With namedshape (tuple of tuples)
>>> input = torch.randn(2, 50, names=('N', 'features'))
>>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5)))
>>> output = unflatten(input)
>>> output.size()
torch.Size([2, 2, 5, 5])

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.Unflatten。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。