当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.nn.sampled_softmax_loss用法及代码示例


计算并返回采样的 softmax 训练损失。

用法

tf.compat.v1.nn.sampled_softmax_loss(
    weights, biases, labels, inputs, num_sampled, num_classes, num_true=1,
    sampled_values=None, remove_accidental_hits=True,
    partition_strategy='mod', name='sampled_softmax_loss',
    seed=None
)

参数

  • weights 形状为 [num_classes, dim]TensorTensor 对象的列表,其沿维度 0 的连接具有形状 [num_classes, dim]。 (possibly-sharded) 类嵌入。
  • biases 形状为 [num_classes]Tensor 。阶级偏见。
  • labels 类型为 int64 和形状为 [batch_size, num_true]Tensor 。目标类。请注意,此格式不同于 nn.softmax_cross_entropy_with_logitslabels 参数。
  • inputs 形状为 [batch_size, dim]Tensor 。输入网络的前向激活。
  • num_sampled 一个 int 。每批随机抽样的类数。
  • num_classes 一个 int 。可能的类数。
  • num_true 一个 int 。每个训练示例的目标类数。
  • sampled_values *_candidate_sampler 函数返回的 (sampled_candidates , true_expected_count , sampled_expected_count) 元组。 (如果没有,我们默认为 log_uniform_candidate_sampler )
  • remove_accidental_hits 一个bool。是否删除"accidental hits",其中采样类等于目标类之一。默认为真。
  • partition_strategy 指定分区策略的字符串,如果 len(weights) > 1 则相关。目前支持"div""mod"。默认为 "mod" 。有关详细信息,请参阅tf.nn.embedding_lookup
  • name 操作的名称(可选)。
  • seed 候选抽样的随机种子。默认为 None,它不会为候选采样设置 op-level 随机种子。

返回

  • per-example 的 batch_size 一维张量采样了 softmax 损失。

这是在大量类上训练 softmax 分类器的更快方法。

此操作仅用于训练。它通常低估了完整的 softmax 损失。

一个常见的用例是使用这种方法进行训练,并计算完整的 softmax 损失以进行评估或推理。在这种情况下,您必须设置 partition_strategy="div" 以使两个损失保持一致,如下例所示:

if mode == "train":
  loss = tf.nn.sampled_softmax_loss(
      weights=weights,
      biases=biases,
      labels=labels,
      inputs=inputs,
      ...,
      partition_strategy="div")
elif mode == "eval":
  logits = tf.matmul(inputs, tf.transpose(weights))
  logits = tf.nn.bias_add(logits, biases)
  labels_one_hot = tf.one_hot(labels, n_classes)
  loss = tf.nn.softmax_cross_entropy_with_logits(
      labels=labels_one_hot,
      logits=logits)

请参阅我们的候选抽样算法参考 (pdf)。另请参阅(Jean 等人,2014 年)的第 3 节了解数学。

参考:

关于使用非常大的目标词汇进行神经机器翻译:Jean 等人,2014 (pdf)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.nn.sampled_softmax_loss。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。