當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch MultiheadAttentionContainer.__init__用法及代碼示例


本文簡要介紹python語言中 torchtext.nn.MultiheadAttentionContainer.__init__ 的用法。

用法:

__init__(nhead, in_proj_container, attention_layer, out_proj, batch_first=False)

參數

  • nhead-多頭注意力模型中的頭數

  • in_proj_container-multi-head in-projection 線性層(又名 nn.Linear)的容器。

  • attention_layer-自定義關注層。從 MHA 容器發送到注意力層的輸入形狀為 (…, L, N * H, E / H)(用於查詢)和 (…, S, N * H, E / H)(用於鍵/值),而注意力層的輸出形狀預計為 (…, L, N * H, E / H) 。如果用戶希望整個MultiheadAttentionContainer具有廣播函數,則attention_layer需要支持廣播。

  • out_proj-multi-head out-projection 層(又名 nn.Linear)。

  • batch_first-如果 True ,則輸入和輸出張量作為 (…, N, L, E) 提供。默認值:False

一個multi-head注意力容器

例子::
>>> import torch
>>> from torchtext.nn import MultiheadAttentionContainer, InProjContainer, ScaledDotProduct
>>> embed_dim, num_heads, bsz = 10, 5, 64
>>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim),
                                        torch.nn.Linear(embed_dim, embed_dim),
                                        torch.nn.Linear(embed_dim, embed_dim))
>>> MHA = MultiheadAttentionContainer(num_heads,
                                      in_proj_container,
                                      ScaledDotProduct(),
                                      torch.nn.Linear(embed_dim, embed_dim))
>>> query = torch.rand((21, bsz, embed_dim))
>>> key = value = torch.rand((16, bsz, embed_dim))
>>> attn_output, attn_weights = MHA(query, key, value)
>>> print(attn_output.shape)
>>> torch.Size([21, 64, 10])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torchtext.nn.MultiheadAttentionContainer.__init__。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。