使用 tf.nn.softmax_cross_entropy_with_logits_v2 創建 cross-entropy 損失。
用法
tf.compat.v1.losses.softmax_cross_entropy(
onehot_labels, logits, weights=1.0, label_smoothing=0, scope=None,
loss_collection=tf.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
)
參數
-
onehot_labels
One-hot-encoded 標簽。 -
logits
Logits 網絡的輸出。 -
weights
可廣播丟失的可選Tensor
。 -
label_smoothing
如果大於 0,則平滑標簽。 -
scope
在計算損失時執行的操作的範圍。 -
loss_collection
將添加損失的集合。 -
reduction
適用於損失的減免類型。
返回
-
加權損失
Tensor
與logits
類型相同。如果reduction
是NONE
,則其形狀為[batch_size]
;否則,它是標量。
拋出
-
ValueError
如果logits
的形狀與onehot_labels
的形狀不匹配,或者如果weights
的形狀無效,或者如果weights
為 None。此外,如果onehot_labels
或logits
為無。
遷移到 TF2
警告:這個 API 是為 TensorFlow v1 設計的。繼續閱讀有關如何從該 API 遷移到本機 TensorFlow v2 等效項的詳細信息。見TensorFlow v1 到 TensorFlow v2 遷移指南有關如何遷移其餘代碼的說明。
tf.compat.v1.losses.softmax_cross_entropy
主要與即刻執行和 tf.function
兼容。但是,loss_collection
參數在即刻執行時會被忽略,並且不會將任何損失寫入損失集合。您將需要手動保留返回值或依賴 tf.keras.Model
損失跟蹤。
要切換到本機 TF2 樣式,請實例化 tf.keras.losses.CategoricalCrossentropy
類,並將 from_logits
設置為 True
,然後調用該對象。
到原生 TF2 的結構映射
前:
loss = tf.compat.v1.losses.softmax_cross_entropy(
onehot_labels=onehot_labels,
logits=logits,
weights=weights,
label_smoothing=smoothing)
後:
loss_fn = tf.keras.losses.CategoricalCrossentropy(
from_logits=True,
label_smoothing=smoothing)
loss = loss_fn(
y_true=onehot_labels,
y_pred=logits,
sample_weight=weights)
如何映射參數
TF1 參數名稱 | TF2 參數名稱 | 注意 |
---|---|---|
- | from_logits
|
將 from_logits 設置為 True 以具有相同的行為 |
onehot_labels |
y_true |
在__call__() 方法中 |
logits |
y_pred |
在__call__() 方法中 |
weights |
sample_weight |
在__call__() 方法中 |
label_smoothing |
label_smoothing |
在構造函數中 |
scope |
不支持 | - |
loss_collection
|
不支持 | 應顯式或使用 Keras API 跟蹤損失,例如add_loss,而不是通過集合 |
reduction
|
reduction
|
在構造函數中。 tf.compat.v1.losses.softmax_cross_entropy 中的tf.compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZE 、tf.compat.v1.losses.Reduction.SUM 、tf.compat.v1.losses.Reduction.NONE 的值分別對應於tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE 、tf.keras.losses.Reduction.SUM 、tf.keras.losses.Reduction.NONE 。如果您為 reduction 使用其他值,包括默認值 tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS ,則沒有直接對應的值。請手動修改損失實現。 |
使用示例之前和之後
前:
y_true = [[0, 1, 0], [0, 0, 1]]
y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
weights = [0.3, 0.7]
smoothing = 0.2
tf.compat.v1.losses.softmax_cross_entropy(y_true, y_pred, weights=weights,
label_smoothing=smoothing).numpy()
0.57618
後:
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True,
label_smoothing=smoothing)
cce(y_true, y_pred, sample_weight=weights).numpy()
0.57618
weights
作為損失的係數。如果提供了標量,則損失隻是按給定值縮放。如果 weights
是形狀為 [batch_size]
的張量,則損失權重適用於每個相應的樣本。
如果 label_smoothing
不為零,則將標簽平滑到 1/num_classes:new_onehot_labels = onehot_labels * (1 - label_smoothing)
+ label_smoothing / num_classes
請注意,onehot_labels
和 logits
必須具有相同的形狀,例如[batch_size, num_classes]
。 weights
的 shape 必須對 loss 是可廣播的,loss 的 shape 由 logits
的 shape 決定。如果 logits
的形狀是 [batch_size, num_classes]
,損失是形狀 [batch_size]
的 Tensor
。
相關用法
- Python tf.compat.v1.losses.sigmoid_cross_entropy用法及代碼示例
- Python tf.compat.v1.losses.mean_squared_error用法及代碼示例
- Python tf.compat.v1.losses.huber_loss用法及代碼示例
- Python tf.compat.v1.lookup.StaticHashTable用法及代碼示例
- Python tf.compat.v1.lookup.StaticVocabularyTable用法及代碼示例
- Python tf.compat.v1.layers.conv3d用法及代碼示例
- Python tf.compat.v1.layers.Conv3D用法及代碼示例
- Python tf.compat.v1.layers.dense用法及代碼示例
- Python tf.compat.v1.layers.AveragePooling3D用法及代碼示例
- Python tf.compat.v1.lite.TFLiteConverter用法及代碼示例
- Python tf.compat.v1.layers.Conv2DTranspose用法及代碼示例
- Python tf.compat.v1.layers.max_pooling3d用法及代碼示例
- Python tf.compat.v1.layers.average_pooling1d用法及代碼示例
- Python tf.compat.v1.layers.experimental.keras_style_scope用法及代碼示例
- Python tf.compat.v1.layers.flatten用法及代碼示例
- Python tf.compat.v1.layers.conv1d用法及代碼示例
- Python tf.compat.v1.layers.experimental.set_keras_style用法及代碼示例
- Python tf.compat.v1.layers.conv2d_transpose用法及代碼示例
- Python tf.compat.v1.layers.dropout用法及代碼示例
- Python tf.compat.v1.layers.batch_normalization用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.losses.softmax_cross_entropy。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。