當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.keras.callbacks.Callback用法及代碼示例


用於構建新回調的抽象基類。

用法

tf.keras.callbacks.Callback()

屬性

  • params 字典。訓練參數(例如詳細程度、批量大小、時期數......)。
  • model keras.models.Model 的實例。正在訓練的模型的參考。

回調可以傳遞給諸如 fit , evaluatepredict 等 keras 方法,以連接到模型訓練和推理生命周期的各個階段。

要創建自定義回調,請繼承 keras.callbacks.Callback 並覆蓋與感興趣的階段關聯的方法。有關更多信息,請參閱 https://www.tensorflow.org/guide/keras/custom_callback。

例子:

training_finished = False
class MyCallback(tf.keras.callbacks.Callback):
  def on_train_end(self, logs=None):
    global training_finished
    training_finished = True
model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
model.compile(loss='mean_squared_error')
model.fit(tf.constant([[1.0]]), tf.constant([[1.0]]),
          callbacks=[MyCallback()])
assert training_finished == True

如果您想在自定義訓練循環中使用 Callback 對象:

  1. 您應該將所有回調打包到一個 callbacks.CallbackList 中,以便可以一起調用它們。
  2. 您將需要在循環中的適當位置手動調用所有on_* 方法。像這樣:

    callbacks =  tf.keras.callbacks.CallbackList([...])
    callbacks.append(...)
    
    callbacks.on_train_begin(...)
    for epoch in range(EPOCHS):
      callbacks.on_epoch_begin(epoch)
      for i, data in dataset.enumerate():
        callbacks.on_train_batch_begin(i)
        batch_logs = model.train_step(data)
        callbacks.on_train_batch_end(i, batch_logs)
      epoch_logs = ...
      callbacks.on_epoch_end(epoch, epoch_logs)
    final_logs=...
    callbacks.on_train_end(final_logs)

    回調方法作為參數的logs 字典將包含與當前批次或時期相關的數量的鍵(參見method-specific 文檔字符串)。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.keras.callbacks.Callback。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。