當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python tf.sequence_mask用法及代碼示例

返回一個掩碼張量,表示每個單元格的前 N ​​個位置。

用法

tf.sequence_mask(
    lengths, maxlen=None, dtype=tf.dtypes.bool, name=None
)

參數

  • lengths 整數張量,其所有值 <= maxlen。
  • maxlen 標量整數張量,返回張量的最後一維的大小。默認值為 lengths 中的最大值。
  • dtype 結果張量的輸出類型。
  • name 操作的名稱。

返回

  • 形狀為 lengths.shape + (maxlen,) 的掩碼張量,轉換為指定的 dtype。

拋出

  • ValueError 如果 maxlen 不是標量。

如果 lengths 具有形狀 [d_1, d_2, ..., d_n] ,則生成的張量 mask 具有 dtype dtype 和形狀 [d_1, d_2, ..., d_n, maxlen] ,與

mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])

例子:

tf.sequence_mask([1, 3, 2], 5)  # [[True, False, False, False, False],
                                #  [True, True, True, False, False],
                                #  [True, True, False, False, False]]

tf.sequence_mask([[1, 3],[2,0]])  # [[[True, False, False],
                                  #   [True, True, True]],
                                  #  [[True, True, False],
                                  #   [False, False, False]]]

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.sequence_mask。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。