当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch MultiheadAttentionContainer.__init__用法及代码示例


本文简要介绍python语言中 torchtext.nn.MultiheadAttentionContainer.__init__ 的用法。

用法:

__init__(nhead, in_proj_container, attention_layer, out_proj, batch_first=False)

参数

  • nhead-多头注意力模型中的头数

  • in_proj_container-multi-head in-projection 线性层(又名 nn.Linear)的容器。

  • attention_layer-自定义关注层。从 MHA 容器发送到注意力层的输入形状为 (…, L, N * H, E / H)(用于查询)和 (…, S, N * H, E / H)(用于键/值),而注意力层的输出形状预计为 (…, L, N * H, E / H) 。如果用户希望整个MultiheadAttentionContainer具有广播函数,则attention_layer需要支持广播。

  • out_proj-multi-head out-projection 层(又名 nn.Linear)。

  • batch_first-如果 True ,则输入和输出张量作为 (…, N, L, E) 提供。默认值:False

一个multi-head注意力容器

例子::
>>> import torch
>>> from torchtext.nn import MultiheadAttentionContainer, InProjContainer, ScaledDotProduct
>>> embed_dim, num_heads, bsz = 10, 5, 64
>>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim),
                                        torch.nn.Linear(embed_dim, embed_dim),
                                        torch.nn.Linear(embed_dim, embed_dim))
>>> MHA = MultiheadAttentionContainer(num_heads,
                                      in_proj_container,
                                      ScaledDotProduct(),
                                      torch.nn.Linear(embed_dim, embed_dim))
>>> query = torch.rand((21, bsz, embed_dim))
>>> key = value = torch.rand((16, bsz, embed_dim))
>>> attn_output, attn_weights = MHA(query, key, value)
>>> print(attn_output.shape)
>>> torch.Size([21, 64, 10])

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchtext.nn.MultiheadAttentionContainer.__init__。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。