當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch lstsq用法及代碼示例


本文簡要介紹python語言中 torch.linalg.lstsq 的用法。

用法:

torch.linalg.lstsq(A, B, rcond=None, *, driver=None)

參數

  • A(Tensor) -形狀為 (*, m, n) 的 lhs 張量,其中 * 是零個或多個批次維度。

  • B(Tensor) -形狀為 (*, m, k) 的 rhs 張量,其中 * 是零個或多個批次維度。

  • rcond(float,可選的) -用於確定 A 的有效等級。如果 rcond = Nonercond 設置為 A 的 dtype 乘以 max(m, n) 的機器精度。默認值:None

關鍵字參數

driver(str,可選的) -要使用的 LAPACK/MAGMA 方法的名稱。如果 None‘gelsy’ 用於 CPU 輸入, ‘gels’ 用於 CUDA 輸入。默認值:None

返回

命名元組 (solution, residuals, rank, singular_values)

計算線性方程組的最小二乘問題的解。

\mathbb{K} \mathbb{R} 或者\mathbb{C} , 這最小二乘問題對於線性係統AX = B A \in \mathbb{K}^{m \times n}, B \in \mathbb{K}^{m \times k} 定義為

其中 表示 Frobenius 範數。

支持 float、double、cfloat 和 cdouble dtypes 的輸入。還支持矩陣批次,如果輸入是矩陣批次,則輸出具有相同的批次尺寸。

driver 選擇將使用的 LAPACK/MAGMA 函數。對於 CPU 輸入,有效值為 ‘gels’‘gelsy’‘gelsd‘gelss’ 。對於 CUDA 輸入,唯一有效的驅動程序是 ‘gels’ ,它假定 A 是滿秩的。要在 CPU 上選擇最佳驅動程序,請考慮:

  • 如果A 條件良好(它的condition number 不是太大),或者您不介意一些精度損失。

    • 對於一般矩陣:‘gelsy’(帶旋轉的 QR)(默認)

    • 如果A 是滿秩:‘gels’ (QR)

  • 如果 A 條件不佳。

    • ‘gelsd’(三對角縮減和 SVD)

    • 但是,如果您遇到內存問題:‘gelss’(完整的 SVD)。

另見full description of these drivers

driver 是(‘gelsy’‘gelsd’‘gelss’)之一時,rcond 用於確定A 中矩陣的有效秩。在這種情況下,如果 A 的降序奇異值,如果 ,則 將向下舍入為零。如果 rcond = None(默認),rcond 設置為 A dtype 的機器精度。

此函數在四個張量 (solution, residuals, rank, singular_values) 的命名元組中返回問題的解決方案和一些額外信息。對於形狀為 (*, m, n)(*, m, k) 的輸入 AB ,它包含

  • solution :最小二乘解。它的形狀為 (*, n, k)

  • residuals :解的平方殘差,即 。它的形狀等於 A 的批量尺寸。當m > nA 中的每個矩陣都是滿秩時計算它,否則它是一個空張量。如果A 是一批矩陣並且該批中的任何矩陣不是滿秩的,則返回一個空張量。此行為可能會在未來的PyTorch 版本中發生變化。

  • rankA 中矩陣的秩張量。它的形狀等於 A 的批量尺寸。當driver 是(‘gelsy’‘gelsd’‘gelss’)之一時計算,否則為空張量。

  • singular_valuesA 中矩陣的奇異值的張量。它的形狀為 (*, min(m, n)) 。當 driver 是 (‘gelsd’ , ‘gelss’ ) 之一時計算它,否則它是一個空張量。

注意

此函數計算 X = A .pinverse() @ B 比單獨執行計算更快且數值更穩定。

警告

rcond 的默認值可能會在未來的 PyTorch 版本中更改。因此,建議使用固定值以避免潛在的破壞性變化。

例子:

>>> A = torch.tensor([[[10, 2, 3], [3, 10, 5], [5, 6, 12]]], dtype=torch.float) # shape (1, 3, 3)
>>> B = torch.tensor([[[2, 5, 1], [3, 2, 1], [5, 1, 9]],
                      [[4, 2, 9], [2, 0, 3], [2, 5, 3]]], dtype=torch.float) # shape (2, 3, 3)
>>> X = torch.linalg.lstsq(A, B).solution # A is broadcasted to shape (2, 3, 3)
>>> torch.dist(X, torch.linalg.pinv(A) @ B)
tensor(2.0862e-07)

>>> S = torch.linalg.lstsq(A, B, driver='gelsd').singular_values
>>> torch.dist(S, torch.linalg.svdvals(A))
tensor(5.7220e-06)

>>> A[:, 0].zero_()  # Decrease the rank of A
>>> rank = torch.linalg.lstsq(A, B).rank
>>> rank
tensor([2])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.linalg.lstsq。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。