計算並返回采樣的 softmax 訓練損失。
用法
tf.nn.sampled_softmax_loss(
weights, biases, labels, inputs, num_sampled, num_classes, num_true=1,
sampled_values=None, remove_accidental_hits=True, seed=None,
name='sampled_softmax_loss'
)
參數
-
weights
形狀為[num_classes, dim]
的Tensor
或Tensor
對象的列表,其沿維度 0 的連接具有形狀 [num_classes, dim]。 (possibly-sharded) 類嵌入。 -
biases
形狀為[num_classes]
的Tensor
。階級偏見。 -
labels
類型為int64
和形狀為[batch_size, num_true]
的Tensor
。目標類。請注意,此格式不同於nn.softmax_cross_entropy_with_logits
的labels
參數。 -
inputs
形狀為[batch_size, dim]
的Tensor
。輸入網絡的前向激活。 -
num_sampled
一個int
。每批隨機抽樣的類數。 -
num_classes
一個int
。可能的類數。 -
num_true
一個int
。每個訓練示例的目標類數。 -
sampled_values
*_candidate_sampler
函數返回的 (sampled_candidates
,true_expected_count
,sampled_expected_count
) 元組。 (如果沒有,我們默認為log_uniform_candidate_sampler
) -
remove_accidental_hits
一個bool
。是否刪除"accidental hits",其中采樣類等於目標類之一。默認為真。 -
seed
候選抽樣的隨機種子。默認為 None,它不會為候選采樣設置 op-level 隨機種子。 -
name
操作的名稱(可選)。
返回
-
per-example 的
batch_size
一維張量采樣了 softmax 損失。
這是在大量類上訓練 softmax 分類器的更快方法。
此操作僅用於訓練。它通常低估了完整的 softmax 損失。
一個常見的用例是使用此方法進行訓練,並計算完整的 softmax 損失以進行評估或推理,如下例所示:
if mode == "train":
loss = tf.nn.sampled_softmax_loss(
weights=weights,
biases=biases,
labels=labels,
inputs=inputs,
...)
elif mode == "eval":
logits = tf.matmul(inputs, tf.transpose(weights))
logits = tf.nn.bias_add(logits, biases)
labels_one_hot = tf.one_hot(labels, n_classes)
loss = tf.nn.softmax_cross_entropy_with_logits(
labels=labels_one_hot,
logits=logits)
請參閱我們的候選抽樣算法參考
另請參閱 Jean 等人,2014 年 (pdf) 的第 3 節以了解數學。
注意:在 weights
和 bias
上進行嵌入查找時,將使用 "div" 分區策略。稍後將添加對其他分區策略的支持。
相關用法
- Python tf.nn.safe_embedding_lookup_sparse用法及代碼示例
- Python tf.nn.scale_regularization_loss用法及代碼示例
- Python tf.nn.softmax用法及代碼示例
- Python tf.nn.sigmoid_cross_entropy_with_logits用法及代碼示例
- Python tf.nn.space_to_depth用法及代碼示例
- Python tf.nn.separable_conv2d用法及代碼示例
- Python tf.nn.sparse_softmax_cross_entropy_with_logits用法及代碼示例
- Python tf.nn.softmax_cross_entropy_with_logits用法及代碼示例
- Python tf.nn.embedding_lookup_sparse用法及代碼示例
- Python tf.nn.RNNCellResidualWrapper.set_weights用法及代碼示例
- Python tf.nn.dropout用法及代碼示例
- Python tf.nn.gelu用法及代碼示例
- Python tf.nn.RNNCellDeviceWrapper.set_weights用法及代碼示例
- Python tf.nn.embedding_lookup用法及代碼示例
- Python tf.nn.RNNCellDeviceWrapper.get_weights用法及代碼示例
- Python tf.nn.local_response_normalization用法及代碼示例
- Python tf.nn.RNNCellResidualWrapper.add_loss用法及代碼示例
- Python tf.nn.max_pool用法及代碼示例
- Python tf.nn.RNNCellDropoutWrapper.set_weights用法及代碼示例
- Python tf.nn.l2_loss用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.nn.sampled_softmax_loss。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。