用于构建新回调的抽象基类。
用法
tf.keras.callbacks.Callback()
回调可以传递给诸如 fit
, evaluate
和 predict
等 keras 方法,以连接到模型训练和推理生命周期的各个阶段。
要创建自定义回调,请继承 keras.callbacks.Callback
并覆盖与感兴趣的阶段关联的方法。有关更多信息,请参阅 https://www.tensorflow.org/guide/keras/custom_callback。
例子:
training_finished = False
class MyCallback(tf.keras.callbacks.Callback):
def on_train_end(self, logs=None):
global training_finished
training_finished = True
model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
model.compile(loss='mean_squared_error')
model.fit(tf.constant([[1.0]]), tf.constant([[1.0]]),
callbacks=[MyCallback()])
assert training_finished == True
如果您想在自定义训练循环中使用 Callback
对象:
- 您应该将所有回调打包到一个
callbacks.CallbackList
中,以便可以一起调用它们。 您将需要在循环中的适当位置手动调用所有
on_*
方法。像这样:callbacks = tf.keras.callbacks.CallbackList([...]) callbacks.append(...) callbacks.on_train_begin(...) for epoch in range(EPOCHS): callbacks.on_epoch_begin(epoch) for i, data in dataset.enumerate(): callbacks.on_train_batch_begin(i) batch_logs = model.train_step(data) callbacks.on_train_batch_end(i, batch_logs) epoch_logs = ... callbacks.on_epoch_end(epoch, epoch_logs) final_logs=... callbacks.on_train_end(final_logs)
回调方法作为参数的
logs
字典将包含与当前批次或时期相关的数量的键(参见method-specific 文档字符串)。
相关用法
- Python tf.keras.callbacks.CSVLogger用法及代码示例
- Python tf.keras.callbacks.ReduceLROnPlateau用法及代码示例
- Python tf.keras.callbacks.EarlyStopping用法及代码示例
- Python tf.keras.callbacks.TensorBoard用法及代码示例
- Python tf.keras.callbacks.ModelCheckpoint用法及代码示例
- Python tf.keras.callbacks.LambdaCallback用法及代码示例
- Python tf.keras.callbacks.BackupAndRestore用法及代码示例
- Python tf.keras.callbacks.LearningRateScheduler用法及代码示例
- Python tf.keras.callbacks.History用法及代码示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代码示例
- Python tf.keras.metrics.Mean.merge_state用法及代码示例
- Python tf.keras.layers.InputLayer用法及代码示例
- Python tf.keras.layers.serialize用法及代码示例
- Python tf.keras.metrics.Hinge用法及代码示例
- Python tf.keras.experimental.WideDeepModel.compute_loss用法及代码示例
- Python tf.keras.metrics.SparseCategoricalAccuracy.merge_state用法及代码示例
- Python tf.keras.metrics.RootMeanSquaredError用法及代码示例
- Python tf.keras.applications.resnet50.preprocess_input用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.callbacks.Callback。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。