當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python tf.compat.v1.losses.mean_squared_error用法及代碼示例

在訓練過程中添加 Sum-of-Squares 損失。

用法

tf.compat.v1.losses.mean_squared_error(
    labels, predictions, weights=1.0, scope=None,
    loss_collection=tf.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
)

參數

  • labels 地麵實況輸出張量,與'predictions' 的維度相同。
  • predictions 預測的輸出。
  • weights 可選Tensor,其秩為0,或與labels相同的秩,並且必須可廣播到labels(即,所有維度必須是1,或與相應的losses維度相同) .
  • scope 在計算損失時執行的操作的範圍。
  • loss_collection 將添加損失的集合。
  • reduction 適用於損失的減免類型。

返回

  • 加權損失浮點數 Tensor 。如果 reductionNONE ,則其形狀與 labels 相同;否則,它是標量。

拋出

  • ValueError 如果predictions 的形狀與labels 的形狀不匹配或weights 的形狀無效。此外,如果 labelspredictions 為無。

遷移到 TF2

警告:這個 API 是為 TensorFlow v1 設計的。繼續閱讀有關如何從該 API 遷移到本機 TensorFlow v2 等效項的詳細信息。見TensorFlow v1 到 TensorFlow v2 遷移指南有關如何遷移其餘代碼的說明。

tf.compat.v1.losses.mean_squared_error 主要與即刻執行和 tf.function 兼容。但是,loss_collection 參數在即刻執行時會被忽略,並且不會將任何損失寫入損失集合。您將需要手動保留返回值或依賴 tf.keras.Model 損失跟蹤。

要切換到本機 TF2 樣式,請實例化 tf.keras.losses.MeanSquaredError 類並改為調用該對象。

到原生 TF2 的結構映射

前:

loss = tf.compat.v1.losses.mean_squared_error(
  labels=labels,
  predictions=predictions,
  weights=weights,
  reduction=reduction)

後:

loss_fn = tf.keras.losses.MeanSquaredError(
  reduction=reduction)
loss = loss_fn(
  y_true=labels,
  y_pred=predictions,
  sample_weight=weights)

如何映射參數

TF1 參數名稱 TF2 參數名稱 注意
labels y_true __call__() 方法中
predictions y_pred __call__() 方法中
weights sample_weight __call__() 方法中。 sample_weight 的形狀要求與 weights 不同。請檢查參數定義以獲取詳細信息。
scope 不支持 -
loss_collection 不支持 應顯式或使用 Keras API 跟蹤損失,例如add_loss,而不是通過集合
reduction reduction 在構造函數中。 tf.compat.v1.losses.softmax_cross_entropy中的tf.compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZEtf.compat.v1.losses.Reduction.SUMtf.compat.v1.losses.Reduction.NONE的值分別對應於tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZEtf.keras.losses.Reduction.SUMtf.keras.losses.Reduction.NONE。如果您為 reduction 使用其他值,包括默認值 tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS ,則沒有直接對應的值。請手動修改損失實現。

使用示例之前和之後

前:

y_true = [1, 2, 3]
y_pred = [1, 3, 5]
weights = [0, 1, 0.25]
# samples with zero-weight are excluded from calculation when `reduction`
# argument is set to default value `Reduction.SUM_BY_NONZERO_WEIGHTS`
tf.compat.v1.losses.mean_squared_error(
   labels=y_true,
   predictions=y_pred,
   weights=weights).numpy()
1.0
tf.compat.v1.losses.mean_squared_error(
   labels=y_true,
   predictions=y_pred,
   weights=weights,
   reduction=tf.compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZE).numpy()
0.66667

後:

y_true = [[1.0], [2.0], [3.0]]
y_pred = [[1.0], [3.0], [5.0]]
weights = [1, 1, 0.25]
mse = tf.keras.losses.MeanSquaredError(
   reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
mse(y_true=y_true, y_pred=y_pred, sample_weight=weights).numpy()
0.66667

weights 作為損失的係數。如果提供了標量,則損失隻是按給定值縮放。如果 weights 是大小為 [batch_size] 的張量,則批次的每個樣本的總損失將由 weights 向量中的相應元素重新縮放。如果 weights 的形狀與 predictions 的形狀匹配,則 predictions 的每個可測量元素的損失將按 weights 的相應值進行縮放。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.losses.mean_squared_error。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。