当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch inv用法及代码示例


本文简要介绍python语言中 torch.linalg.inv 的用法。

用法:

torch.linalg.inv(A, *, out=None) → Tensor

参数

A(Tensor) -形状为 (*, n, n) 的张量,其中 * 是零个或多个由可逆矩阵组成的批次维度。

关键字参数

out(Tensor,可选的) -输出张量。如果 None 则忽略。默认值:None

抛出

RuntimeError - 如果矩阵 A 或矩阵批次 A 中的任何矩阵不可逆。

计算方阵的逆(如果存在)。如果矩阵不可逆,则抛出 RuntimeError

\mathbb{K} \mathbb{R} 或者\mathbb{C} , 对于矩阵A \in \mathbb{K}^{n \times n} , 它的逆矩阵 A^{-1} \in \mathbb{K}^{n \times n} (如果存在)定义为

其中 n 维单位矩阵。

当且仅当 invertible 时,逆矩阵才存在。在这种情况下,逆是唯一的。

支持 float、double、cfloat 和 cdouble dtypes 的输入。还支持批量矩阵,如果 A 是批量矩阵,则输出具有相同的批量维度。

注意

当输入在 CUDA 设备上时,此函数将该设备与 CPU 同步。

注意

如果可能,请考虑使用 torch.linalg.solve() 将左侧的矩阵乘以逆矩阵,如下所示:

torch.linalg.solve(A, B) == A.inv() @ B

在可能的情况下,总是首选使用 solve() ,因为它比显式计算逆运算更快且数值更稳定。

例子:

>>> A = torch.randn(4, 4)
>>> Ainv = torch.linalg.inv(A)
>>> torch.dist(A @ Ainv, torch.eye(4))
tensor(1.1921e-07)

>>> A = torch.randn(2, 3, 4, 4)  # Batch of matrices
>>> Ainv = torch.linalg.inv(A)
>>> torch.dist(A @ Ainv, torch.eye(4)))
tensor(1.9073e-06)

>>> A = torch.randn(4, 4, dtype=torch.complex128)  # Complex matrix
>>> Ainv = torch.linalg.inv(A)
>>> torch.dist(A @ Ainv, torch.eye(4))
tensor(7.5107e-16, dtype=torch.float64)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.linalg.inv。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。