本文简要介绍python语言中 torch.nn.MultiheadAttention
的用法。
用法:
class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)
embed_dim-模型的总尺寸。
num_heads-并行注意力头的数量。请注意,
embed_dim
将被拆分为num_heads
(即每个头部都有维度embed_dim // num_heads
)。dropout-
attn_output_weights
上的辍学概率。默认值:0.0
(无丢失)。bias-如果指定,则向输入/输出投影层添加偏差。默认值:
True
。add_bias_kv-如果指定,则在 dim=0 处向键和值序列添加偏差。默认值:
False
。add_zero_attn-如果指定,则将一批新的零添加到 dim=1 处的键和值序列。默认值:
False
。kdim-键的函数总数。默认值:
None
(使用kdim=embed_dim
)。vdim-值的特征总数。默认值:
None
(使用vdim=embed_dim
)。batch_first-如果
True
,则输入和输出张量提供为 (batch, seq, feature)。默认值:False
(序列、批处理、特征)。
允许模型共同关注来自不同表示子空间的信息。见Attention Is All You Need。
其中 。
例子:
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
参数:
相关用法
- Python PyTorch MultiheadAttentionContainer.__init__用法及代码示例
- Python PyTorch MultiStepLR用法及代码示例
- Python PyTorch MultiLabelMarginLoss用法及代码示例
- Python PyTorch MultiplicativeLR用法及代码示例
- Python PyTorch MultivariateNormal用法及代码示例
- Python PyTorch MultiScaleRoIAlign用法及代码示例
- Python PyTorch MultiMarginLoss用法及代码示例
- Python PyTorch Multiplexer用法及代码示例
- Python PyTorch Multinomial用法及代码示例
- Python PyTorch MuLawEncoding用法及代码示例
- Python PyTorch MuLawDecoding用法及代码示例
- Python PyTorch MaxUnpool3d用法及代码示例
- Python PyTorch Module.buffers用法及代码示例
- Python PyTorch Module.register_full_backward_hook用法及代码示例
- Python PyTorch Module.named_modules用法及代码示例
- Python PyTorch Module.parameters用法及代码示例
- Python PyTorch MaxPool1d用法及代码示例
- Python PyTorch Module.register_forward_hook用法及代码示例
- Python PyTorch MetaInferGroupedPooledEmbeddingsLookup.state_dict用法及代码示例
- Python PyTorch Module.named_parameters用法及代码示例
- Python PyTorch MetaInferGroupedEmbeddingsLookup.named_buffers用法及代码示例
- Python PyTorch ModuleList用法及代码示例
- Python PyTorch MixtureSameFamily用法及代码示例
- Python PyTorch MpSerialExecutor用法及代码示例
- Python PyTorch MaxUnpool1d用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.MultiheadAttention。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。