當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch Transformer.forward用法及代碼示例


本文簡要介紹python語言中 torch.nn.Transformer.forward 的用法。

用法:

forward(src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None)

參數

  • src-編碼器的序列(必需)。

  • tgt-解碼器的序列(必需)。

  • src_mask-src 序列的附加掩碼(可選)。

  • tgt_mask-tgt 序列的附加掩碼(可選)。

  • memory_mask-編碼器輸出的附加掩碼(可選)。

  • src_key_padding_mask-每批 src key 的 ByteTensor 掩碼(可選)。

  • tgt_key_padding_mask-每批 tgt 鍵的 ByteTensor 掩碼(可選)。

  • memory_key_padding_mask-每個批次的內存 key 的ByteTensor 掩碼(可選)。

接收並處理屏蔽的源/目標序列。

形狀:
  • src: , (N, S, E) 如果 batch_first。

  • tgt: , (N, T, E) 如果 batch_first。

  • src_mask:

  • tgt_mask:

  • memory_mask:

  • src_key_padding_mask:

  • tgt_key_padding_mask:

  • memory_key_padding_mask:

注意:[src/tgt/memory]_mask 確保允許位置 i 參加未屏蔽的位置。如果提供ByteTensor,則非零位不得參加,零位不變。如果提供了BoolTensor,則True 的位置不允許參加,而False 的值將保持不變。如果提供了FloatTensor,它將被添加到注意力權重中。 [src/tgt/memory]_key_padding_mask提供了key中的指定元素被注意忽略。如果提供了ByteTensor,則將忽略非零位置,而零位置將保持不變。如果提供了BoolTensor,則將忽略具有True 值的位置,而具有False 值的位置將保持不變。

  • 輸出: ,如果batch_first,則為(N, T, E)

注意:由於 Transformer 模型中的multi-head attention 架構,transformer 的輸出序列長度與解碼的輸入序列(即目標)長度相同。

其中 S 是源序列長度,T 是目標序列長度,N 是批量大小,E 是特征數

例子

>>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.Transformer.forward。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。