当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch matrix_power用法及代码示例


本文简要介绍python语言中 torch.linalg.matrix_power 的用法。

用法:

torch.linalg.matrix_power(A, n, *, out=None) → Tensor

参数

  • A(Tensor) -形状为 (*, m, m) 的张量,其中 * 是零个或多个批次维度。

  • n(int) - index 。

关键字参数

out(Tensor,可选的) -输出张量。如果 None 则忽略。默认值:None

抛出

RuntimeError - 如果 n < 0 和矩阵 A 或这批矩阵 A 中的任何矩阵不可逆。

计算整数 n 的方阵的 n 次幂。

支持 float、double、cfloat 和 cdouble dtypes 的输入。还支持批量矩阵,如果 A 是批量矩阵,则输出具有相同的批量维度。

如果 n = 0 ,则返回与 A 形状相同的单位矩阵(或批次)。如果 n 为负数,则返回每个矩阵的逆矩阵(如果可逆)的 abs(n) 次幂。

注意

如果可能,请考虑使用 torch.linalg.solve() 将左侧矩阵乘以负幂,例如 n > 0

matrix_power(torch.linalg.solve(A, B), n) == matrix_power(A, -n)  @ B

在可能的情况下,总是首选使用 solve() ,因为它比显式计算 更快且数值更稳定。

例子:

>>> A = torch.randn(3, 3)
>>> torch.linalg.matrix_power(A, 0)
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
>>> torch.linalg.matrix_power(A, 3)
tensor([[ 1.0756,  0.4980,  0.0100],
        [-1.6617,  1.4994, -1.9980],
        [-0.4509,  0.2731,  0.8001]])
>>> torch.linalg.matrix_power(A.expand(2, -1, -1), -2)
tensor([[[ 0.2640,  0.4571, -0.5511],
        [-1.0163,  0.3491, -1.5292],
        [-0.4899,  0.0822,  0.2773]],
        [[ 0.2640,  0.4571, -0.5511],
        [-1.0163,  0.3491, -1.5292],
        [-0.4899,  0.0822,  0.2773]]])

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.linalg.matrix_power。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。