當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python tf.data.experimental.CheckpointInputPipelineHook用法及代碼示例

檢查點每 N 步或秒輸入管道狀態。

繼承自:SessionRunHook

用法

tf.data.experimental.CheckpointInputPipelineHook(
    estimator, external_state_policy=None
)

參數

  • estimator 估計器。
  • external_state_policy 一個字符串,用於標識如何處理依賴於外部狀態的輸入管道。可能的值為'ignore':外部狀態被靜默忽略。 'warn':外部狀態被忽略,記錄警告。 'fail':遇到外部狀態操作失敗。默認情況下,我們將其設置為'fail'。

拋出

  • ValueError 應設置save_stepssave_secs 之一。
  • ValueError 最多隻能設置一個 saver 或scaffold 。
  • ValueError 如果 external_state_policy 不是 'warn'、'ignore' 或 'fail' 之一。

該鉤子將迭代器的狀態保存在Graph 中,以便在恢複訓練時輸入管道從中斷處繼續。這可以潛在地避免某些管道中的過度擬合,其中每個評估的訓練步驟數與數據集大小相比較小,或者如果訓練管道被搶占。

CheckpointSaverHook 的區別:

  1. 僅保存 "iterators" 集合中的輸入管道,而不保存全局變量或其他可保存對象。
  2. 不將GraphDefMetaGraphDef 寫入摘要。

檢查點訓練管道的示例:

est = tf.estimator.Estimator(model_fn)
while True:
  est.train(
      train_input_fn,
      hooks=[tf.data.experimental.CheckpointInputPipelineHook(est)],
      steps=train_steps_per_eval)
  # Note:We do not pass the hook here.
  metrics = est.evaluate(eval_input_fn)
  if should_stop_the_training(metrics):
    break

如果需要將輸入管道狀態與模型檢查點分開保存,則應使用此掛鉤。出於以下幾個原因,這樣做可能很有用:

  1. 輸入管道檢查點可能很大,例如,如果有大的 shuffle 或預取緩衝區,並且可能會使檢查點大小膨脹。
  2. 如果輸入管道在訓練和驗證之間共享,則在驗證期間恢複檢查點可能會覆蓋驗證輸入管道。

為了將輸入管道檢查點與模型權重一起保存,請直接使用 tf.data.experimental.make_saveable_from_iterator 創建 SaveableObject 並添加到 SAVEABLE_OBJECTS 集合中。但是請注意,您需要注意不要在評估期間恢複訓練迭代器。您可以通過在構建評估圖時不將迭代器添加到 SAVEABLE_OBJECTS Collector來做到這一點。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.data.experimental.CheckpointInputPipelineHook。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。