当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.linalg.band_part用法及代码示例


复制一个张量,将每个最内层矩阵中中心带之外的所有内容设置为零。

用法

tf.linalg.band_part(
    input, num_lower, num_upper, name=None
)

参数

  • input 一个Tensor。秩 k 张量。
  • num_lower 一个Tensor。必须是以下类型之一:int32 , int64。 0-D 张量。要保留的子对角线的数量。如果为负,则保留整个下三角形。
  • num_upper 一个Tensor。必须与 num_lower 具有相同的类型。 0-D 张量。要保留的超对角行数。如果为负数,则保留整个上三角。
  • name 操作的名称(可选)。

返回

  • 一个Tensor。具有与 input 相同的类型。

band 部分计算如下:假设 input 具有 k 维度 [I, J, K, ..., M, N] ,则输出是具有相同形状的张量,其中

band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n].

指标函数

in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && (num_upper < 0 || (n-m) <= num_upper).

例如:

# if 'input' is [[ 0,  1,  2, 3]
#                [-1,  0,  1, 2]
#                [-2, -1,  0, 1]
#                [-3, -2, -1, 0]],

tf.linalg.band_part(input, 1, -1) ==> [[ 0,  1,  2, 3]
                                       [-1,  0,  1, 2]
                                       [ 0, -1,  0, 1]
                                       [ 0,  0, -1, 0]],

tf.linalg.band_part(input, 2, 1) ==> [[ 0,  1,  0, 0]
                                      [-1,  0,  1, 0]
                                      [-2, -1,  0, 1]
                                      [ 0, -2, -1, 0]]

有用的特殊情况:

tf.linalg.band_part(input, 0, -1) ==> Upper triangular part.
 tf.linalg.band_part(input, -1, 0) ==> Lower triangular part.
 tf.linalg.band_part(input, 0, 0) ==> Diagonal.

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.linalg.band_part。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。