當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch Flatten用法及代碼示例


本文簡要介紹python語言中 torch.nn.Flatten 的用法。

用法:

class torch.nn.Flatten(start_dim=1, end_dim=- 1)

參數

  • start_dim-第一個 dim 變平(默認值 = 1)。

  • end_dim-最後 dim 變平(默認 = -1)。

將連續的暗淡範圍展平為張量。與 Sequential 一起使用。

形狀:
  • 輸入: ,' 其中 是維度 的大小, 表示任意數量的維度,包括無維度。

  • 輸出:

例子::
>>> input = torch.randn(32, 1, 5, 5)
>>> m = nn.Sequential(
>>>     nn.Conv2d(1, 32, 5, 1, 1),
>>>     nn.Flatten()
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([32, 288])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.Flatten。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。