當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch block_diag用法及代碼示例


本文簡要介紹python語言中 torch.block_diag 的用法。

用法:

torch.block_diag(*tensors)

參數

*tensors-一個或多個具有 0、1 或 2 維的張量。

返回

一個二維張量,所有輸入張量按順序排列,使得它們的左上角和右下角對角相鄰。所有其他元素都設置為 0。

返回類型

Tensor

從提供的張量創建一個塊對角矩陣。

例子:

>>> import torch
>>> A = torch.tensor([[0, 1], [1, 0]])
>>> B = torch.tensor([[3, 4, 5], [6, 7, 8]])
>>> C = torch.tensor(7)
>>> D = torch.tensor([1, 2, 3])
>>> E = torch.tensor([[4], [5], [6]])
>>> torch.block_diag(A, B, C, D, E)
tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 3, 4, 5, 0, 0, 0, 0, 0],
        [0, 0, 6, 7, 8, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 7, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 2, 3, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 4],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 5],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 6]])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.block_diag。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。