當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch solve用法及代碼示例


本文簡要介紹python語言中 torch.solve 的用法。

用法:

torch.solve(input, A, *, out=None)

參數

  • input(Tensor) -大小為 的輸入矩陣 ,其中 是零個或多個批處理維度。

  • A(Tensor) -輸入大小為 的方陣,其中 是零個或多個批處理維度。

關鍵字參數

out((Tensor,Tensor),可選的) -可選的輸出元組。

此函數返回由 和 A 的 LU 分解表示的線性方程組的解,以便作為命名元組 solution, LU

LU 包含 LU 因子,用於 A 的 LU 分解。

torch.solve(B, A) 可以接受 2D 輸入 B, A 或作為 2D 矩陣批次的輸入。如果輸入是批處理,則返回批處理輸出 solution, LU

支持實值和complex-valued 輸入。

警告

torch.solve() 已棄用,取而代之的是 torch.linalg.solve() ,並將在未來的 PyTorch 版本中刪除。 torch.linalg.solve() 的參數相反,並且不返回輸入的 LU 分解。要獲得 LU 分解,請參閱 torch.lu() ,它可以與 torch.lu_solve() torch.lu_unpack() 一起使用。

X = torch.solve(B, A).solution 應替換為

X = torch.linalg.solve(A, B)

注意

無論原始步幅如何,返回的矩陣 solutionLU 將被轉置,即分別具有像 B.contiguous().transpose(-1, -2).stride()A.contiguous().transpose(-1, -2).stride() 這樣的步幅。

例子:

>>> A = torch.tensor([[6.80, -2.11,  5.66,  5.97,  8.23],
...                   [-6.05, -3.30,  5.36, -4.44,  1.08],
...                   [-0.45,  2.58, -2.70,  0.27,  9.04],
...                   [8.32,  2.71,  4.35,  -7.17,  2.14],
...                   [-9.67, -5.14, -7.26,  6.08, -6.87]]).t()
>>> B = torch.tensor([[4.02,  6.19, -8.22, -7.57, -3.03],
...                   [-1.56,  4.00, -8.67,  1.75,  2.86],
...                   [9.81, -4.09, -4.57, -8.61,  8.99]]).t()
>>> X, LU = torch.solve(B, A)
>>> torch.dist(B, torch.mm(A, X))
tensor(1.00000e-06 *
       7.0977)

>>> # Batched solver example
>>> A = torch.randn(2, 3, 1, 4, 4)
>>> B = torch.randn(2, 3, 1, 4, 6)
>>> X, LU = torch.solve(B, A)
>>> torch.dist(B, A.matmul(X))
tensor(1.00000e-06 *
   3.6386)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.solve。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。