以某種頻率保存 Keras 模型或模型權重的回調。
繼承自:Callback
用法
tf.keras.callbacks.ModelCheckpoint(
filepath, monitor='val_loss', verbose=0, save_best_only=False,
save_weights_only=False, mode='auto', save_freq='epoch',
options=None, initial_value_threshold=None, **kwargs
)
參數
-
filepath
string 或PathLike
,保存模型文件的路徑。例如文件路徑 = os.path.join(working_dir, 'ckpt', file_name)。filepath
可以包含命名格式選項,這些選項將填充epoch
的值和logs
中的鍵(在on_epoch_end
中傳遞)。例如:如果filepath
是weights.{epoch:02d}-{val_loss:.2f}.hdf5
,則模型檢查點將與文件名中的紀元號和驗證損失一起保存。文件路徑的目錄不應被任何其他回調重用以避免衝突。 -
monitor
要監控的指標名稱。通常,指標由tf.keras.Model.compile方法。筆記:- 在名稱前加上
"val_
" 以監控驗證指標。 - 使用
"loss"
或“val_loss
”來監控模型的總損失。 如果您將指標指定為字符串,例如
"accuracy"
,請傳遞相同的字符串(帶或不帶"val_"
前綴)。如果傳遞
metrics.Metric
對象,monitor
應設置為metric.name
如果您不確定指標名稱,您可以檢查
history = model.fit()
返回的history.history
字典的內容Multi-output 模型在指標名稱上設置額外的前綴。
- 在名稱前加上
-
verbose
詳細模式,0 或 1。 -
save_best_only
如果是save_best_only=True
,則僅在模型被認為是"best"時保存,並且不會覆蓋根據監控數量的最新最佳模型。如果filepath
不包含像{epoch}
這樣的格式化選項,那麽filepath
將被每個新的更好的模型覆蓋。 -
mode
{'auto'、'min'、'max'} 之一。如果
save_best_only=True
,則根據監控數量的最大化或最小化來決定是否覆蓋當前保存文件。對於val_acc
,這應該是max
,對於val_loss
這應該是min
,等等。在
auto
模式下,如果監控的數量為'acc' 或以'fmeasure' 開頭,則模式設置為max
,其餘數量設置為min
。 -
save_weights_only
如果為 True,則僅保存模型的權重(model.save_weights(filepath)
),否則保存完整模型(model.save(filepath)
)。 -
save_freq
'epoch'
或整數。使用'epoch'
時,回調會在每個 epoch 後保存模型。使用整數時,回調將模型保存在這麽多批次的末尾。如果Model
是用steps_per_execution=N
編譯的,那麽將每第 N 批檢查一次保存條件。請注意,如果保存與 epoch 不一致,則監控的指標可能不太可靠(它可能隻反映 1 個批次,因為每個 epoch 都會重置指標)。默認為'epoch'
。 -
options
如果save_weights_only
為真,則可選tf.train.CheckpointOptions
對象;如果save_weights_only
為假,則可選tf.saved_model.SaveOptions
對象。 -
initial_value_threshold
要監控的指標的浮點初始 "best" 值。僅適用於save_best_value=True
。如果當前模型的性能優於此值,則僅覆蓋已保存的模型權重。 -
**kwargs
向後兼容的附加參數。可能的鍵是period
。
ModelCheckpoint
回調與使用model.fit()
的訓練結合使用以在某個時間間隔保存模型或權重(在檢查點文件中),因此可以稍後加載模型或權重以從保存的狀態繼續訓練。
此回調提供的一些選項包括:
- 是隻保留目前已經達到"best performance"的模型,還是不管性能如何在每個epoch結束時保存模型。
- 'best'的定義;要監控的數量以及是否應該最大化或最小化。
- 它應該保存的頻率。目前,回調支持在每個 epoch 結束時保存,或者在固定數量的訓練批次後保存。
- 是隻保存權重,還是保存整個模型。
注意:如果您得到WARNING:tensorflow:Can save best model only with <name>
available, skipping
,請參閱monitor
參數的說明,了解有關如何正確執行此操作的詳細信息。
例子:
model.compile(loss=..., optimizer=...,
metrics=['accuracy'])
EPOCHS = 10
checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_accuracy',
mode='max',
save_best_only=True)
# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
# The model weights (that are considered the best) are loaded into the model.
model.load_weights(checkpoint_filepath)
相關用法
- Python tf.keras.callbacks.ReduceLROnPlateau用法及代碼示例
- Python tf.keras.callbacks.EarlyStopping用法及代碼示例
- Python tf.keras.callbacks.CSVLogger用法及代碼示例
- Python tf.keras.callbacks.TensorBoard用法及代碼示例
- Python tf.keras.callbacks.Callback用法及代碼示例
- Python tf.keras.callbacks.LambdaCallback用法及代碼示例
- Python tf.keras.callbacks.BackupAndRestore用法及代碼示例
- Python tf.keras.callbacks.LearningRateScheduler用法及代碼示例
- Python tf.keras.callbacks.History用法及代碼示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代碼示例
- Python tf.keras.metrics.Mean.merge_state用法及代碼示例
- Python tf.keras.layers.InputLayer用法及代碼示例
- Python tf.keras.layers.serialize用法及代碼示例
- Python tf.keras.metrics.Hinge用法及代碼示例
- Python tf.keras.experimental.WideDeepModel.compute_loss用法及代碼示例
- Python tf.keras.metrics.SparseCategoricalAccuracy.merge_state用法及代碼示例
- Python tf.keras.metrics.RootMeanSquaredError用法及代碼示例
- Python tf.keras.applications.resnet50.preprocess_input用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.keras.callbacks.ModelCheckpoint。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。