當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch pinv用法及代碼示例


本文簡要介紹python語言中 torch.linalg.pinv 的用法。

用法:

torch.linalg.pinv(A, rcond=1e-15, hermitian=False, *, out=None) → Tensor

參數

  • A(Tensor) -形狀為 (*, m, n) 的張量,其中 * 是零個或多個批次維度。

  • rcond(float或者Tensor,可選的) -確定奇異值何時為零的容差值如果它是 torch.Tensor ,則其形狀必須可廣播為 torch.svd() 返回的 A 的奇異值的形狀。默認值:1e-15

  • hermitian(bool,可選的) -指示 A 如果是複數是 Hermitian,如果是實數是對稱的。默認值:False

關鍵字參數

out(Tensor,可選的) -輸出張量。如果 None 則忽略。默認值:None

計算矩陣的偽逆(Moore-Penrose逆)。

偽逆可能是defined algebraically,但理解它在計算上更方便through the SVD

支持 float、double、cfloat 和 cdouble dtypes 的輸入。還支持批量矩陣,如果 A 是批量矩陣,則輸出具有相同的批量維度。

如果 hermitian = TrueA 假設是 Hermitian 如果複數或對稱如果實數,但內部不檢查。相反,在計算中僅使用矩陣的下三角部分。

低於指定 rcond 閾值的奇異值(或當 hermitian = True 時的特征值的範數)被視為零並在計算中被丟棄。

注意

此函數使用 torch.linalg.svd() if hermitian = False torch.linalg.eigh() if hermitian = True 。對於 CUDA 輸入,此函數將該設備與 CPU 同步。

注意

如果可能,考慮使用 torch.linalg.lstsq() 將左側矩陣乘以偽逆矩陣,如下所示:

torch.linalg.lstsq(A, B).solution == A.pinv() @ B

在可能的情況下,總是首選使用 lstsq() ,因為它比顯式計算偽逆更快且數值更穩定。

警告

該函數內部使用 torch.linalg.svd() (或 torch.linalg.eigh() ,當hermitian = True 時),因此其導數與這些函數具有相同的問題。有關更多詳細信息,請參閱 torch.linalg.svd() torch.linalg.eigh() 中的警告。

例子:

>>> A = torch.randn(3, 5)
>>> A
tensor([[ 0.5495,  0.0979, -1.4092, -0.1128,  0.4132],
        [-1.1143, -0.3662,  0.3042,  1.6374, -0.9294],
        [-0.3269, -0.5745, -0.0382, -0.5922, -0.6759]])
>>> torch.linalg.pinv(A)
tensor([[ 0.0600, -0.1933, -0.2090],
        [-0.0903, -0.0817, -0.4752],
        [-0.7124, -0.1631, -0.2272],
        [ 0.1356,  0.3933, -0.5023],
        [-0.0308, -0.1725, -0.5216]])

>>> A = torch.randn(2, 6, 3)
>>> Apinv = torch.linalg.pinv(A)
>>> torch.dist(Apinv @ A, torch.eye(3))
tensor(8.5633e-07)

>>> A = torch.randn(3, 3, dtype=torch.complex64)
>>> A = A + A.T.conj()  # creates a Hermitian matrix
>>> Apinv = torch.linalg.pinv(A, hermitian=True)
>>> torch.dist(Apinv @ A, torch.eye(3))
tensor(1.0830e-06)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.linalg.pinv。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。