当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.estimator.experimental.stop_if_lower_hook用法及代码示例


如果给定指标低于阈值,则创建挂钩以停止。

用法

tf.estimator.experimental.stop_if_lower_hook(
    estimator, metric_name, threshold, eval_dir=None, min_steps=0,
    run_every_secs=60, run_every_steps=None
)

参数

  • estimator tf.estimator.Estimator 实例。
  • metric_name str ,要跟踪的指标。 "loss"、"accuracy"等
  • threshold 给定指标的数值阈值。
  • eval_dir 如果设置,则包含带有评估指标的摘要文件的目录。默认情况下,将使用estimator.eval_dir()
  • min_steps int ,如果全局步长小于此值,则永远不会请求停止。默认为 0。
  • run_every_secs 如果指定,则以 run_every_secs 秒的间隔调用 should_stop_fn。默认为 60 秒。必须设置这个或run_every_steps
  • run_every_steps 如果指定,则每 run_every_steps 步骤调用 should_stop_fn。必须设置这个或run_every_secs

返回

  • SessionRunHook 类型的 early-stopping 挂钩,定期检查给定指标是否低于指定阈值,如果为真则启动提前停止。

使用示例:

estimator = ...
# Hook to stop training if loss becomes lower than 100.
hook = early_stopping.stop_if_lower_hook(estimator, "loss", 100)
train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
tf.estimator.train_and_evaluate(estimator, train_spec, ...)

警告:当前实现支持 early-stopping 在本地模式下进行训练和评估。在分布式模式下,可以停止训练,但评估(这是一项单独的工作)将无限期地等待新模型检查点的评估,因此您将需要其他方法来检测和停止它。分布式模式下的Early-stopping 评估需要更改train_and_evaluate API,并将在未来的修订版中解决。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.estimator.experimental.stop_if_lower_hook。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。