当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.linalg.tridiagonal_solve用法及代码示例


求解三对角方程组。

用法

tf.linalg.tridiagonal_solve(
    diagonals, rhs, diagonals_format='compact', transpose_rhs=False,
    conjugate_rhs=False, name=None, partial_pivoting=True, perturb_singular=False
)

参数

  • diagonals TensorTensor 的元组说明 left-hand 边。形状取决于 diagonals_format ,见上面的说明。必须是 float32 , float64 , complex64complex128
  • rhs 形状为 [..., M] 或 [..., M, K] 的 Tensor 并且具有与 diagonals 相同的 dtype。请注意,如果 rhs 和/或 diags 的形状不是静态已知的,则 rhs 将被视为矩阵而不是向量。
  • diagonals_format matrix , sequencecompact 之一。默认为 compact
  • transpose_rhs 如果 True , rhs 在求解之前被转置(如果 rhs 的形状是 [..., M] 则无效)。
  • conjugate_rhs 如果 True , rhs 在求解之前共轭。
  • name 给此 Op 的名称(可选)。
  • partial_pivoting 是否执行部分旋转。 True 默认情况下。部分旋转使程序更稳定,但速度更慢。在某些情况下,部分旋转是不必要的,包括对角占优和对称正定矩阵(参见例如 [1] 中的定理 9.12)。
  • perturb_singular 是否扰动奇异矩阵以返回有限结果。 False 默认情况下。如果为真,涉及奇异矩阵的系统的解决方案将通过在部分旋转的 LU 分解中扰动接近零的枢轴来计算。具体来说,微小的枢轴会受到一定数量的订单 eps * max_{ij} |U(i,j)| 的干扰,以避免溢出。这里U是LU分解的上三角部分,eps是机器精度。当通过逆迭代计算特征向量时,这对于求解数值奇异系统很有用。如果 partial_pivotingFalse , perturb_singular 也必须是 False

返回

  • 包含解决方案的形状为 [..., M] 或 [..., M, K] 的 Tensor。如果输入矩阵是奇异的,则结果是不确定的。

抛出

  • ValueError 如果满足以下任何条件,则引发:
    1. 提供了不支持的类型作为输入,
    2. 输入张量的形状不正确,
    3. perturb_singularTruepartial_pivoting 不是。
  • UnimplementedError 每当 partial_pivoting 为真且后端为 XLA 时,或每当 perturb_singular 为真且后端为 XLA 或 GPU 时。

输入可以以各种格式提供:matrix , sequencecompact,由 diagonals_format arg 指定。

matrix 格式中,diagonals 必须是形状为 [..., M, M] 的张量,其中两个 inner-most 维度代表三对角矩阵。三个对角线之外的元素将被忽略。

sequence 格式中,diagonals 以元组或三个形状张量的列表的形式提供,[..., N] , [..., M] , [..., N] 分别表示上对角线、对角线和下对角线。 N 可以是 M-1M ;在后一种情况下,上对角线的最后一个元素和下对角线的第一个元素将被忽略。

compact 格式中,三个对角线组合成一个形状为 [..., 3, M] 的张量,最后两个维度依次包含上对角线、对角线和下对角线。与sequence 格式类似,元素diagonals[..., 0, M-1]diagonals[..., 2, 0] 被忽略。

建议使用compact 格式作为性能最佳的格式。如果您需要手动将张量转换为紧凑格式,请使用 tf.gather_nd 。形状为 [m, m] 的张量的示例:

rhs = tf.constant([...])
matrix = tf.constant([[...]])
m = matrix.shape[0]
dummy_idx = [0, 0]  # An arbitrary element to use as a dummy
indices = [[[i, i + 1] for i in range(m - 1)] + [dummy_idx],  # Superdiagonal
         [[i, i] for i in range(m)],                          # Diagonal
         [dummy_idx] + [[i + 1, i] for i in range(m - 1)]]    # Subdiagonal
diagonals=tf.gather_nd(matrix, indices)
x = tf.linalg.tridiagonal_solve(diagonals, rhs)

无论 diagonals_format , rhs 是形状为 [..., M] 还是 [..., M, K] 的张量。后者允许同时求解具有相同 left-hand 边和 K 个不同 right-hand 边的 K 个系统。如果 transpose_rhs 设置为 True ,则预期形状为 [..., M][..., K, M]

批尺寸,表示为 ... ,在 diagonalsrhs 中必须相同。

输出是与 rhs 形状相同的张量:[..., M][..., M, K]

如果输入矩阵不可逆,则不保证运算会引发错误。 tf.debugging.check_numerics 可应用于输出以检测可逆性问题。

注意:对于大批量大小,如果 partial_pivoting=True 或有多个 right-hand 边(K > 1),GPU 上的计算可能会很慢。如果出现此问题,请考虑是否可以禁用旋转并使用 K = 1 ,或者考虑使用 CPU。

在 CPU 上,解决方案是通过高斯消元计算的,有或没有部分旋转,具体取决于partial_pivoting范围。在 GPU 上,使用了 Nvidia 的 cuSPARSE 库:https://docs.nvidia.com/cuda/cusparse/index.html#gtsv

[1] Nicholas J. Higham (2002)。数值算法的准确性和稳定性:第二版。暹。页。 175. 国际标准书号 978-0-89871-802-7。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.linalg.tridiagonal_solve。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。