计算并返回 noise-contrastive 估计训练损失。
用法
tf.compat.v1.nn.nce_loss(
weights, biases, labels, inputs, num_sampled, num_classes, num_true=1,
sampled_values=None, remove_accidental_hits=False,
partition_strategy='mod', name='nce_loss'
)
参数
-
weights
形状为[num_classes, dim]
的Tensor
或Tensor
对象的列表,其沿维度 0 的连接具有形状 [num_classes, dim]。 (possibly-partitioned) 类嵌入。 -
biases
形状为[num_classes]
的Tensor
。阶级偏见。 -
labels
类型为int64
和形状为[batch_size, num_true]
的Tensor
。目标类。 -
inputs
形状为[batch_size, dim]
的Tensor
。输入网络的前向激活。 -
num_sampled
一个int
。每批随机抽样的负类数量。为批次中的每个元素评估这个负类的单个样本。 -
num_classes
一个int
。可能的类数。 -
num_true
一个int
。每个训练示例的目标类数。 -
sampled_values
*_candidate_sampler
函数返回的 (sampled_candidates
,true_expected_count
,sampled_expected_count
) 元组。 (如果没有,我们默认为log_uniform_candidate_sampler
) -
remove_accidental_hits
Abool
.是否删除"accidental hits",其中采样类等于目标类之一。如果设置为True
,这是 "Sampled Logistic" 损失而不是 NCE,我们正在学习生成 log-odds 而不是日志概率。请参阅我们的候选抽样算法参考(pdf)。默认为假。 -
partition_strategy
指定分区策略的字符串,如果len(weights) > 1
则相关。目前支持"div"
和"mod"
。默认为"mod"
。有关详细信息,请参阅tf.nn.embedding_lookup
。 -
name
操作的名称(可选)。
返回
-
per-example NCE 损失的
batch_size
一维张量。
一个常见的用例是使用这种方法进行训练,并计算完整的 sigmoid 损失以进行评估或推理。在这种情况下,您必须设置 partition_strategy="div"
以使两个损失保持一致,如下例所示:
if mode == "train":
loss = tf.nn.nce_loss(
weights=weights,
biases=biases,
labels=labels,
inputs=inputs,
...,
partition_strategy="div")
elif mode == "eval":
logits = tf.matmul(inputs, tf.transpose(weights))
logits = tf.nn.bias_add(logits, biases)
labels_one_hot = tf.one_hot(labels, n_classes)
loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=labels_one_hot,
logits=logits)
loss = tf.reduce_sum(loss, axis=1)
注意:默认情况下,这使用 log-uniform (Zipfian) 分布进行采样,因此您的标签必须按照频率递减的顺序进行排序才能获得良好的结果。有关详细信息,请参阅 tf.random.log_uniform_candidate_sampler
。
注意:在 num_true
> 1 的情况下,我们为每个目标类别分配目标概率 1 /num_true
,以便目标概率总和为 1 per-example。
注意:每个示例允许可变数量的目标类将很有用。我们希望在未来的版本中提供此函数。现在,如果您有可变数量的目标类,您可以通过重复它们或使用其他未使用的类填充它们来将它们填充为一个常量。
参考:
Noise-contrastive 估计 - 非归一化统计模型的新估计原理:Gutmann 等人,2010 (pdf)
相关用法
- Python tf.compat.v1.nn.static_rnn用法及代码示例
- Python tf.compat.v1.nn.sufficient_statistics用法及代码示例
- Python tf.compat.v1.nn.dynamic_rnn用法及代码示例
- Python tf.compat.v1.nn.embedding_lookup_sparse用法及代码示例
- Python tf.compat.v1.nn.separable_conv2d用法及代码示例
- Python tf.compat.v1.nn.depthwise_conv2d_native用法及代码示例
- Python tf.compat.v1.nn.weighted_cross_entropy_with_logits用法及代码示例
- Python tf.compat.v1.nn.depthwise_conv2d用法及代码示例
- Python tf.compat.v1.nn.convolution用法及代码示例
- Python tf.compat.v1.nn.conv2d用法及代码示例
- Python tf.compat.v1.nn.safe_embedding_lookup_sparse用法及代码示例
- Python tf.compat.v1.nn.sampled_softmax_loss用法及代码示例
- Python tf.compat.v1.nn.pool用法及代码示例
- Python tf.compat.v1.nn.sigmoid_cross_entropy_with_logits用法及代码示例
- Python tf.compat.v1.nn.ctc_loss用法及代码示例
- Python tf.compat.v1.nn.rnn_cell.MultiRNNCell用法及代码示例
- Python tf.compat.v1.nn.erosion2d用法及代码示例
- Python tf.compat.v1.nn.raw_rnn用法及代码示例
- Python tf.compat.v1.nn.dilation2d用法及代码示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.nn.nce_loss。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。