當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch MultiheadAttention用法及代碼示例


本文簡要介紹python語言中 torch.nn.MultiheadAttention 的用法。

用法:

class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)

參數

  • embed_dim-模型的總尺寸。

  • num_heads-並行注意力頭的數量。請注意,embed_dim 將被拆分為 num_heads (即每個頭部都有維度 embed_dim // num_heads )。

  • dropout-attn_output_weights 上的輟學概率。默認值:0.0(無丟失)。

  • bias-如果指定,則向輸入/輸出投影層添加偏差。默認值:True

  • add_bias_kv-如果指定,則在 dim=0 處向鍵和值序列添加偏差。默認值:False

  • add_zero_attn-如果指定,則將一批新的零添加到 dim=1 處的鍵和值序列。默認值:False

  • kdim-鍵的函數總數。默認值:None(使用 kdim=embed_dim )。

  • vdim-值的特征總數。默認值:None(使用 vdim=embed_dim )。

  • batch_first-如果 True ,則輸入和輸出張量提供為 (batch, seq, feature)。默認值:False(序列、批處理、特征)。

允許模型共同關注來自不同表示子空間的信息。見Attention Is All You Need

其中

例子:

>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.MultiheadAttention。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。