檢查點每 N 步或秒輸入管道狀態。
繼承自:SessionRunHook
用法
tf.data.experimental.CheckpointInputPipelineHook(
estimator, external_state_policy=None
)
參數
-
estimator
估計器。 -
external_state_policy
一個字符串,用於標識如何處理依賴於外部狀態的輸入管道。可能的值為'ignore':外部狀態被靜默忽略。 'warn':外部狀態被忽略,記錄警告。 'fail':遇到外部狀態操作失敗。默認情況下,我們將其設置為'fail'。
拋出
-
ValueError
應設置save_steps
或save_secs
之一。 -
ValueError
最多隻能設置一個 saver 或scaffold 。 -
ValueError
如果external_state_policy
不是 'warn'、'ignore' 或 'fail' 之一。
該鉤子將迭代器的狀態保存在Graph
中,以便在恢複訓練時輸入管道從中斷處繼續。這可以潛在地避免某些管道中的過度擬合,其中每個評估的訓練步驟數與數據集大小相比較小,或者如果訓練管道被搶占。
與 CheckpointSaverHook
的區別:
- 僅保存 "iterators" 集合中的輸入管道,而不保存全局變量或其他可保存對象。
- 不將
GraphDef
和MetaGraphDef
寫入摘要。
檢查點訓練管道的示例:
est = tf.estimator.Estimator(model_fn)
while True:
est.train(
train_input_fn,
hooks=[tf.data.experimental.CheckpointInputPipelineHook(est)],
steps=train_steps_per_eval)
# Note:We do not pass the hook here.
metrics = est.evaluate(eval_input_fn)
if should_stop_the_training(metrics):
break
如果需要將輸入管道狀態與模型檢查點分開保存,則應使用此掛鉤。出於以下幾個原因,這樣做可能很有用:
- 輸入管道檢查點可能很大,例如,如果有大的 shuffle 或預取緩衝區,並且可能會使檢查點大小膨脹。
- 如果輸入管道在訓練和驗證之間共享,則在驗證期間恢複檢查點可能會覆蓋驗證輸入管道。
為了將輸入管道檢查點與模型權重一起保存,請直接使用 tf.data.experimental.make_saveable_from_iterator
創建 SaveableObject
並添加到 SAVEABLE_OBJECTS
集合中。但是請注意,您需要注意不要在評估期間恢複訓練迭代器。您可以通過在構建評估圖時不將迭代器添加到 SAVEABLE_OBJECTS Collector來做到這一點。
相關用法
- Python tf.data.experimental.Counter用法及代碼示例
- Python tf.data.experimental.CsvDataset.window用法及代碼示例
- Python tf.data.experimental.CsvDataset.apply用法及代碼示例
- Python tf.data.experimental.CsvDataset.flat_map用法及代碼示例
- Python tf.data.experimental.CsvDataset.random用法及代碼示例
- Python tf.data.experimental.CsvDataset.cardinality用法及代碼示例
- Python tf.data.experimental.CsvDataset.interleave用法及代碼示例
- Python tf.data.experimental.CsvDataset.group_by_window用法及代碼示例
- Python tf.data.experimental.CsvDataset.as_numpy_iterator用法及代碼示例
- Python tf.data.experimental.CsvDataset.from_generator用法及代碼示例
- Python tf.data.experimental.CsvDataset.range用法及代碼示例
- Python tf.data.experimental.CsvDataset.unique用法及代碼示例
- Python tf.data.experimental.CsvDataset.shard用法及代碼示例
- Python tf.data.experimental.CsvDataset.choose_from_datasets用法及代碼示例
- Python tf.data.experimental.CsvDataset.batch用法及代碼示例
- Python tf.data.experimental.CsvDataset用法及代碼示例
- Python tf.data.experimental.CsvDataset.enumerate用法及代碼示例
- Python tf.data.experimental.CsvDataset.from_tensors用法及代碼示例
- Python tf.data.experimental.CsvDataset.bucket_by_sequence_length用法及代碼示例
- Python tf.data.experimental.CsvDataset.padded_batch用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.data.experimental.CheckpointInputPipelineHook。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。