當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python tf.linalg.tridiagonal_solve用法及代碼示例

求解三對角方程組。

用法

tf.linalg.tridiagonal_solve(
    diagonals, rhs, diagonals_format='compact', transpose_rhs=False,
    conjugate_rhs=False, name=None, partial_pivoting=True, perturb_singular=False
)

參數

  • diagonals TensorTensor 的元組說明 left-hand 邊。形狀取決於 diagonals_format ,見上麵的說明。必須是 float32 , float64 , complex64complex128
  • rhs 形狀為 [..., M] 或 [..., M, K] 的 Tensor 並且具有與 diagonals 相同的 dtype。請注意,如果 rhs 和/或 diags 的形狀不是靜態已知的,則 rhs 將被視為矩陣而不是向量。
  • diagonals_format matrix , sequencecompact 之一。默認為 compact
  • transpose_rhs 如果 True , rhs 在求解之前被轉置(如果 rhs 的形狀是 [..., M] 則無效)。
  • conjugate_rhs 如果 True , rhs 在求解之前共軛。
  • name 給此 Op 的名稱(可選)。
  • partial_pivoting 是否執行部分旋轉。 True 默認情況下。部分旋轉使程序更穩定,但速度更慢。在某些情況下,部分旋轉是不必要的,包括對角占優和對稱正定矩陣(參見例如 [1] 中的定理 9.12)。
  • perturb_singular 是否擾動奇異矩陣以返回有限結果。 False 默認情況下。如果為真,涉及奇異矩陣的係統的解決方案將通過在部分旋轉的 LU 分解中擾動接近零的樞軸來計算。具體來說,微小的樞軸會受到一定數量的訂單 eps * max_{ij} |U(i,j)| 的幹擾,以避免溢出。這裏U是LU分解的上三角部分,eps是機器精度。當通過逆迭代計算特征向量時,這對於求解數值奇異係統很有用。如果 partial_pivotingFalse , perturb_singular 也必須是 False

返回

  • 包含解決方案的形狀為 [..., M] 或 [..., M, K] 的 Tensor。如果輸入矩陣是奇異的,則結果是不確定的。

拋出

  • ValueError 如果滿足以下任何條件,則引發:
    1. 提供了不支持的類型作為輸入,
    2. 輸入張量的形狀不正確,
    3. perturb_singularTruepartial_pivoting 不是。
  • UnimplementedError 每當 partial_pivoting 為真且後端為 XLA 時,或每當 perturb_singular 為真且後端為 XLA 或 GPU 時。

輸入可以以各種格式提供:matrix , sequencecompact,由 diagonals_format arg 指定。

matrix 格式中,diagonals 必須是形狀為 [..., M, M] 的張量,其中兩個 inner-most 維度代表三對角矩陣。三個對角線之外的元素將被忽略。

sequence 格式中,diagonals 以元組或三個形狀張量的列表的形式提供,[..., N] , [..., M] , [..., N] 分別表示上對角線、對角線和下對角線。 N 可以是 M-1M ;在後一種情況下,上對角線的最後一個元素和下對角線的第一個元素將被忽略。

compact 格式中,三個對角線組合成一個形狀為 [..., 3, M] 的張量,最後兩個維度依次包含上對角線、對角線和下對角線。與sequence 格式類似,元素diagonals[..., 0, M-1]diagonals[..., 2, 0] 被忽略。

建議使用compact 格式作為性能最佳的格式。如果您需要手動將張量轉換為緊湊格式,請使用 tf.gather_nd 。形狀為 [m, m] 的張量的示例:

rhs = tf.constant([...])
matrix = tf.constant([[...]])
m = matrix.shape[0]
dummy_idx = [0, 0]  # An arbitrary element to use as a dummy
indices = [[[i, i + 1] for i in range(m - 1)] + [dummy_idx],  # Superdiagonal
         [[i, i] for i in range(m)],                          # Diagonal
         [dummy_idx] + [[i + 1, i] for i in range(m - 1)]]    # Subdiagonal
diagonals=tf.gather_nd(matrix, indices)
x = tf.linalg.tridiagonal_solve(diagonals, rhs)

無論 diagonals_format , rhs 是形狀為 [..., M] 還是 [..., M, K] 的張量。後者允許同時求解具有相同 left-hand 邊和 K 個不同 right-hand 邊的 K 個係統。如果 transpose_rhs 設置為 True ,則預期形狀為 [..., M][..., K, M]

批尺寸,表示為 ... ,在 diagonalsrhs 中必須相同。

輸出是與 rhs 形狀相同的張量:[..., M][..., M, K]

如果輸入矩陣不可逆,則不保證運算會引發錯誤。 tf.debugging.check_numerics 可應用於輸出以檢測可逆性問題。

注意:對於大批量大小,如果 partial_pivoting=True 或有多個 right-hand 邊(K > 1),GPU 上的計算可能會很慢。如果出現此問題,請考慮是否可以禁用旋轉並使用 K = 1 ,或者考慮使用 CPU。

在 CPU 上,解決方案是通過高斯消元計算的,有或沒有部分旋轉,具體取決於partial_pivoting範圍。在 GPU 上,使用了 Nvidia 的 cuSPARSE 庫:https://docs.nvidia.com/cuda/cusparse/index.html#gtsv

[1] Nicholas J. Higham (2002)。數值算法的準確性和穩定性:第二版。暹。頁。 175. 國際標準書號 978-0-89871-802-7。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.linalg.tridiagonal_solve。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。