計算並返回 noise-contrastive 估計訓練損失。
用法
tf.compat.v1.nn.nce_loss(
weights, biases, labels, inputs, num_sampled, num_classes, num_true=1,
sampled_values=None, remove_accidental_hits=False,
partition_strategy='mod', name='nce_loss'
)
參數
-
weights
形狀為[num_classes, dim]
的Tensor
或Tensor
對象的列表,其沿維度 0 的連接具有形狀 [num_classes, dim]。 (possibly-partitioned) 類嵌入。 -
biases
形狀為[num_classes]
的Tensor
。階級偏見。 -
labels
類型為int64
和形狀為[batch_size, num_true]
的Tensor
。目標類。 -
inputs
形狀為[batch_size, dim]
的Tensor
。輸入網絡的前向激活。 -
num_sampled
一個int
。每批隨機抽樣的負類數量。為批次中的每個元素評估這個負類的單個樣本。 -
num_classes
一個int
。可能的類數。 -
num_true
一個int
。每個訓練示例的目標類數。 -
sampled_values
*_candidate_sampler
函數返回的 (sampled_candidates
,true_expected_count
,sampled_expected_count
) 元組。 (如果沒有,我們默認為log_uniform_candidate_sampler
) -
remove_accidental_hits
Abool
.是否刪除"accidental hits",其中采樣類等於目標類之一。如果設置為True
,這是 "Sampled Logistic" 損失而不是 NCE,我們正在學習生成 log-odds 而不是日誌概率。請參閱我們的候選抽樣算法參考(pdf)。默認為假。 -
partition_strategy
指定分區策略的字符串,如果len(weights) > 1
則相關。目前支持"div"
和"mod"
。默認為"mod"
。有關詳細信息,請參閱tf.nn.embedding_lookup
。 -
name
操作的名稱(可選)。
返回
-
per-example NCE 損失的
batch_size
一維張量。
一個常見的用例是使用這種方法進行訓練,並計算完整的 sigmoid 損失以進行評估或推理。在這種情況下,您必須設置 partition_strategy="div"
以使兩個損失保持一致,如下例所示:
if mode == "train":
loss = tf.nn.nce_loss(
weights=weights,
biases=biases,
labels=labels,
inputs=inputs,
...,
partition_strategy="div")
elif mode == "eval":
logits = tf.matmul(inputs, tf.transpose(weights))
logits = tf.nn.bias_add(logits, biases)
labels_one_hot = tf.one_hot(labels, n_classes)
loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=labels_one_hot,
logits=logits)
loss = tf.reduce_sum(loss, axis=1)
注意:默認情況下,這使用 log-uniform (Zipfian) 分布進行采樣,因此您的標簽必須按照頻率遞減的順序進行排序才能獲得良好的結果。有關詳細信息,請參閱 tf.random.log_uniform_candidate_sampler
。
注意:在 num_true
> 1 的情況下,我們為每個目標類別分配目標概率 1 /num_true
,以便目標概率總和為 1 per-example。
注意:每個示例允許可變數量的目標類將很有用。我們希望在未來的版本中提供此函數。現在,如果您有可變數量的目標類,您可以通過重複它們或使用其他未使用的類填充它們來將它們填充為一個常量。
參考:
Noise-contrastive 估計 - 非歸一化統計模型的新估計原理:Gutmann 等人,2010 (pdf)
相關用法
- Python tf.compat.v1.nn.static_rnn用法及代碼示例
- Python tf.compat.v1.nn.sufficient_statistics用法及代碼示例
- Python tf.compat.v1.nn.dynamic_rnn用法及代碼示例
- Python tf.compat.v1.nn.embedding_lookup_sparse用法及代碼示例
- Python tf.compat.v1.nn.separable_conv2d用法及代碼示例
- Python tf.compat.v1.nn.depthwise_conv2d_native用法及代碼示例
- Python tf.compat.v1.nn.weighted_cross_entropy_with_logits用法及代碼示例
- Python tf.compat.v1.nn.depthwise_conv2d用法及代碼示例
- Python tf.compat.v1.nn.convolution用法及代碼示例
- Python tf.compat.v1.nn.conv2d用法及代碼示例
- Python tf.compat.v1.nn.safe_embedding_lookup_sparse用法及代碼示例
- Python tf.compat.v1.nn.sampled_softmax_loss用法及代碼示例
- Python tf.compat.v1.nn.pool用法及代碼示例
- Python tf.compat.v1.nn.sigmoid_cross_entropy_with_logits用法及代碼示例
- Python tf.compat.v1.nn.ctc_loss用法及代碼示例
- Python tf.compat.v1.nn.rnn_cell.MultiRNNCell用法及代碼示例
- Python tf.compat.v1.nn.erosion2d用法及代碼示例
- Python tf.compat.v1.nn.raw_rnn用法及代碼示例
- Python tf.compat.v1.nn.dilation2d用法及代碼示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.nn.nce_loss。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。