计算加权交叉熵。
用法
tf.nn.weighted_cross_entropy_with_logits(
labels, logits, pos_weight, name=None
)参数
-
labels与logits具有相同类型和形状的Tensor,其值介于 0 和 1(含)之间。 -
logitsTensor类型为float32或float64,任何实数。 -
pos_weight用于正样本的系数,通常是标量,但可以广播到logits的形状。它的值应该是非负的。 -
name操作的名称(可选)。
返回
-
与
logits形状相同的Tensor,具有分量加权逻辑损失。
抛出
-
ValueError如果logits和labels的形状不同。
这就像 sigmoid_cross_entropy_with_logits() 除了 pos_weight 允许人们通过向上或 down-weighting 相对于负错误的正错误成本来权衡召回和精度。
通常的cross-entropy 成本定义为:
labels * -log(sigmoid(logits)) +
(1 - labels) * -log(1 - sigmoid(logits))
值 pos_weight > 1 会减少假阴性计数,从而增加召回率。相反,设置 pos_weight < 1 会减少误报计数并提高精度。这可以从 pos_weight 被引入作为损失表达式中正标签项的乘法系数的事实中看出:
labels * -log(sigmoid(logits)) * pos_weight +
(1 - labels) * -log(1 - sigmoid(logits))
为简洁起见,让 x = logits , z = labels , q = pos_weight 。损失是:
qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + (qz + 1 - z) * log(1 + exp(-x))
= (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
设置 l = (1 + (q - 1) * z) ,为了保证稳定性和避免溢出,实现使用
(1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
logits 和labels 必须具有相同的类型和形状。
labels = tf.constant([1., 0.5, 0.])
logits = tf.constant([1.5, -0.1, -10.])
tf.nn.weighted_cross_entropy_with_logits(
labels=labels, logits=logits, pos_weight=tf.constant(1.5)).numpy()
array([3.0211994e-01, 8.8049585e-01, 4.5776367e-05], dtype=float32)
tf.nn.weighted_cross_entropy_with_logits(
labels=labels, logits=logits, pos_weight=tf.constant(0.5)).numpy()
array([1.00706644e-01, 5.08297503e-01, 4.57763672e-05], dtype=float32)
相关用法
- Python tf.nn.embedding_lookup_sparse用法及代码示例
- Python tf.nn.RNNCellResidualWrapper.set_weights用法及代码示例
- Python tf.nn.dropout用法及代码示例
- Python tf.nn.gelu用法及代码示例
- Python tf.nn.RNNCellDeviceWrapper.set_weights用法及代码示例
- Python tf.nn.embedding_lookup用法及代码示例
- Python tf.nn.RNNCellDeviceWrapper.get_weights用法及代码示例
- Python tf.nn.local_response_normalization用法及代码示例
- Python tf.nn.scale_regularization_loss用法及代码示例
- Python tf.nn.RNNCellResidualWrapper.add_loss用法及代码示例
- Python tf.nn.max_pool用法及代码示例
- Python tf.nn.RNNCellDropoutWrapper.set_weights用法及代码示例
- Python tf.nn.l2_loss用法及代码示例
- Python tf.nn.log_softmax用法及代码示例
- Python tf.nn.ctc_greedy_decoder用法及代码示例
- Python tf.nn.dilation2d用法及代码示例
- Python tf.nn.RNNCellResidualWrapper.get_weights用法及代码示例
- Python tf.nn.compute_average_loss用法及代码示例
- Python tf.nn.RNNCellDeviceWrapper用法及代码示例
- Python tf.nn.atrous_conv2d用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.nn.weighted_cross_entropy_with_logits。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。
