當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch tensorinv用法及代碼示例


本文簡要介紹python語言中 torch.linalg.tensorinv 的用法。

用法:

torch.linalg.tensorinv(A, ind=2, *, out=None) → Tensor

參數

  • A(Tensor) -要反轉的張量。它的形狀必須滿足 prod( A .shape[: ind ]) == prod( A .shape[ ind :])

  • ind(int) -計算 torch.tensordot() 的逆的索引。默認值:2

關鍵字參數

out(Tensor,可選的) -輸出張量。如果 None 則忽略。默認值:None

拋出

RuntimeError - 如果重新整形的 A 不可逆,或者第一個 ind 維度的乘積不等於其餘維度的乘積。

計算 torch.tensordot() 的乘法逆元。

如果mA 的第一個ind 維度的乘積,而n 是其餘維度的乘積,則此函數期望mn 相等。如果是這種情況,它會計算一個張量 X 使得 tensordot( A , X, ind ) 是維度 m 中的單位矩陣。 X 將具有 A 的形狀,但第一個 ind 尺寸被推回末尾

X.shape == A.shape[ind:] + A.shape[:ind]

支持 float、double、cfloat 和 cdouble dtypes 的輸入。

注意

A2 維張量和 ind = 1 時,此函數計算 A 的(乘法)逆(參見 torch.linalg.inv() )。

注意

如果可能,考慮使用 torch.linalg.tensorsolve() 將左側張量乘以張量逆,如下所示:

tensorsolve(A, B) == torch.tensordot(tensorinv(A), B)

在可能的情況下,總是首選使用 tensorsolve() ,因為它比顯式計算偽逆更快且數值更穩定。

例子:

>>> A = torch.eye(4 * 6).reshape((4, 6, 8, 3))
>>> Ainv = torch.linalg.tensorinv(A, ind=2)
>>> Ainv.shape
torch.Size([8, 3, 4, 6])
>>> B = torch.randn(4, 6)
>>> torch.allclose(torch.tensordot(Ainv, B), torch.linalg.tensorsolve(A, B))
True

>>> A = torch.randn(4, 4)
>>> Atensorinv = torch.linalg.tensorinv(A, ind=1)
>>> Ainv = torch.linalg.inverse(A)
>>> torch.allclose(Atensorinv, Ainv)
True

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.linalg.tensorinv。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。