用法
compute_loss(
x=None, y=None, y_pred=None, sample_weight=None
)
参数
-
x
输入数据。 -
y
目标数据。 -
y_pred
模型返回的预测(model(x)
的输出) -
sample_weight
用于加权损失函数的样本权重。
返回
-
总损失为
tf.Tensor
,或None
如果没有损失结果(这是由Model.test_step
调用的情况)。
计算总损失,验证并返回。
子类可以选择覆盖此方法以提供自定义损失计算逻辑。
例子:
class MyModel(tf.keras.Model):
def __init__(self, *args, **kwargs):
super(MyModel, self).__init__(*args, **kwargs)
self.loss_tracker = tf.keras.metrics.Mean(name='loss')
def compute_loss(self, x, y, y_pred, sample_weight):
loss = tf.reduce_mean(tf.math.squared_difference(y_pred, y))
loss += tf.add_n(self.losses)
self.loss_tracker.update_state(loss)
return loss
def reset_metrics(self):
self.loss_tracker.reset_states()
@property
def metrics(self):
return [self.loss_tracker]
tensors = tf.random.uniform((10, 10)), tf.random.uniform((10,))
dataset = tf.data.Dataset.from_tensor_slices(tensors).repeat().batch(1)
inputs = tf.keras.layers.Input(shape=(10,), name='my_input')
outputs = tf.keras.layers.Dense(10)(inputs)
model = MyModel(inputs, outputs)
model.add_loss(tf.reduce_sum(outputs))
optimizer = tf.keras.optimizers.SGD()
model.compile(optimizer, loss='mse', steps_per_execution=10)
model.fit(dataset, epochs=2, steps_per_epoch=10)
print('My custom loss:', model.loss_tracker.result().numpy())
相关用法
- Python tf.keras.Model.compute_metrics用法及代码示例
- Python tf.keras.Model.compile用法及代码示例
- Python tf.keras.Model.reset_metrics用法及代码示例
- Python tf.keras.Model.save_spec用法及代码示例
- Python tf.keras.Model.save用法及代码示例
- Python tf.keras.Model用法及代码示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代码示例
- Python tf.keras.metrics.Mean.merge_state用法及代码示例
- Python tf.keras.layers.InputLayer用法及代码示例
- Python tf.keras.callbacks.ReduceLROnPlateau用法及代码示例
- Python tf.keras.layers.serialize用法及代码示例
- Python tf.keras.metrics.Hinge用法及代码示例
- Python tf.keras.experimental.WideDeepModel.compute_loss用法及代码示例
- Python tf.keras.metrics.SparseCategoricalAccuracy.merge_state用法及代码示例
- Python tf.keras.metrics.RootMeanSquaredError用法及代码示例
- Python tf.keras.applications.resnet50.preprocess_input用法及代码示例
- Python tf.keras.metrics.SparseCategoricalCrossentropy.merge_state用法及代码示例
- Python tf.keras.metrics.sparse_categorical_accuracy用法及代码示例
- Python tf.keras.layers.Dropout用法及代码示例
- Python tf.keras.activations.softplus用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.Model.compute_loss。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。