當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch triangular_solve用法及代碼示例


本文簡要介紹python語言中 torch.triangular_solve 的用法。

用法:

torch.triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None)

參數

  • b(Tensor) -大小為 的多個右側,其中 是多個批次維度的零

  • A(Tensor) -大小為 的輸入三角係數矩陣,其中 是零個或多個批處理維度

  • upper(bool,可選的) -是否求解上三角方程組(默認)或下三角方程組。默認值:True

  • transpose(bool,可選的) - 在發送到求解器之前是否應轉置。默認值:False

  • unitriangular(bool,可選的) - 是否為單位三角形。如果為 True,則 的對角線元素被假定為 1 並且未從 引用。默認值:False

關鍵字參數

out((Tensor,Tensor),可選的) -將輸出寫入的兩個張量的元組。如果 None 則忽略。默認值:None

返回

命名元組(solution, cloned_coefficient) 其中cloned_coefficient 的克隆,solution 的解(或方程組的任何變體,取決於關鍵字參數。)

求解具有三角係數矩陣 和多個右側 的方程組。

特別是,求解 並假設 是具有默認關鍵字參數的上三角函數。

torch.triangular_solve(b, A) 可以接受 2D 輸入 b, A 或作為 2D 矩陣批次的輸入。如果輸入是批處理,則返回批處理輸出X

如果 A 的對角線包含零或非常接近零的元素和 unitriangular = False(默認),或者如果輸入矩陣條件不佳,則結果可能包含 NaN s。

支持 float、double、cfloat 和 cdouble 數據類型的輸入。

例子:

>>> A = torch.randn(2, 2).triu()
>>> A
tensor([[ 1.1527, -1.0753],
        [ 0.0000,  0.7986]])
>>> b = torch.randn(2, 3)
>>> b
tensor([[-0.0210,  2.3513, -1.5492],
        [ 1.5429,  0.7403, -1.0243]])
>>> torch.triangular_solve(b, A)
torch.return_types.triangular_solve(
solution=tensor([[ 1.7841,  2.9046, -2.5405],
        [ 1.9320,  0.9270, -1.2826]]),
cloned_coefficient=tensor([[ 1.1527, -1.0753],
        [ 0.0000,  0.7986]]))

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.triangular_solve。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。