当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch Flatten用法及代码示例


本文简要介绍python语言中 torch.nn.Flatten 的用法。

用法:

class torch.nn.Flatten(start_dim=1, end_dim=- 1)

参数

  • start_dim-第一个 dim 变平(默认值 = 1)。

  • end_dim-最后 dim 变平(默认 = -1)。

将连续的暗淡范围展平为张量。与 Sequential 一起使用。

形状:
  • 输入: ,' 其中 是维度 的大小, 表示任意数量的维度,包括无维度。

  • 输出:

例子::
>>> input = torch.randn(32, 1, 5, 5)
>>> m = nn.Sequential(
>>>     nn.Conv2d(1, 32, 5, 1, 1),
>>>     nn.Flatten()
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([32, 288])

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.Flatten。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。