检查点每 N 步或秒输入管道状态。
继承自:SessionRunHook
用法
tf.data.experimental.CheckpointInputPipelineHook(
estimator, external_state_policy=None
)
参数
-
estimator
估计器。 -
external_state_policy
一个字符串,用于标识如何处理依赖于外部状态的输入管道。可能的值为'ignore':外部状态被静默忽略。 'warn':外部状态被忽略,记录警告。 'fail':遇到外部状态操作失败。默认情况下,我们将其设置为'fail'。
抛出
-
ValueError
应设置save_steps
或save_secs
之一。 -
ValueError
最多只能设置一个 saver 或scaffold 。 -
ValueError
如果external_state_policy
不是 'warn'、'ignore' 或 'fail' 之一。
该钩子将迭代器的状态保存在Graph
中,以便在恢复训练时输入管道从中断处继续。这可以潜在地避免某些管道中的过度拟合,其中每个评估的训练步骤数与数据集大小相比较小,或者如果训练管道被抢占。
与 CheckpointSaverHook
的区别:
- 仅保存 "iterators" 集合中的输入管道,而不保存全局变量或其他可保存对象。
- 不将
GraphDef
和MetaGraphDef
写入摘要。
检查点训练管道的示例:
est = tf.estimator.Estimator(model_fn)
while True:
est.train(
train_input_fn,
hooks=[tf.data.experimental.CheckpointInputPipelineHook(est)],
steps=train_steps_per_eval)
# Note:We do not pass the hook here.
metrics = est.evaluate(eval_input_fn)
if should_stop_the_training(metrics):
break
如果需要将输入管道状态与模型检查点分开保存,则应使用此挂钩。出于以下几个原因,这样做可能很有用:
- 输入管道检查点可能很大,例如,如果有大的 shuffle 或预取缓冲区,并且可能会使检查点大小膨胀。
- 如果输入管道在训练和验证之间共享,则在验证期间恢复检查点可能会覆盖验证输入管道。
为了将输入管道检查点与模型权重一起保存,请直接使用 tf.data.experimental.make_saveable_from_iterator
创建 SaveableObject
并添加到 SAVEABLE_OBJECTS
集合中。但是请注意,您需要注意不要在评估期间恢复训练迭代器。您可以通过在构建评估图时不将迭代器添加到 SAVEABLE_OBJECTS Collector来做到这一点。
相关用法
- Python tf.data.experimental.Counter用法及代码示例
- Python tf.data.experimental.CsvDataset.window用法及代码示例
- Python tf.data.experimental.CsvDataset.apply用法及代码示例
- Python tf.data.experimental.CsvDataset.flat_map用法及代码示例
- Python tf.data.experimental.CsvDataset.random用法及代码示例
- Python tf.data.experimental.CsvDataset.cardinality用法及代码示例
- Python tf.data.experimental.CsvDataset.interleave用法及代码示例
- Python tf.data.experimental.CsvDataset.group_by_window用法及代码示例
- Python tf.data.experimental.CsvDataset.as_numpy_iterator用法及代码示例
- Python tf.data.experimental.CsvDataset.from_generator用法及代码示例
- Python tf.data.experimental.CsvDataset.range用法及代码示例
- Python tf.data.experimental.CsvDataset.unique用法及代码示例
- Python tf.data.experimental.CsvDataset.shard用法及代码示例
- Python tf.data.experimental.CsvDataset.choose_from_datasets用法及代码示例
- Python tf.data.experimental.CsvDataset.batch用法及代码示例
- Python tf.data.experimental.CsvDataset用法及代码示例
- Python tf.data.experimental.CsvDataset.enumerate用法及代码示例
- Python tf.data.experimental.CsvDataset.from_tensors用法及代码示例
- Python tf.data.experimental.CsvDataset.bucket_by_sequence_length用法及代码示例
- Python tf.data.experimental.CsvDataset.padded_batch用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.data.experimental.CheckpointInputPipelineHook。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。