当监控的指标停止改进时停止训练。
继承自:Callback
用法
tf.keras.callbacks.EarlyStopping(
monitor='val_loss', min_delta=0, patience=0, verbose=0,
mode='auto', baseline=None, restore_best_weights=False
)
参数
-
monitor
要监控的数量。 -
min_delta
被监测数量的最小变化被认为是改进,即小于min_delta的绝对变化,将被视为没有改进。 -
patience
训练停止后没有改善的 epoch 数。 -
verbose
详细模式。 -
mode
{"auto", "min", "max"}
之一。min
模式下,当监测的数量停止减少时,训练将停止;在"max"
模式下,当监控的数量停止增加时它会停止;在"auto"
模式下,根据监控量的名称自动推断方向。 -
baseline
监控数量的基线值。如果模型没有显示出对基线的改进,则训练将停止。 -
restore_best_weights
是否从监测量的最佳值的epoch恢复模型权重。如果为 False,则使用在训练的最后一步获得的模型权重。无论相对于baseline
的性能如何,都会恢复一个纪元。如果在baseline
上没有任何 epoch 得到改进,则将针对patience
epoch 运行训练,并从该集合中的最佳 epoch 恢复权重。
假设训练的目标是最小化损失。这样,要监控的指标将是 'loss'
,模式将是 'min'
。 model.fit()
训练循环将在每个 epoch 结束时检查损失是否不再减少,如果适用的话,考虑 min_delta
和 patience
。一旦发现它不再减少,model.stop_training
将被标记为 True 并且训练终止。
要监控的数量需要在logs
dict 中可用。为此,请在 model.compile()
传递损失或指标。
例子:
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
# This callback will stop the training when there is no improvement in
# the loss for three consecutive epochs.
model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss='mse')
history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
epochs=10, batch_size=1, callbacks=[callback],
verbose=0)
len(history.history['loss']) # Only 4 epochs are run.
4
相关用法
- Python tf.keras.callbacks.ReduceLROnPlateau用法及代码示例
- Python tf.keras.callbacks.CSVLogger用法及代码示例
- Python tf.keras.callbacks.TensorBoard用法及代码示例
- Python tf.keras.callbacks.Callback用法及代码示例
- Python tf.keras.callbacks.ModelCheckpoint用法及代码示例
- Python tf.keras.callbacks.LambdaCallback用法及代码示例
- Python tf.keras.callbacks.BackupAndRestore用法及代码示例
- Python tf.keras.callbacks.LearningRateScheduler用法及代码示例
- Python tf.keras.callbacks.History用法及代码示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代码示例
- Python tf.keras.metrics.Mean.merge_state用法及代码示例
- Python tf.keras.layers.InputLayer用法及代码示例
- Python tf.keras.layers.serialize用法及代码示例
- Python tf.keras.metrics.Hinge用法及代码示例
- Python tf.keras.experimental.WideDeepModel.compute_loss用法及代码示例
- Python tf.keras.metrics.SparseCategoricalAccuracy.merge_state用法及代码示例
- Python tf.keras.metrics.RootMeanSquaredError用法及代码示例
- Python tf.keras.applications.resnet50.preprocess_input用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.callbacks.EarlyStopping。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。