當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch Unflatten用法及代碼示例


本文簡要介紹python語言中 torch.nn.Unflatten 的用法。

用法:

class torch.nn.Unflatten(dim, unflattened_size)

參數

  • dim(聯盟[int,str]) -未展平的維度

  • unflattened_size(聯盟[torch.Size,元組,List,NamedShape]) - 未展平尺寸的新形狀

展開張量 dim 將其擴展為所需的形狀。與 Sequential 一起使用。

  • dim指定輸入張量的維度,當使用TensorNamedTensor時,它可以是intstr

  • unflattened_size 是張量未展平維度的新形狀,它可以是整數的 tuple 或整數的 listTensor 輸入的 torch.Size;用於NamedTensor 輸入的NamedShape((name, size) 元組的元組)。

形狀:
  • 輸入: ,其中 是維度 dim 的大小,而 表示任意數量的維度,包括無維度。

  • 輸出: ,其中 = unflattened_size

例子

>>> input = torch.randn(2, 50)
>>> # With tuple of ints
>>> m = nn.Sequential(
>>>     nn.Linear(50, 50),
>>>     nn.Unflatten(1, (2, 5, 5))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With torch.Size
>>> m = nn.Sequential(
>>>     nn.Linear(50, 50),
>>>     nn.Unflatten(1, torch.Size([2, 5, 5]))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With namedshape (tuple of tuples)
>>> input = torch.randn(2, 50, names=('N', 'features'))
>>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5)))
>>> output = unflatten(input)
>>> output.size()
torch.Size([2, 2, 5, 5])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.Unflatten。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。