用法:
forward(src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None)
src-编码器的序列(必需)。
tgt-解码器的序列(必需)。
src_mask-src 序列的附加掩码(可选)。
tgt_mask-tgt 序列的附加掩码(可选)。
memory_mask-编码器输出的附加掩码(可选)。
src_key_padding_mask-每批 src key 的 ByteTensor 掩码(可选)。
tgt_key_padding_mask-每批 tgt 键的 ByteTensor 掩码(可选)。
memory_key_padding_mask-每个批次的内存 key 的ByteTensor 掩码(可选)。
接收并处理屏蔽的源/目标序列。
- 形状:
src:
(N, S, E)
如果 batch_first。 ,tgt:
(N, T, E)
如果 batch_first。 ,src_mask: 。
tgt_mask: 。
memory_mask: 。
src_key_padding_mask: 。
tgt_key_padding_mask: 。
memory_key_padding_mask: 。
注意:[src/tgt/memory]_mask 确保允许位置 i 参加未屏蔽的位置。如果提供ByteTensor,则非零位不得参加,零位不变。如果提供了BoolTensor,则
True
的位置不允许参加,而False
的值将保持不变。如果提供了FloatTensor,它将被添加到注意力权重中。 [src/tgt/memory]_key_padding_mask提供了key中的指定元素被注意忽略。如果提供了ByteTensor,则将忽略非零位置,而零位置将保持不变。如果提供了BoolTensor,则将忽略具有True
值的位置,而具有False
值的位置将保持不变。输出:
(N, T, E)
。 ,如果batch_first,则为
注意:由于 Transformer 模型中的multi-head attention 架构,transformer 的输出序列长度与解码的输入序列(即目标)长度相同。
其中 S 是源序列长度,T 是目标序列长度,N 是批量大小,E 是特征数
例子
>>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
参数:
相关用法
- Python torch.nn.Transformer用法及代码示例
- Python torch.nn.TransformerEncoder用法及代码示例
- Python torch.nn.TransformerEncoderLayer用法及代码示例
- Python torch.nn.TransformerDecoderLayer用法及代码示例
- Python torch.nn.TransformerDecoder用法及代码示例
- Python torch.nn.TripletMarginWithDistanceLoss用法及代码示例
- Python torch.nn.TripletMarginLoss用法及代码示例
- Python torch.nn.Tanhshrink用法及代码示例
- Python torch.nn.Tanh用法及代码示例
- Python torch.nn.Threshold用法及代码示例
- Python torch.nn.InstanceNorm3d用法及代码示例
- Python torch.nn.quantized.dynamic.LSTM用法及代码示例
- Python torch.nn.EmbeddingBag用法及代码示例
- Python torch.nn.Module.register_forward_hook用法及代码示例
- Python torch.nn.AvgPool2d用法及代码示例
- Python torch.nn.PixelShuffle用法及代码示例
- Python torch.nn.CELU用法及代码示例
- Python torch.nn.Hardsigmoid用法及代码示例
- Python torch.nn.GLU用法及代码示例
- Python torch.nn.functional.conv1d用法及代码示例
- Python torch.nn.Identity用法及代码示例
- Python torch.nn.Sigmoid用法及代码示例
- Python torch.nn.utils.spectral_norm用法及代码示例
- Python torch.nn.utils.prune.custom_from_mask用法及代码示例
- Python torch.nn.MaxUnpool3d用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.Transformer.forward。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。