當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.keras.callbacks.ModelCheckpoint用法及代碼示例


以某種頻率保存 Keras 模型或模型權重的回調。

繼承自:Callback

用法

tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=0, save_best_only=False,
    save_weights_only=False, mode='auto', save_freq='epoch',
    options=None, initial_value_threshold=None, **kwargs
)

參數

  • filepath string 或 PathLike ,保存模型文件的路徑。例如文件路徑 = os.path.join(working_dir, 'ckpt', file_name)。 filepath 可以包含命名格式選項,這些選項將填充 epoch 的值和 logs 中的鍵(在 on_epoch_end 中傳遞)。例如:如果 filepathweights.{epoch:02d}-{val_loss:.2f}.hdf5 ,則模型檢查點將與文件名中的紀元號和驗證損失一起保存。文件路徑的目錄不應被任何其他回調重用以避免衝突。
  • monitor 要監控的指標名稱。通常,指標由tf.keras.Model.compile方法。筆記:
    • 在名稱前加上 "val_ " 以監控驗證指標。
    • 使用"loss" 或“val_loss”來監控模型的總損失。
    • 如果您將指標指定為字符串,例如 "accuracy" ,請傳遞相同的字符串(帶或不帶 "val_" 前綴)。

    • 如果傳遞metrics.Metric 對象,monitor 應設置為metric.name

    • 如果您不確定指標名稱,您可以檢查 history = model.fit() 返回的 history.history 字典的內容

    • Multi-output 模型在指標名稱上設置額外的前綴。

  • verbose 詳細模式,0 或 1。
  • save_best_only 如果是save_best_only=True,則僅在模型被認為是"best"時保存,並且不會覆蓋根據監控數量的最新最佳模型。如果 filepath 不包含像 {epoch} 這樣的格式化選項,那麽 filepath 將被每個新的更好的模型覆蓋。
  • mode

    {'auto'、'min'、'max'} 之一。如果 save_best_only=True ,則根據監控數量的最大化或最小化來決定是否覆蓋當前保存文件。對於 val_acc ,這應該是 max ,對於 val_loss 這應該是 min ,等等。

    auto 模式下,如果監控的數量為'acc' 或以'fmeasure' 開頭,則模式設置為max,其餘數量設置為min

  • save_weights_only 如果為 True,則僅保存模型的權重(model.save_weights(filepath)),否則保存完整模型(model.save(filepath))。
  • save_freq 'epoch' 或整數。使用 'epoch' 時,回調會在每個 epoch 後保存模型。使用整數時,回調將模型保存在這麽多批次的末尾。如果 Model 是用 steps_per_execution=N 編譯的,那麽將每第 N 批檢查一次保存條件。請注意,如果保存與 epoch 不一致,則監控的指標可能不太可靠(它可能隻反映 1 個批次,因為每個 epoch 都會重置指標)。默認為 'epoch'
  • options 如果save_weights_only 為真,則可選tf.train.CheckpointOptions 對象;如果save_weights_only 為假,則可選tf.saved_model.SaveOptions 對象。
  • initial_value_threshold 要監控的指標的浮點初始 "best" 值。僅適用於 save_best_value=True 。如果當前模型的性能優於此值,則僅覆蓋已保存的模型權重。
  • **kwargs 向後兼容的附加參數。可能的鍵是 period

ModelCheckpoint 回調與使用model.fit() 的訓練結合使用以在某個時間間隔保存模型或權重(在檢查點文件中),因此可以稍後加載模型或權重以從保存的狀態繼續訓練。

此回調提供的一些選項包括:

  • 是隻保留目前已經達到"best performance"的模型,還是不管性能如何在每個epoch結束時保存模型。
  • 'best'的定義;要監控的數量以及是否應該最大化或最小化。
  • 它應該保存的頻率。目前,回調支持在每個 epoch 結束時保存,或者在固定數量的訓練批次後保存。
  • 是隻保存權重,還是保存整個模型。

注意:如果您得到WARNING:tensorflow:Can save best model only with <name> available, skipping,請參閱monitor 參數的說明,了解有關如何正確執行此操作的詳細信息。

例子:

model.compile(loss=..., optimizer=...,
              metrics=['accuracy'])

EPOCHS = 10
checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])

# The model weights (that are considered the best) are loaded into the model.
model.load_weights(checkpoint_filepath)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.keras.callbacks.ModelCheckpoint。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。