以某种频率保存 Keras 模型或模型权重的回调。
继承自:Callback
用法
tf.keras.callbacks.ModelCheckpoint(
filepath, monitor='val_loss', verbose=0, save_best_only=False,
save_weights_only=False, mode='auto', save_freq='epoch',
options=None, initial_value_threshold=None, **kwargs
)
参数
-
filepath
string 或PathLike
,保存模型文件的路径。例如文件路径 = os.path.join(working_dir, 'ckpt', file_name)。filepath
可以包含命名格式选项,这些选项将填充epoch
的值和logs
中的键(在on_epoch_end
中传递)。例如:如果filepath
是weights.{epoch:02d}-{val_loss:.2f}.hdf5
,则模型检查点将与文件名中的纪元号和验证损失一起保存。文件路径的目录不应被任何其他回调重用以避免冲突。 -
monitor
要监控的指标名称。通常,指标由tf.keras.Model.compile方法。笔记:- 在名称前加上
"val_
" 以监控验证指标。 - 使用
"loss"
或“val_loss
”来监控模型的总损失。 如果您将指标指定为字符串,例如
"accuracy"
,请传递相同的字符串(带或不带"val_"
前缀)。如果传递
metrics.Metric
对象,monitor
应设置为metric.name
如果您不确定指标名称,您可以检查
history = model.fit()
返回的history.history
字典的内容Multi-output 模型在指标名称上设置额外的前缀。
- 在名称前加上
-
verbose
详细模式,0 或 1。 -
save_best_only
如果是save_best_only=True
,则仅在模型被认为是"best"时保存,并且不会覆盖根据监控数量的最新最佳模型。如果filepath
不包含像{epoch}
这样的格式化选项,那么filepath
将被每个新的更好的模型覆盖。 -
mode
{'auto'、'min'、'max'} 之一。如果
save_best_only=True
,则根据监控数量的最大化或最小化来决定是否覆盖当前保存文件。对于val_acc
,这应该是max
,对于val_loss
这应该是min
,等等。在
auto
模式下,如果监控的数量为'acc' 或以'fmeasure' 开头,则模式设置为max
,其余数量设置为min
。 -
save_weights_only
如果为 True,则仅保存模型的权重(model.save_weights(filepath)
),否则保存完整模型(model.save(filepath)
)。 -
save_freq
'epoch'
或整数。使用'epoch'
时,回调会在每个 epoch 后保存模型。使用整数时,回调将模型保存在这么多批次的末尾。如果Model
是用steps_per_execution=N
编译的,那么将每第 N 批检查一次保存条件。请注意,如果保存与 epoch 不一致,则监控的指标可能不太可靠(它可能只反映 1 个批次,因为每个 epoch 都会重置指标)。默认为'epoch'
。 -
options
如果save_weights_only
为真,则可选tf.train.CheckpointOptions
对象;如果save_weights_only
为假,则可选tf.saved_model.SaveOptions
对象。 -
initial_value_threshold
要监控的指标的浮点初始 "best" 值。仅适用于save_best_value=True
。如果当前模型的性能优于此值,则仅覆盖已保存的模型权重。 -
**kwargs
向后兼容的附加参数。可能的键是period
。
ModelCheckpoint
回调与使用model.fit()
的训练结合使用以在某个时间间隔保存模型或权重(在检查点文件中),因此可以稍后加载模型或权重以从保存的状态继续训练。
此回调提供的一些选项包括:
- 是只保留目前已经达到"best performance"的模型,还是不管性能如何在每个epoch结束时保存模型。
- 'best'的定义;要监控的数量以及是否应该最大化或最小化。
- 它应该保存的频率。目前,回调支持在每个 epoch 结束时保存,或者在固定数量的训练批次后保存。
- 是只保存权重,还是保存整个模型。
注意:如果您得到WARNING:tensorflow:Can save best model only with <name>
available, skipping
,请参阅monitor
参数的说明,了解有关如何正确执行此操作的详细信息。
例子:
model.compile(loss=..., optimizer=...,
metrics=['accuracy'])
EPOCHS = 10
checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_accuracy',
mode='max',
save_best_only=True)
# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
# The model weights (that are considered the best) are loaded into the model.
model.load_weights(checkpoint_filepath)
相关用法
- Python tf.keras.callbacks.ReduceLROnPlateau用法及代码示例
- Python tf.keras.callbacks.EarlyStopping用法及代码示例
- Python tf.keras.callbacks.CSVLogger用法及代码示例
- Python tf.keras.callbacks.TensorBoard用法及代码示例
- Python tf.keras.callbacks.Callback用法及代码示例
- Python tf.keras.callbacks.LambdaCallback用法及代码示例
- Python tf.keras.callbacks.BackupAndRestore用法及代码示例
- Python tf.keras.callbacks.LearningRateScheduler用法及代码示例
- Python tf.keras.callbacks.History用法及代码示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代码示例
- Python tf.keras.metrics.Mean.merge_state用法及代码示例
- Python tf.keras.layers.InputLayer用法及代码示例
- Python tf.keras.layers.serialize用法及代码示例
- Python tf.keras.metrics.Hinge用法及代码示例
- Python tf.keras.experimental.WideDeepModel.compute_loss用法及代码示例
- Python tf.keras.metrics.SparseCategoricalAccuracy.merge_state用法及代码示例
- Python tf.keras.metrics.RootMeanSquaredError用法及代码示例
- Python tf.keras.applications.resnet50.preprocess_input用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.callbacks.ModelCheckpoint。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。