當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch Tensor.unflatten用法及代碼示例


本文簡要介紹python語言中 torch.Tensor.unflatten 的用法。

用法:

unflatten(dim, sizes)

參數

  • dim(聯盟[int,str]) -要展開的尺寸

  • sizes(聯盟[元組[int] 或者torch.Size,元組[元組[str,int]]]) - 未展平尺寸的新形狀

sizes 給出的多個尺寸維度上擴展 self 張量的維度 dim

  • sizes 是未展平維度的新形狀,如果 selfTensornamedshape (Tuple[(name: str, size: int),則它可以是 Tuple[int] 以及 torch.Size )]) 如果 selfNamedTensor 。 size 中的元素總數必須與未展平的原始 dim 中的元素數量相匹配。

例子

>>> torch.randn(3, 4, 1).unflatten(1, (2, 2)).shape
torch.Size([3, 2, 2, 1])
>>> torch.randn(3, 4, 1).unflatten(1, (-1, 2)).shape # the size -1 is inferred from the size of dimension 1
torch.Size([3, 2, 2, 1])
>>> torch.randn(2, 4, names=('A', 'B')).unflatten('B', (('B1', 2), ('B2', 2)))
tensor([[[-1.1772,  0.0180],
        [ 0.2412,  0.1431]],
        [[-1.1819, -0.8899],
        [ 1.5813,  0.2274]]], names=('A', 'B1', 'B2'))
>>> torch.randn(2, names=('A',)).unflatten('A', (('B1', -1), ('B2', 1)))
tensor([[-0.8591],
        [ 0.3100]], names=('B1', 'B2'))

警告

命名張量 API 是實驗性的,可能會發生變化。

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.Tensor.unflatten。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。