當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python tf.nn.weighted_cross_entropy_with_logits用法及代碼示例

計算加權交叉熵。

用法

tf.nn.weighted_cross_entropy_with_logits(
    labels, logits, pos_weight, name=None
)

參數

  • labels logits 具有相同類型和形狀的 Tensor ,其值介於 0 和 1(含)之間。
  • logits Tensor 類型為 float32float64 ,任何實數。
  • pos_weight 用於正樣本的係數,通常是標量,但可以廣播到 logits 的形狀。它的值應該是非負的。
  • name 操作的名稱(可選)。

返回

  • logits 形狀相同的 Tensor,具有分量加權邏輯損失。

拋出

  • ValueError 如果logitslabels 的形狀不同。

這就像 sigmoid_cross_entropy_with_logits() 除了 pos_weight 允許人們通過向上或 down-weighting 相對於負錯誤的正錯誤成本來權衡召回和精度。

通常的cross-entropy 成本定義為:

labels * -log(sigmoid(logits)) +
    (1 - labels) * -log(1 - sigmoid(logits))

pos_weight > 1 會減少假陰性計數,從而增加召回率。相反,設置 pos_weight < 1 會減少誤報計數並提高精度。這可以從 pos_weight 被引入作為損失表達式中正標簽項的乘法係數的事實中看出:

labels * -log(sigmoid(logits)) * pos_weight +
    (1 - labels) * -log(1 - sigmoid(logits))

為簡潔起見,讓 x = logits , z = labels , q = pos_weight 。損失是:

qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + (qz +  1 - z) * log(1 + exp(-x))
= (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))

設置 l = (1 + (q - 1) * z) ,為了保證穩定性和避免溢出,實現使用

(1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))

logitslabels 必須具有相同的類型和形狀。

labels = tf.constant([1., 0.5, 0.])
logits = tf.constant([1.5, -0.1, -10.])
tf.nn.weighted_cross_entropy_with_logits(
    labels=labels, logits=logits, pos_weight=tf.constant(1.5)).numpy()
array([3.0211994e-01, 8.8049585e-01, 4.5776367e-05], dtype=float32)
tf.nn.weighted_cross_entropy_with_logits(
    labels=labels, logits=logits, pos_weight=tf.constant(0.5)).numpy()
array([1.00706644e-01, 5.08297503e-01, 4.57763672e-05], dtype=float32)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.nn.weighted_cross_entropy_with_logits。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。