當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch embedding用法及代碼示例


本文簡要介紹python語言中 torch.nn.functional.embedding 的用法。

用法:

torch.nn.functional.embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)

參數

  • input(LongTensor) -包含嵌入矩陣中的索引的張量

  • weight(Tensor) -行數等於最大可能索引 + 1 且列數等於嵌入大小的嵌入矩陣

  • padding_idx(int,可選的) -如果指定,padding_idx 處的條目不會影響梯度;因此,padding_idx 處的嵌入向量在訓練期間不會更新,即它保持為固定的 “pad”。

  • max_norm(float,可選的) -如果給定,每個範數大於 max_norm 的嵌入向量將被重新歸一化為範數 max_norm 。注意:這將就地修改weight

  • norm_type(float,可選的) -p-norm 的 p 為 max_norm 選項計算。默認 2

  • scale_grad_by_freq(布爾值,可選的) -如果給定,這將通過小批量中單詞頻率的倒數來縮放梯度。默認 False

  • sparse(bool,可選的) -如果 True ,梯度 w.r.t. weight 將是一個稀疏張量。有關稀疏漸變的更多詳細信息,請參閱 torch.nn.Embedding 下的注釋。

一個簡單的查找表,用於查找固定字典和大小的嵌入。

該模塊通常用於使用索引檢索詞嵌入。模塊的輸入是索引列表和嵌入矩陣,輸出是相應的詞嵌入。

有關詳細信息,請參閱 torch.nn.Embedding

形狀:
  • 輸入:任意形狀的LongTensor,包含要提取的索引

  • 權重:形狀為 (V, embedding_dim) 的浮點類型嵌入矩陣,其中 V = 最大索引 + 1,embedding_dim = 嵌入大小

  • 輸出:(*, embedding_dim) ,其中 * 是輸入形狀

例子:

>>> # a batch of 2 samples of 4 indices each
>>> input = torch.tensor([[1,2,4,5],[4,3,2,9]])
>>> # an embedding matrix containing 10 tensors of size 3
>>> embedding_matrix = torch.rand(10, 3)
>>> F.embedding(input, embedding_matrix)
tensor([[[ 0.8490,  0.9625,  0.6753],
         [ 0.9666,  0.7761,  0.6108],
         [ 0.6246,  0.9751,  0.3618],
         [ 0.4161,  0.2419,  0.7383]],

        [[ 0.6246,  0.9751,  0.3618],
         [ 0.0237,  0.7794,  0.0528],
         [ 0.9666,  0.7761,  0.6108],
         [ 0.3385,  0.8612,  0.1867]]])

>>> # example with padding_idx
>>> weights = torch.rand(10, 3)
>>> weights[0, :].zero_()
>>> embedding_matrix = weights
>>> input = torch.tensor([[0,2,0,5]])
>>> F.embedding(input, embedding_matrix, padding_idx=0)
tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.5609,  0.5384,  0.8720],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.6262,  0.2438,  0.7471]]])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.functional.embedding。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。