当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch tensorinv用法及代码示例


本文简要介绍python语言中 torch.linalg.tensorinv 的用法。

用法:

torch.linalg.tensorinv(A, ind=2, *, out=None) → Tensor

参数

  • A(Tensor) -要反转的张量。它的形状必须满足 prod( A .shape[: ind ]) == prod( A .shape[ ind :])

  • ind(int) -计算 torch.tensordot() 的逆的索引。默认值:2

关键字参数

out(Tensor,可选的) -输出张量。如果 None 则忽略。默认值:None

抛出

RuntimeError - 如果重新整形的 A 不可逆,或者第一个 ind 维度的乘积不等于其余维度的乘积。

计算 torch.tensordot() 的乘法逆元。

如果mA 的第一个ind 维度的乘积,而n 是其余维度的乘积,则此函数期望mn 相等。如果是这种情况,它会计算一个张量 X 使得 tensordot( A , X, ind ) 是维度 m 中的单位矩阵。 X 将具有 A 的形状,但第一个 ind 尺寸被推回末尾

X.shape == A.shape[ind:] + A.shape[:ind]

支持 float、double、cfloat 和 cdouble dtypes 的输入。

注意

A2 维张量和 ind = 1 时,此函数计算 A 的(乘法)逆(参见 torch.linalg.inv() )。

注意

如果可能,考虑使用 torch.linalg.tensorsolve() 将左侧张量乘以张量逆,如下所示:

tensorsolve(A, B) == torch.tensordot(tensorinv(A), B)

在可能的情况下,总是首选使用 tensorsolve() ,因为它比显式计算伪逆更快且数值更稳定。

例子:

>>> A = torch.eye(4 * 6).reshape((4, 6, 8, 3))
>>> Ainv = torch.linalg.tensorinv(A, ind=2)
>>> Ainv.shape
torch.Size([8, 3, 4, 6])
>>> B = torch.randn(4, 6)
>>> torch.allclose(torch.tensordot(Ainv, B), torch.linalg.tensorsolve(A, B))
True

>>> A = torch.randn(4, 4)
>>> Atensorinv = torch.linalg.tensorinv(A, ind=1)
>>> Ainv = torch.linalg.inverse(A)
>>> torch.allclose(Atensorinv, Ainv)
True

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.linalg.tensorinv。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。