当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.losses.mean_squared_error用法及代码示例


在训练过程中添加 Sum-of-Squares 损失。

用法

tf.compat.v1.losses.mean_squared_error(
    labels, predictions, weights=1.0, scope=None,
    loss_collection=tf.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
)

参数

  • labels 地面实况输出张量,与'predictions' 的维度相同。
  • predictions 预测的输出。
  • weights 可选Tensor,其秩为0,或与labels相同的秩,并且必须可广播到labels(即,所有维度必须是1,或与相应的losses维度相同) .
  • scope 在计算损失时执行的操作的范围。
  • loss_collection 将添加损失的集合。
  • reduction 适用于损失的减免类型。

返回

  • 加权损失浮点数 Tensor 。如果 reductionNONE ,则其形状与 labels 相同;否则,它是标量。

抛出

  • ValueError 如果predictions 的形状与labels 的形状不匹配或weights 的形状无效。此外,如果 labelspredictions 为无。

迁移到 TF2

警告:这个 API 是为 TensorFlow v1 设计的。继续阅读有关如何从该 API 迁移到本机 TensorFlow v2 等效项的详细信息。见TensorFlow v1 到 TensorFlow v2 迁移指南有关如何迁移其余代码的说明。

tf.compat.v1.losses.mean_squared_error 主要与即刻执行和 tf.function 兼容。但是,loss_collection 参数在即刻执行时会被忽略,并且不会将任何损失写入损失集合。您将需要手动保留返回值或依赖 tf.keras.Model 损失跟踪。

要切换到本机 TF2 样式,请实例化 tf.keras.losses.MeanSquaredError 类并改为调用该对象。

到原生 TF2 的结构映射

前:

loss = tf.compat.v1.losses.mean_squared_error(
  labels=labels,
  predictions=predictions,
  weights=weights,
  reduction=reduction)

后:

loss_fn = tf.keras.losses.MeanSquaredError(
  reduction=reduction)
loss = loss_fn(
  y_true=labels,
  y_pred=predictions,
  sample_weight=weights)

如何映射参数

TF1 参数名称 TF2 参数名称 注意
labels y_true __call__() 方法中
predictions y_pred __call__() 方法中
weights sample_weight __call__() 方法中。 sample_weight 的形状要求与 weights 不同。请检查参数定义以获取详细信息。
scope 不支持 -
loss_collection 不支持 应显式或使用 Keras API 跟踪损失,例如add_loss,而不是通过集合
reduction reduction 在构造函数中。 tf.compat.v1.losses.softmax_cross_entropy中的tf.compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZEtf.compat.v1.losses.Reduction.SUMtf.compat.v1.losses.Reduction.NONE的值分别对应于tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZEtf.keras.losses.Reduction.SUMtf.keras.losses.Reduction.NONE。如果您为 reduction 使用其他值,包括默认值 tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS ,则没有直接对应的值。请手动修改损失实现。

使用示例之前和之后

前:

y_true = [1, 2, 3]
y_pred = [1, 3, 5]
weights = [0, 1, 0.25]
# samples with zero-weight are excluded from calculation when `reduction`
# argument is set to default value `Reduction.SUM_BY_NONZERO_WEIGHTS`
tf.compat.v1.losses.mean_squared_error(
   labels=y_true,
   predictions=y_pred,
   weights=weights).numpy()
1.0
tf.compat.v1.losses.mean_squared_error(
   labels=y_true,
   predictions=y_pred,
   weights=weights,
   reduction=tf.compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZE).numpy()
0.66667

后:

y_true = [[1.0], [2.0], [3.0]]
y_pred = [[1.0], [3.0], [5.0]]
weights = [1, 1, 0.25]
mse = tf.keras.losses.MeanSquaredError(
   reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
mse(y_true=y_true, y_pred=y_pred, sample_weight=weights).numpy()
0.66667

weights 作为损失的系数。如果提供了标量,则损失只是按给定值缩放。如果 weights 是大小为 [batch_size] 的张量,则批次的每个样本的总损失将由 weights 向量中的相应元素重新缩放。如果 weights 的形状与 predictions 的形状匹配,则 predictions 的每个可测量元素的损失将按 weights 的相应值进行缩放。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.losses.mean_squared_error。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。