当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.keras.callbacks.Callback用法及代码示例


用于构建新回调的抽象基类。

用法

tf.keras.callbacks.Callback()

属性

  • params 字典。训练参数(例如详细程度、批量大小、时期数......)。
  • model keras.models.Model 的实例。正在训练的模型的参考。

回调可以传递给诸如 fit , evaluatepredict 等 keras 方法,以连接到模型训练和推理生命周期的各个阶段。

要创建自定义回调,请继承 keras.callbacks.Callback 并覆盖与感兴趣的阶段关联的方法。有关更多信息,请参阅 https://www.tensorflow.org/guide/keras/custom_callback。

例子:

training_finished = False
class MyCallback(tf.keras.callbacks.Callback):
  def on_train_end(self, logs=None):
    global training_finished
    training_finished = True
model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
model.compile(loss='mean_squared_error')
model.fit(tf.constant([[1.0]]), tf.constant([[1.0]]),
          callbacks=[MyCallback()])
assert training_finished == True

如果您想在自定义训练循环中使用 Callback 对象:

  1. 您应该将所有回调打包到一个 callbacks.CallbackList 中,以便可以一起调用它们。
  2. 您将需要在循环中的适当位置手动调用所有on_* 方法。像这样:

    callbacks =  tf.keras.callbacks.CallbackList([...])
    callbacks.append(...)
    
    callbacks.on_train_begin(...)
    for epoch in range(EPOCHS):
      callbacks.on_epoch_begin(epoch)
      for i, data in dataset.enumerate():
        callbacks.on_train_batch_begin(i)
        batch_logs = model.train_step(data)
        callbacks.on_train_batch_end(i, batch_logs)
      epoch_logs = ...
      callbacks.on_epoch_end(epoch, epoch_logs)
    final_logs=...
    callbacks.on_train_end(final_logs)

    回调方法作为参数的logs 字典将包含与当前批次或时期相关的数量的键(参见method-specific 文档字符串)。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.callbacks.Callback。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。