当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.estimator.experimental.stop_if_no_decrease_hook用法及代码示例


如果指标在给定的最大步长内没有减少,则创建挂钩以停止。

用法

tf.estimator.experimental.stop_if_no_decrease_hook(
    estimator, metric_name, max_steps_without_decrease, eval_dir=None, min_steps=0,
    run_every_secs=60, run_every_steps=None
)

参数

  • estimator tf.estimator.Estimator 实例。
  • metric_name str ,要跟踪的指标。 "loss"、"accuracy"等
  • max_steps_without_decrease int ,在给定指标没有减少的情况下,最大训练步数。
  • eval_dir 如果设置,则包含带有评估指标的摘要文件的目录。默认情况下,将使用estimator.eval_dir()
  • min_steps int ,如果全局步长小于此值,则永远不会请求停止。默认为 0。
  • run_every_secs 如果指定,则以 run_every_secs 秒的间隔调用 should_stop_fn。默认为 60 秒。必须设置这个或run_every_steps
  • run_every_steps 如果指定,则每 run_every_steps 步骤调用 should_stop_fn。必须设置这个或run_every_secs

返回

  • SessionRunHook 类型的 early-stopping 钩子,它定期检查给定的指标是否在给定的最大训练步数上没有减少,如果为真则启动提前停止。

使用示例:

estimator = ...
# Hook to stop training if loss does not decrease in over 100000 steps.
hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 100000)
train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
tf.estimator.train_and_evaluate(estimator, train_spec, ...)

警告:当前实现支持 early-stopping 在本地模式下进行训练和评估。在分布式模式下,可以停止训练,但评估(这是一项单独的工作)将无限期地等待新模型检查点的评估,因此您将需要其他方法来检测和停止它。分布式模式下的Early-stopping 评估需要更改train_and_evaluate API,并将在未来的修订版中解决。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.estimator.experimental.stop_if_no_decrease_hook。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。