当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.linalg.banded_triangular_solve用法及代码示例


使用带状求解器求解三角方程组。

用法

tf.linalg.banded_triangular_solve(
    bands, rhs, lower=True, adjoint=False, name=None
)

参数

  • bands A Tensor 说明左侧的条带,形状为 [..., K, M] 。当lowerTrue 时,K 行对应于第 K - 1 对角线的对角线(对角线是顶行),否则为对角线的第 K - 1 上对角线(对角线是底行)当 lowerFalse 。波段以 'LEFT_RIGHT' 对齐方式存储,其中上对角线在右侧填充,下对角线在左侧填充。这是 cuSPARSE 使用的对齐方式。有关详细信息,请参阅tf.linalg.set_diag
  • rhs 形状为 [..., M] 或 [..., M, N] 的 Tensor 并且具有与 diagonals 相同的 dtype。请注意,如果 rhs 和/或 diags 的形状不是静态已知的,则 rhs 将被视为矩阵而不是向量。
  • lower 可选的 bool 。默认为 True 。布尔值,指示 bands 是表示下三角矩阵还是上三角矩阵。
  • adjoint 可选的 bool 。默认为 False 。布尔值,指示是否使用矩阵的block-wise 伴随求解。
  • name 给此 Op 的名称(可选)。

返回

  • 包含解决方案的形状为 [..., M] 或 [..., M, N] 的 Tensor

bands 是形状为 [..., K, M] 的张量,其中 K 表示存储的波段数。这对应于一批 M by M 矩阵,其 K 次对角线(当 lowerTrue 时)被存储。

该算子广播 bands 的批次维度和 rhs 的批次维度。

例子:

存储 3x3 矩阵的 2 个波段。请注意,由于'LEFT_RIGHT' 填充,第二行中的第一个元素被忽略。

x = [[2., 3., 4.], [1., 2., 3.]]
x2 = [[2., 3., 4.], [10000., 2., 3.]]
y = tf.zeros([3, 3])
z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(-1, 0))
z
<tf.Tensor:shape=(3, 3), dtype=float32, numpy=
array([[2., 0., 0.],
       [2., 3., 0.],
       [0., 3., 4.]], dtype=float32)>
soln = tf.linalg.banded_triangular_solve(x, tf.ones([3, 1]))
soln
<tf.Tensor:shape=(3, 1), dtype=float32, numpy=
array([[0.5 ],
       [0.  ],
       [0.25]], dtype=float32)>
are_equal = soln == tf.linalg.banded_triangular_solve(x2, tf.ones([3, 1]))
tf.reduce_all(are_equal).numpy()
True
are_equal = soln == tf.linalg.triangular_solve(z, tf.ones([3, 1]))
tf.reduce_all(are_equal).numpy()
True

存储 2 个 4x4 矩阵的超对角线。由于'LEFT_RIGHT' 填充,第一行的最后一个元素被忽略。

x = [[2., 3., 4., 5.], [-1., -2., -3., -4.]]
y = tf.zeros([4, 4])
z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(0, 1))
z
<tf.Tensor:shape=(4, 4), dtype=float32, numpy=
array([[-1.,  2.,  0.,  0.],
       [ 0., -2.,  3.,  0.],
       [ 0.,  0., -3.,  4.],
       [ 0.,  0., -0., -4.]], dtype=float32)>
soln = tf.linalg.banded_triangular_solve(x, tf.ones([4, 1]), lower=False)
soln
<tf.Tensor:shape=(4, 1), dtype=float32, numpy=
array([[-4.       ],
       [-1.5      ],
       [-0.6666667],
       [-0.25     ]], dtype=float32)>
are_equal = (soln == tf.linalg.triangular_solve(
  z, tf.ones([4, 1]), lower=False))
tf.reduce_all(are_equal).numpy()
True

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.linalg.banded_triangular_solve。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。