当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch block_diag用法及代码示例


本文简要介绍python语言中 torch.block_diag 的用法。

用法:

torch.block_diag(*tensors)

参数

*tensors-一个或多个具有 0、1 或 2 维的张量。

返回

一个二维张量,所有输入张量按顺序排列,使得它们的左上角和右下角对角相邻。所有其他元素都设置为 0。

返回类型

Tensor

从提供的张量创建一个块对角矩阵。

例子:

>>> import torch
>>> A = torch.tensor([[0, 1], [1, 0]])
>>> B = torch.tensor([[3, 4, 5], [6, 7, 8]])
>>> C = torch.tensor(7)
>>> D = torch.tensor([1, 2, 3])
>>> E = torch.tensor([[4], [5], [6]])
>>> torch.block_diag(A, B, C, D, E)
tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 3, 4, 5, 0, 0, 0, 0, 0],
        [0, 0, 6, 7, 8, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 7, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 2, 3, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 4],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 5],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 6]])

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.block_diag。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。