求解三对角方程组。
用法
tf.linalg.tridiagonal_solve(
diagonals, rhs, diagonals_format='compact', transpose_rhs=False,
conjugate_rhs=False, name=None, partial_pivoting=True, perturb_singular=False
)参数
-
diagonalsTensor或Tensor的元组说明 left-hand 边。形状取决于diagonals_format,见上面的说明。必须是float32,float64,complex64或complex128。 -
rhs形状为 [..., M] 或 [..., M, K] 的Tensor并且具有与diagonals相同的 dtype。请注意,如果rhs和/或diags的形状不是静态已知的,则rhs将被视为矩阵而不是向量。 -
diagonals_formatmatrix,sequence或compact之一。默认为compact。 -
transpose_rhs如果True,rhs在求解之前被转置(如果 rhs 的形状是 [..., M] 则无效)。 -
conjugate_rhs如果True,rhs在求解之前共轭。 -
name给此Op的名称(可选)。 -
partial_pivoting是否执行部分旋转。True默认情况下。部分旋转使程序更稳定,但速度更慢。在某些情况下,部分旋转是不必要的,包括对角占优和对称正定矩阵(参见例如 [1] 中的定理 9.12)。 -
perturb_singular是否扰动奇异矩阵以返回有限结果。False默认情况下。如果为真,涉及奇异矩阵的系统的解决方案将通过在部分旋转的 LU 分解中扰动接近零的枢轴来计算。具体来说,微小的枢轴会受到一定数量的订单eps * max_{ij} |U(i,j)|的干扰,以避免溢出。这里U是LU分解的上三角部分,eps是机器精度。当通过逆迭代计算特征向量时,这对于求解数值奇异系统很有用。如果partial_pivoting是False,perturb_singular也必须是False。
返回
-
包含解决方案的形状为 [..., M] 或 [..., M, K] 的
Tensor。如果输入矩阵是奇异的,则结果是不确定的。
抛出
-
ValueError如果满足以下任何条件,则引发:- 提供了不支持的类型作为输入,
- 输入张量的形状不正确,
perturb_singular是True但partial_pivoting不是。
-
UnimplementedError每当partial_pivoting为真且后端为 XLA 时,或每当perturb_singular为真且后端为 XLA 或 GPU 时。
输入可以以各种格式提供:matrix , sequence 和 compact,由 diagonals_format arg 指定。
在 matrix 格式中,diagonals 必须是形状为 [..., M, M] 的张量,其中两个 inner-most 维度代表三对角矩阵。三个对角线之外的元素将被忽略。
在sequence 格式中,diagonals 以元组或三个形状张量的列表的形式提供,[..., N] , [..., M] , [..., N] 分别表示上对角线、对角线和下对角线。 N 可以是 M-1 或 M ;在后一种情况下,上对角线的最后一个元素和下对角线的第一个元素将被忽略。
在 compact 格式中,三个对角线组合成一个形状为 [..., 3, M] 的张量,最后两个维度依次包含上对角线、对角线和下对角线。与sequence 格式类似,元素diagonals[..., 0, M-1] 和diagonals[..., 2, 0] 被忽略。
建议使用compact 格式作为性能最佳的格式。如果您需要手动将张量转换为紧凑格式,请使用 tf.gather_nd 。形状为 [m, m] 的张量的示例:
rhs = tf.constant([...])
matrix = tf.constant([[...]])
m = matrix.shape[0]
dummy_idx = [0, 0] # An arbitrary element to use as a dummy
indices = [[[i, i + 1] for i in range(m - 1)] + [dummy_idx], # Superdiagonal
[[i, i] for i in range(m)], # Diagonal
[dummy_idx] + [[i + 1, i] for i in range(m - 1)]] # Subdiagonal
diagonals=tf.gather_nd(matrix, indices)
x = tf.linalg.tridiagonal_solve(diagonals, rhs)
无论 diagonals_format , rhs 是形状为 [..., M] 还是 [..., M, K] 的张量。后者允许同时求解具有相同 left-hand 边和 K 个不同 right-hand 边的 K 个系统。如果 transpose_rhs 设置为 True ,则预期形状为 [..., M] 或 [..., K, M] 。
批尺寸,表示为 ... ,在 diagonals 和 rhs 中必须相同。
输出是与 rhs 形状相同的张量:[..., M] 或 [..., M, K] 。
如果输入矩阵不可逆,则不保证运算会引发错误。 tf.debugging.check_numerics 可应用于输出以检测可逆性问题。
注意:对于大批量大小,如果 partial_pivoting=True 或有多个 right-hand 边(K > 1),GPU 上的计算可能会很慢。如果出现此问题,请考虑是否可以禁用旋转并使用 K = 1 ,或者考虑使用 CPU。
在 CPU 上,解决方案是通过高斯消元计算的,有或没有部分旋转,具体取决于partial_pivoting范围。在 GPU 上,使用了 Nvidia 的 cuSPARSE 库:https://docs.nvidia.com/cuda/cusparse/index.html#gtsv
[1] Nicholas J. Higham (2002)。数值算法的准确性和稳定性:第二版。暹。页。 175. 国际标准书号 978-0-89871-802-7。
相关用法
- Python tf.linalg.tridiagonal_matmul用法及代码示例
- Python tf.linalg.triangular_solve用法及代码示例
- Python tf.linalg.trace用法及代码示例
- Python tf.linalg.tensor_diag_part用法及代码示例
- Python tf.linalg.tensor_diag用法及代码示例
- Python tf.linalg.LinearOperatorFullMatrix.matvec用法及代码示例
- Python tf.linalg.LinearOperatorToeplitz.solve用法及代码示例
- Python tf.linalg.LinearOperatorIdentity.solvevec用法及代码示例
- Python tf.linalg.LinearOperatorPermutation.solve用法及代码示例
- Python tf.linalg.band_part用法及代码示例
- Python tf.linalg.LinearOperatorKronecker.diag_part用法及代码示例
- Python tf.linalg.lu_matrix_inverse用法及代码示例
- Python tf.linalg.LinearOperatorToeplitz.matvec用法及代码示例
- Python tf.linalg.LinearOperatorBlockLowerTriangular.solvevec用法及代码示例
- Python tf.linalg.LinearOperatorLowerTriangular.matvec用法及代码示例
- Python tf.linalg.LinearOperatorCirculant2D.solve用法及代码示例
- Python tf.linalg.LinearOperatorCirculant3D.diag_part用法及代码示例
- Python tf.linalg.LinearOperatorToeplitz.solvevec用法及代码示例
- Python tf.linalg.LinearOperatorCirculant2D.assert_non_singular用法及代码示例
- Python tf.linalg.LinearOperatorPermutation.diag_part用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.linalg.tridiagonal_solve。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。
