当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.sequence_mask用法及代码示例


返回一个掩码张量,表示每个单元格的前 N ​​个位置。

用法

tf.sequence_mask(
    lengths, maxlen=None, dtype=tf.dtypes.bool, name=None
)

参数

  • lengths 整数张量,其所有值 <= maxlen。
  • maxlen 标量整数张量,返回张量的最后一维的大小。默认值为 lengths 中的最大值。
  • dtype 结果张量的输出类型。
  • name 操作的名称。

返回

  • 形状为 lengths.shape + (maxlen,) 的掩码张量,转换为指定的 dtype。

抛出

  • ValueError 如果 maxlen 不是标量。

如果 lengths 具有形状 [d_1, d_2, ..., d_n] ,则生成的张量 mask 具有 dtype dtype 和形状 [d_1, d_2, ..., d_n, maxlen] ,与

mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])

例子:

tf.sequence_mask([1, 3, 2], 5)  # [[True, False, False, False, False],
                                #  [True, True, True, False, False],
                                #  [True, True, False, False, False]]

tf.sequence_mask([[1, 3],[2,0]])  # [[[True, False, False],
                                  #   [True, True, True]],
                                  #  [[True, True, False],
                                  #   [False, False, False]]]

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.sequence_mask。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。