當前位置: 首頁>>技術教程>>正文


python – tf.nn.embedding_lookup函數有什麽作用?案例詳解

問題描述

tf.nn.embedding_lookup(params, ids, partition_strategy='mod', name=None)這個函數有什麽作用?看起來像查找表,也就是返回每個ID對應的參數(以ID為單位)?

例如,在Skip-Gram模型中,如果我們使用tf.nn.embedding_lookup(embeddings, train_inputs),那麽對於每個train_input,它會找到對應的嵌入(Embedding)?

 

簡單來說

embedding_lookup函數檢索params張量的行。該行為類似於對numpy中的數組使用索引。例如。

matrix = np.random.random([1024, 64])  # 64-dimensional embeddings
ids = np.array([0, 5, 17, 33])
print matrix[ids]  # prints a matrix of shape [4, 64] 

params參數也可以是張量的列表,在這種情況下,ids表示多個張量的索引組合。例如,給定ids[0, 3][1, 4][2, 5],得到的張量都是[2, 64]的列表

另外,partition_strategy參數可以控製ids在列表中的分配方式。當矩陣可能太大而無法合為一體時,分區策略對於較大規模的問題很有用。

 

進階說明

最簡單的形式類似於tf.gather。它根據ids指定的索引返回params的元素。

例如(假設您在tf.InteractiveSession()內)

params = tf.constant([10,20,30,40])
ids = tf.constant([0,1,2,3])
print tf.nn.embedding_lookup(params,ids).eval()

將返回[10 20 30 40],因為params的第一個元素(索引0)是10,params的第二個元素(索引1)是20,依此類推。

同樣,

params = tf.constant([10,20,30,40])
ids = tf.constant([1,1,3])
print tf.nn.embedding_lookup(params,ids).eval()

將返回[20 20 40]

但是embedding_lookup不僅限於此。 params參數可以是張量列表,而不是單個張量。

params1 = tf.constant([1,2])
params2 = tf.constant([10,20])
ids = tf.constant([2,0,2,1,2,3])
result = tf.nn.embedding_lookup([params1, params2], ids)

在這種情況下,在ids中指定的索引根據分區策略對應於張量的元素,其中默認分區策略為’mod’。

在’mod’策略中,索引0對應於列表中第一個張量的第一個元素。索引1對應於第二張量的第一元素。索引2對應於第三張量的第一個元素,依此類推。對於所有索引0..(n-1),假設參數是n張量的列表,則簡單地將索引i對應於第(i + 1)張量的第一個元素。

現在,索引n不能對應於張量n + 1,因為列表params僅包含n張量。因此索引n對應於第一個張量的第二個元素。類似地,索引n+1對應於第二張量的第二個元素,依此類推。

因此,在代碼中

params1 = tf.constant([1,2])
params2 = tf.constant([10,20])
ids = tf.constant([2,0,2,1,2,3])
result = tf.nn.embedding_lookup([params1, params2], ids)

下標0對應於第一個張量的第一個元素:1

索引1對應於第二張量的第一個元素:10

索引2對應於第一個張量的第二個元素:2

索引3對應於第二張量的第二個元素:20

因此,結果將是:

[ 2  1  2 10  2 20]

 

實際的例子

以文本嵌入(Embedding)為例,tf.nn.embedding_lookup()函數的目的是在嵌入矩陣中執行查找並返回單詞的嵌入(或簡單地說是矢量表示)。

一個簡單的嵌入矩陣(形狀:vocabulary_size x embedding_dimension)如下所示。 (即,每個單詞將由一個數字向量表示;也就是word2vec)


嵌入矩陣

the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.044457 -0.49688 -0.17862
like 0.36808 0.20834 -0.22319 0.046283 0.20098 0.27515 -0.77127 -0.76804
between 0.7503 0.71623 -0.27033 0.20059 -0.17008 0.68568 -0.061672 -0.054638
did 0.042523 -0.21172 0.044739 -0.19248 0.26224 0.0043991 -0.88195 0.55184
just 0.17698 0.065221 0.28548 -0.4243 0.7499 -0.14892 -0.66786 0.11788
national -1.1105 0.94945 -0.17078 0.93037 -0.2477 -0.70633 -0.8649 -0.56118
day 0.11626 0.53897 -0.39514 -0.26027 0.57706 -0.79198 -0.88374 0.30119
country -0.13531 0.15485 -0.07309 0.034013 -0.054457 -0.20541 -0.60086 -0.22407
under 0.13721 -0.295 -0.05916 -0.59235 0.02301 0.21884 -0.34254 -0.70213
such 0.61012 0.33512 -0.53499 0.36139 -0.39866 0.70627 -0.18699 -0.77246
second -0.29809 0.28069 0.087102 0.54455 0.70003 0.44778 -0.72565 0.62309 

我拆分了上述嵌入矩陣,並僅將單詞vocab裝入我們的詞匯表,並將相應的向量裝入emb數組。

vocab = ['the','like','between','did','just','national','day','country','under','such','second']

emb = np.array([[0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.044457, -0.49688, -0.17862],
   [0.36808, 0.20834, -0.22319, 0.046283, 0.20098, 0.27515, -0.77127, -0.76804],
   [0.7503, 0.71623, -0.27033, 0.20059, -0.17008, 0.68568, -0.061672, -0.054638],
   [0.042523, -0.21172, 0.044739, -0.19248, 0.26224, 0.0043991, -0.88195, 0.55184],
   [0.17698, 0.065221, 0.28548, -0.4243, 0.7499, -0.14892, -0.66786, 0.11788],
   [-1.1105, 0.94945, -0.17078, 0.93037, -0.2477, -0.70633, -0.8649, -0.56118],
   [0.11626, 0.53897, -0.39514, -0.26027, 0.57706, -0.79198, -0.88374, 0.30119],
   [-0.13531, 0.15485, -0.07309, 0.034013, -0.054457, -0.20541, -0.60086, -0.22407],
   [ 0.13721, -0.295, -0.05916, -0.59235, 0.02301, 0.21884, -0.34254, -0.70213],
   [ 0.61012, 0.33512, -0.53499, 0.36139, -0.39866, 0.70627, -0.18699, -0.77246 ],
   [ -0.29809, 0.28069, 0.087102, 0.54455, 0.70003, 0.44778, -0.72565, 0.62309 ]])


emb.shape
# (11, 8)

在TensorFlow中查找嵌入(Embedding)

現在,我們將看到如何對某些任意輸入語句執行嵌入查找。

In [54]: from collections import OrderedDict

# embedding as TF tensor (for now constant; could be tf.Variable() during training)
In [55]: tf_embedding = tf.constant(emb, dtype=tf.float32)

# input for which we need the embedding
In [56]: input_str = "like the country"

# build index based on our `vocabulary`
In [57]: word_to_idx = OrderedDict({w:vocab.index(w) for w in input_str.split() if w in vocab})

# lookup in embedding matrix & return the vectors for the input words
In [58]: tf.nn.embedding_lookup(tf_embedding, list(word_to_idx.values())).eval()
Out[58]: 
array([[ 0.36807999,  0.20834   , -0.22318999,  0.046283  ,  0.20097999,
         0.27515   , -0.77126998, -0.76804   ],
       [ 0.41800001,  0.24968   , -0.41242   ,  0.1217    ,  0.34527001,
        -0.044457  , -0.49687999, -0.17862   ],
       [-0.13530999,  0.15485001, -0.07309   ,  0.034013  , -0.054457  ,
        -0.20541   , -0.60086   , -0.22407   ]], dtype=float32)

觀察我們如何使用詞匯表中的單詞索引從原始嵌入矩陣(帶單詞)中獲取嵌入。

通常,此類嵌入查找是由第一層(稱為嵌入層)執行的,然後將這些嵌入傳遞到RNN /LSTM /GRU層以進行進一步處理。


旁注:通常,詞匯表還將具有特殊的unk token。因此,如果我們的詞匯表中不存在來自我們輸入句子的token,那麽將在嵌入矩陣中查找與unk對應的索引。


附言請注意,embedding_dimension是一個超參數,必須針對其應用進行調整,但是流行的模型(如Word2VecGloVe)使用300維向量來表示每個單詞。

參考閱讀word2vec skip-gram model

 

嵌入的圖示

這是描述嵌入查找過程的圖示。

簡而言之,它獲取由ID列表指定的嵌入層的相應行,並將其提供為張量。它是通過以下過程實現的。

  1. 定義一個占位符lookup_ids = tf.placeholder([10])
  2. 定義嵌入層embeddings = tf.Variable([100,10],...)
  3. 定義張量流操作embed_lookup = tf.embedding_lookup(embeddings, lookup_ids)
  4. 通過運行lookup = session.run(embed_lookup, feed_dict={lookup_ids:[95,4,14]})獲取結果

 

高維embedding_lookup例子

當參數張量為高維時,id僅指最大維。也許對大多數人來說很明顯,但是我必須運行以下代碼才能理解這一點:

embeddings = tf.constant([[[1,1],[2,2],[3,3],[4,4]],[[11,11],[12,12],[13,13],[14,14]],
                          [[21,21],[22,22],[23,23],[24,24]]])
ids=tf.constant([0,2,1])
embed = tf.nn.embedding_lookup(embeddings, ids, partition_strategy='div')

with tf.Session() as session:
    result = session.run(embed)
    print (result)

僅嘗試’div’策略,對於一個張量,這沒有什麽區別。

這是輸出:

[[[ 1  1]
  [ 2  2]
  [ 3  3]
  [ 4  4]]

 [[21 21]
  [22 22]
  [23 23]
  [24 24]]

 [[11 11]
  [12 12]
  [13 13]
  [14 14]]]

 

參考資料

 

本文由《純淨天空》出品。文章地址: https://vimsky.com/zh-tw/article/4298.html,未經允許,請勿轉載。