当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.keras.callbacks.EarlyStopping用法及代码示例


当监控的指标停止改进时停止训练。

继承自:Callback

用法

tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', min_delta=0, patience=0, verbose=0,
    mode='auto', baseline=None, restore_best_weights=False
)

参数

  • monitor 要监控的数量。
  • min_delta 被监测数量的最小变化被认为是改进,即小于min_delta的绝对变化,将被视为没有改进。
  • patience 训练停止后没有改善的 epoch 数。
  • verbose 详细模式。
  • mode {"auto", "min", "max"} 之一。 min模式下,当监测的数量停止减少时,训练将停止;在"max"模式下,当监控的数量停止增加时它会停止;在"auto"模式下,根据监控量的名称自动推断方向。
  • baseline 监控数量的基线值。如果模型没有显示出对基线的改进,则训练将停止。
  • restore_best_weights 是否从监测量的最佳值的epoch恢复模型权重。如果为 False,则使用在训练的最后一步获得的模型权重。无论相对于 baseline 的性能如何,都会恢复一个纪元。如果在 baseline 上没有任何 epoch 得到改进,则将针对 patience epoch 运行训练,并从该集合中的最佳 epoch 恢复权重。

假设训练的目标是最小化损失。这样,要监控的指标将是 'loss' ,模式将是 'min'model.fit() 训练循环将在每个 epoch 结束时检查损失是否不再减少,如果适用的话,考虑 min_deltapatience。一旦发现它不再减少,model.stop_training 将被标记为 True 并且训练终止。

要监控的数量需要在logs dict 中可用。为此,请在 model.compile() 传递损失或指标。

例子:

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
# This callback will stop the training when there is no improvement in
# the loss for three consecutive epochs.
model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss='mse')
history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
                    epochs=10, batch_size=1, callbacks=[callback],
                    verbose=0)
len(history.history['loss'])  # Only 4 epochs are run.
4

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.callbacks.EarlyStopping。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。