當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python tf.keras.metrics.binary_focal_crossentropy用法及代碼示例

計算二元焦點交叉熵損失。

用法

tf.keras.metrics.binary_focal_crossentropy(
    y_true, y_pred, gamma=2.0, from_logits=False, label_smoothing=0.0, axis=-1
)

參數

  • y_true 地麵真值,形狀為 (batch_size, d0, .. dN)
  • y_pred 形狀為 (batch_size, d0, .. dN) 的預測值。
  • gamma 一個對焦參數,默認是參考中提到的2.0
  • from_logits y_pred 是否預期為 logits 張量。默認情況下,我們假設 y_pred 對概率分布進行編碼。
  • label_smoothing 浮點數在 [0, 1] 中。如果高於 0,則通過將標簽向 0.5 擠壓來平滑標簽,即,對目標類使用 1. - 0.5 * label_smoothing,對非目標類使用 0.5 * label_smoothing
  • axis 計算平均值的軸。默認為 -1

返回

  • 二元焦點交叉熵損失值。形狀 = [batch_size, d0, .. dN-1]

根據 Lin 等人,2018 年的說法,它有助於將焦點因子應用於down-weight 簡單示例並更多地關注困難示例。默認情況下,焦點張量計算如下:

focal_factor = (1 - output)**gamma 用於 1 類 focal_factor = output**gamma 用於 0 類,其中 gamma 是聚焦參數。當gamma = 0 時,這個函數相當於二元交叉熵損失。

單機使用:

y_true = [[0, 1], [0, 0]]
y_pred = [[0.6, 0.4], [0.4, 0.6]]
loss = tf.keras.losses.binary_focal_crossentropy(y_true, y_pred, gamma=2)
assert loss.shape == (2,)
loss.numpy()
array([0.330, 0.206], dtype=float32)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.keras.metrics.binary_focal_crossentropy。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。