当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.linalg.tridiagonal_matmul用法及代码示例


将三对角矩阵乘以矩阵。

用法

tf.linalg.tridiagonal_matmul(
    diagonals, rhs, diagonals_format='compact', name=None
)

参数

  • diagonals TensorTensor 的元组说明 left-hand 边。形状取决于 diagonals_format ,见上面的说明。必须是 float32 , float64 , complex64complex128
  • rhs 形状为 [..., M, N] 的 Tensor 并且具有与 diagonals 相同的 dtype。
  • diagonals_format sequencecompact 之一。默认为 compact
  • name 给此 Op 的名称(可选)。

返回

  • 包含乘法结果的形状为 [..., M, N] 的 Tensor

抛出

  • ValueError 提供不支持的类型作为输入,或者当输入张量的形状不正确时。

diagonals 是 3 对角 NxN 矩阵的表示,它取决于 diagonals_format

matrix 格式中,diagonals 必须是形状为 [..., M, M] 的张量,其中两个 inner-most 维度代表三对角矩阵。三个对角线之外的元素将被忽略。

如果 sequence 格式,diagonals 是三个张量的列表或元组: [superdiag, maindiag, subdiag] ,每个具有形状 [..., M]。 superdiag 的最后一个元素 subdiag 的第一个元素被忽略。

compact 格式中,三个对角线组合成一个形状为 [..., 3, M] 的张量,最后两个维度依次包含上对角线、对角线和下对角线。与sequence 格式类似,元素diagonals[..., 0, M-1]diagonals[..., 2, 0] 被忽略。

建议使用sequence 格式作为性能最佳的格式。

rhs 是乘法右边的矩阵。它的形状为 [..., M, N]

例子:

superdiag = tf.constant([-1, -1, 0], dtype=tf.float64)
maindiag = tf.constant([2, 2, 2], dtype=tf.float64)
subdiag = tf.constant([0, -1, -1], dtype=tf.float64)
diagonals = [superdiag, maindiag, subdiag]
rhs = tf.constant([[1, 1], [1, 1], [1, 1]], dtype=tf.float64)
x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence')

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.linalg.tridiagonal_matmul。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。