當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch diagflat用法及代碼示例


本文簡要介紹python語言中 torch.diagflat 的用法。

用法:

torch.diagflat(input, offset=0) → Tensor

參數

  • input(Tensor) -輸入張量。

  • offset(int,可選的) -要考慮的對角線。默認值:0(主對角線)。

  • 如果input 是一個向量(一維張量),則返回一個以input 的元素為對角線的二維平方張量。

  • 如果 input 是一個多維張量,則返回一個二維張量,其對角線元素等於展平的 input

參數 offset 控製要考慮的對角線:

  • 如果offset = 0,它是主對角線。

  • 如果offset > 0,它在主對角線之上。

  • 如果offset < 0,則它位於主對角線下方。

例子:

>>> a = torch.randn(3)
>>> a
tensor([-0.2956, -0.9068,  0.1695])
>>> torch.diagflat(a)
tensor([[-0.2956,  0.0000,  0.0000],
        [ 0.0000, -0.9068,  0.0000],
        [ 0.0000,  0.0000,  0.1695]])
>>> torch.diagflat(a, 1)
tensor([[ 0.0000, -0.2956,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.9068,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.1695],
        [ 0.0000,  0.0000,  0.0000,  0.0000]])

>>> a = torch.randn(2, 2)
>>> a
tensor([[ 0.2094, -0.3018],
        [-0.1516,  1.9342]])
>>> torch.diagflat(a)
tensor([[ 0.2094,  0.0000,  0.0000,  0.0000],
        [ 0.0000, -0.3018,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.1516,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  1.9342]])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.diagflat。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。