使用 tf.nn.softmax_cross_entropy_with_logits_v2 创建 cross-entropy 损失。
用法
tf.compat.v1.losses.softmax_cross_entropy(
onehot_labels, logits, weights=1.0, label_smoothing=0, scope=None,
loss_collection=tf.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
)
参数
-
onehot_labels
One-hot-encoded 标签。 -
logits
Logits 网络的输出。 -
weights
可广播丢失的可选Tensor
。 -
label_smoothing
如果大于 0,则平滑标签。 -
scope
在计算损失时执行的操作的范围。 -
loss_collection
将添加损失的集合。 -
reduction
适用于损失的减免类型。
返回
-
加权损失
Tensor
与logits
类型相同。如果reduction
是NONE
,则其形状为[batch_size]
;否则,它是标量。
抛出
-
ValueError
如果logits
的形状与onehot_labels
的形状不匹配,或者如果weights
的形状无效,或者如果weights
为 None。此外,如果onehot_labels
或logits
为无。
迁移到 TF2
警告:这个 API 是为 TensorFlow v1 设计的。继续阅读有关如何从该 API 迁移到本机 TensorFlow v2 等效项的详细信息。见TensorFlow v1 到 TensorFlow v2 迁移指南有关如何迁移其余代码的说明。
tf.compat.v1.losses.softmax_cross_entropy
主要与即刻执行和 tf.function
兼容。但是,loss_collection
参数在即刻执行时会被忽略,并且不会将任何损失写入损失集合。您将需要手动保留返回值或依赖 tf.keras.Model
损失跟踪。
要切换到本机 TF2 样式,请实例化 tf.keras.losses.CategoricalCrossentropy
类,并将 from_logits
设置为 True
,然后调用该对象。
到原生 TF2 的结构映射
前:
loss = tf.compat.v1.losses.softmax_cross_entropy(
onehot_labels=onehot_labels,
logits=logits,
weights=weights,
label_smoothing=smoothing)
后:
loss_fn = tf.keras.losses.CategoricalCrossentropy(
from_logits=True,
label_smoothing=smoothing)
loss = loss_fn(
y_true=onehot_labels,
y_pred=logits,
sample_weight=weights)
如何映射参数
TF1 参数名称 | TF2 参数名称 | 注意 |
---|---|---|
- | from_logits
|
将 from_logits 设置为 True 以具有相同的行为 |
onehot_labels |
y_true |
在__call__() 方法中 |
logits |
y_pred |
在__call__() 方法中 |
weights |
sample_weight |
在__call__() 方法中 |
label_smoothing |
label_smoothing |
在构造函数中 |
scope |
不支持 | - |
loss_collection
|
不支持 | 应显式或使用 Keras API 跟踪损失,例如add_loss,而不是通过集合 |
reduction
|
reduction
|
在构造函数中。 tf.compat.v1.losses.softmax_cross_entropy 中的tf.compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZE 、tf.compat.v1.losses.Reduction.SUM 、tf.compat.v1.losses.Reduction.NONE 的值分别对应于tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE 、tf.keras.losses.Reduction.SUM 、tf.keras.losses.Reduction.NONE 。如果您为 reduction 使用其他值,包括默认值 tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS ,则没有直接对应的值。请手动修改损失实现。 |
使用示例之前和之后
前:
y_true = [[0, 1, 0], [0, 0, 1]]
y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
weights = [0.3, 0.7]
smoothing = 0.2
tf.compat.v1.losses.softmax_cross_entropy(y_true, y_pred, weights=weights,
label_smoothing=smoothing).numpy()
0.57618
后:
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True,
label_smoothing=smoothing)
cce(y_true, y_pred, sample_weight=weights).numpy()
0.57618
weights
作为损失的系数。如果提供了标量,则损失只是按给定值缩放。如果 weights
是形状为 [batch_size]
的张量,则损失权重适用于每个相应的样本。
如果 label_smoothing
不为零,则将标签平滑到 1/num_classes:new_onehot_labels = onehot_labels * (1 - label_smoothing)
+ label_smoothing / num_classes
请注意,onehot_labels
和 logits
必须具有相同的形状,例如[batch_size, num_classes]
。 weights
的 shape 必须对 loss 是可广播的,loss 的 shape 由 logits
的 shape 决定。如果 logits
的形状是 [batch_size, num_classes]
,损失是形状 [batch_size]
的 Tensor
。
相关用法
- Python tf.compat.v1.losses.sigmoid_cross_entropy用法及代码示例
- Python tf.compat.v1.losses.mean_squared_error用法及代码示例
- Python tf.compat.v1.losses.huber_loss用法及代码示例
- Python tf.compat.v1.lookup.StaticHashTable用法及代码示例
- Python tf.compat.v1.lookup.StaticVocabularyTable用法及代码示例
- Python tf.compat.v1.layers.conv3d用法及代码示例
- Python tf.compat.v1.layers.Conv3D用法及代码示例
- Python tf.compat.v1.layers.dense用法及代码示例
- Python tf.compat.v1.layers.AveragePooling3D用法及代码示例
- Python tf.compat.v1.lite.TFLiteConverter用法及代码示例
- Python tf.compat.v1.layers.Conv2DTranspose用法及代码示例
- Python tf.compat.v1.layers.max_pooling3d用法及代码示例
- Python tf.compat.v1.layers.average_pooling1d用法及代码示例
- Python tf.compat.v1.layers.experimental.keras_style_scope用法及代码示例
- Python tf.compat.v1.layers.flatten用法及代码示例
- Python tf.compat.v1.layers.conv1d用法及代码示例
- Python tf.compat.v1.layers.experimental.set_keras_style用法及代码示例
- Python tf.compat.v1.layers.conv2d_transpose用法及代码示例
- Python tf.compat.v1.layers.dropout用法及代码示例
- Python tf.compat.v1.layers.batch_normalization用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.losses.softmax_cross_entropy。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。