當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.linalg.banded_triangular_solve用法及代碼示例


使用帶狀求解器求解三角方程組。

用法

tf.linalg.banded_triangular_solve(
    bands, rhs, lower=True, adjoint=False, name=None
)

參數

  • bands A Tensor 說明左側的條帶,形狀為 [..., K, M] 。當lowerTrue 時,K 行對應於第 K - 1 對角線的對角線(對角線是頂行),否則為對角線的第 K - 1 上對角線(對角線是底行)當 lowerFalse 。波段以 'LEFT_RIGHT' 對齊方式存儲,其中上對角線在右側填充,下對角線在左側填充。這是 cuSPARSE 使用的對齊方式。有關詳細信息,請參閱tf.linalg.set_diag
  • rhs 形狀為 [..., M] 或 [..., M, N] 的 Tensor 並且具有與 diagonals 相同的 dtype。請注意,如果 rhs 和/或 diags 的形狀不是靜態已知的,則 rhs 將被視為矩陣而不是向量。
  • lower 可選的 bool 。默認為 True 。布爾值,指示 bands 是表示下三角矩陣還是上三角矩陣。
  • adjoint 可選的 bool 。默認為 False 。布爾值,指示是否使用矩陣的block-wise 伴隨求解。
  • name 給此 Op 的名稱(可選)。

返回

  • 包含解決方案的形狀為 [..., M] 或 [..., M, N] 的 Tensor

bands 是形狀為 [..., K, M] 的張量,其中 K 表示存儲的波段數。這對應於一批 M by M 矩陣,其 K 次對角線(當 lowerTrue 時)被存儲。

該算子廣播 bands 的批次維度和 rhs 的批次維度。

例子:

存儲 3x3 矩陣的 2 個波段。請注意,由於'LEFT_RIGHT' 填充,第二行中的第一個元素被忽略。

x = [[2., 3., 4.], [1., 2., 3.]]
x2 = [[2., 3., 4.], [10000., 2., 3.]]
y = tf.zeros([3, 3])
z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(-1, 0))
z
<tf.Tensor:shape=(3, 3), dtype=float32, numpy=
array([[2., 0., 0.],
       [2., 3., 0.],
       [0., 3., 4.]], dtype=float32)>
soln = tf.linalg.banded_triangular_solve(x, tf.ones([3, 1]))
soln
<tf.Tensor:shape=(3, 1), dtype=float32, numpy=
array([[0.5 ],
       [0.  ],
       [0.25]], dtype=float32)>
are_equal = soln == tf.linalg.banded_triangular_solve(x2, tf.ones([3, 1]))
tf.reduce_all(are_equal).numpy()
True
are_equal = soln == tf.linalg.triangular_solve(z, tf.ones([3, 1]))
tf.reduce_all(are_equal).numpy()
True

存儲 2 個 4x4 矩陣的超對角線。由於'LEFT_RIGHT' 填充,第一行的最後一個元素被忽略。

x = [[2., 3., 4., 5.], [-1., -2., -3., -4.]]
y = tf.zeros([4, 4])
z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(0, 1))
z
<tf.Tensor:shape=(4, 4), dtype=float32, numpy=
array([[-1.,  2.,  0.,  0.],
       [ 0., -2.,  3.,  0.],
       [ 0.,  0., -3.,  4.],
       [ 0.,  0., -0., -4.]], dtype=float32)>
soln = tf.linalg.banded_triangular_solve(x, tf.ones([4, 1]), lower=False)
soln
<tf.Tensor:shape=(4, 1), dtype=float32, numpy=
array([[-4.       ],
       [-1.5      ],
       [-0.6666667],
       [-0.25     ]], dtype=float32)>
are_equal = (soln == tf.linalg.triangular_solve(
  z, tf.ones([4, 1]), lower=False))
tf.reduce_all(are_equal).numpy()
True

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.linalg.banded_triangular_solve。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。