本文简要介绍python语言中 torchtext.nn.MultiheadAttentionContainer.__init__
的用法。
用法:
__init__(nhead, in_proj_container, attention_layer, out_proj, batch_first=False)
nhead-多头注意力模型中的头数
in_proj_container-multi-head in-projection 线性层(又名 nn.Linear)的容器。
attention_layer-自定义关注层。从 MHA 容器发送到注意力层的输入形状为
(…, L, N * H, E / H)
(用于查询)和(…, S, N * H, E / H)
(用于键/值),而注意力层的输出形状预计为(…, L, N * H, E / H)
。如果用户希望整个MultiheadAttentionContainer具有广播函数,则attention_layer需要支持广播。out_proj-multi-head out-projection 层(又名 nn.Linear)。
batch_first-如果
True
,则输入和输出张量作为(…, N, L, E)
提供。默认值:False
一个multi-head注意力容器
- 例子::
>>> import torch >>> from torchtext.nn import MultiheadAttentionContainer, InProjContainer, ScaledDotProduct >>> embed_dim, num_heads, bsz = 10, 5, 64 >>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim)) >>> MHA = MultiheadAttentionContainer(num_heads, in_proj_container, ScaledDotProduct(), torch.nn.Linear(embed_dim, embed_dim)) >>> query = torch.rand((21, bsz, embed_dim)) >>> key = value = torch.rand((16, bsz, embed_dim)) >>> attn_output, attn_weights = MHA(query, key, value) >>> print(attn_output.shape) >>> torch.Size([21, 64, 10])
参数:
相关用法
- Python PyTorch MultiheadAttention用法及代码示例
- Python PyTorch MultiStepLR用法及代码示例
- Python PyTorch MultiLabelMarginLoss用法及代码示例
- Python PyTorch MultiplicativeLR用法及代码示例
- Python PyTorch MultivariateNormal用法及代码示例
- Python PyTorch MultiScaleRoIAlign用法及代码示例
- Python PyTorch MultiMarginLoss用法及代码示例
- Python PyTorch Multiplexer用法及代码示例
- Python PyTorch Multinomial用法及代码示例
- Python PyTorch MuLawEncoding用法及代码示例
- Python PyTorch MuLawDecoding用法及代码示例
- Python PyTorch MaxUnpool3d用法及代码示例
- Python PyTorch Module.buffers用法及代码示例
- Python PyTorch Module.register_full_backward_hook用法及代码示例
- Python PyTorch Module.named_modules用法及代码示例
- Python PyTorch Module.parameters用法及代码示例
- Python PyTorch MaxPool1d用法及代码示例
- Python PyTorch Module.register_forward_hook用法及代码示例
- Python PyTorch MetaInferGroupedPooledEmbeddingsLookup.state_dict用法及代码示例
- Python PyTorch Module.named_parameters用法及代码示例
- Python PyTorch MetaInferGroupedEmbeddingsLookup.named_buffers用法及代码示例
- Python PyTorch ModuleList用法及代码示例
- Python PyTorch MixtureSameFamily用法及代码示例
- Python PyTorch MpSerialExecutor用法及代码示例
- Python PyTorch MaxUnpool1d用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchtext.nn.MultiheadAttentionContainer.__init__。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。