當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.estimator.TrainSpec用法及代碼示例


train_and_evaluate 調用的"train" 部分的配置。

用法

tf.estimator.TrainSpec(
    input_fn, max_steps=None, hooks=None, saving_listeners=None
)

參數

  • input_fn 提供輸入數據以作為小批量進行訓練的函數。看預製估算器了解更多信息。該函數應構造並返回以下內容之一:
    • 'tf.data.Dataset' 對象:Dataset 對象的輸出必須是具有與以下相同約束的元組(特征、標簽)。
    • 元組(特征、標簽):其中 features 是 Tensor 或字符串特征名稱字典到 Tensor 並且標簽是 Tensor 或字符串標簽名稱字典到 Tensor
  • max_steps Int. 訓練模型的總步數為正數。如果 None ,永遠訓練。訓練 input_fn 預計不會生成 OutOfRangeErrorStopIteration 異常。有關詳細信息,請參閱train_and_evaluate 停止條件部分。
  • hooks tf.train.SessionRunHook 對象的可迭代對象,可在訓練期間在所有工作人員(包括主管)上運行。
  • saving_listeners tf.estimator.CheckpointSaverListener 對象的迭代,在訓練期間主要運行。

拋出

  • ValueError 如果任何輸入參數無效。
  • TypeError 如果任何參數不是預期的類型。

屬性

  • input_fn 字段編號 0 的 namedtuple 別名
  • max_steps 字段編號 1 的 namedtuple 別名
  • hooks 字段編號 2 的 namedtuple 別名
  • saving_listeners 字段編號 3 的 namedtuple 別名

TrainSpec 確定訓練的輸入數據以及持續時間。可選的鉤子在訓練的不同階段運行。

用法:

train_spec = tf.estimator.TrainSpec(
   input_fn=lambda:1,
   max_steps=100,
   hooks=[_StopAtSecsHook(stop_after_secs=10)],
   saving_listeners=[_NewCheckpointListenerForEvaluate(None, 20, None)])
train_spec.saving_listeners[0]._eval_throttle_secs
20
train_spec.hooks[0]._stop_after_secs
10
train_spec.max_steps
100

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.estimator.TrainSpec。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。