當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.estimator.train_and_evaluate用法及代碼示例


訓練和評估 estimator

用法

tf.estimator.train_and_evaluate(
    estimator, train_spec, eval_spec
)

參數

  • estimator 用於訓練和評估的 Estimator 實例。
  • train_spec 用於指定訓練規範的 TrainSpec 實例。
  • eval_spec 用於指定評估和導出規範的 EvalSpec 實例。

返回

  • evaluate 調用 Estimator 的結果的元組和使用指定的 Exporter 的導出結果。目前,分布式訓練模式的返回值未定義。

拋出

  • ValueError 如果環境變量TF_CONFIG 設置不正確。

此實用程序函數使用給定的 estimator 訓練、評估和(可選)導出模型。所有訓練相關規範都保存在train_spec中,包括訓練input_fn和訓練最大步數等。所有評估和導出相關規範都保存在eval_spec中,包括評估input_fn、步數等。

此實用程序函數為本地(非分布式)和分布式配置提供一致的行為。默認分發配置是基於參數服務器的between-graph 複製。對於其他類型的分布配置,例如all-reduce 訓練,請使用 DistributionStrategies。

過擬合:為了避免過擬合,建議設置訓練input_fn,對訓練數據進行適當的shuffle。

停止條件:為了可靠地支持分布式和非分布式配置,模型訓練唯一支持的停止條件是 train_spec.max_steps 。如果 train_spec.max_stepsNone ,則模型將永遠被訓練。如果模型停止條件不同,請小心使用。例如,假設模型預計使用一個 epoch 的訓練數據進行訓練,並且訓練 input_fn 配置為在經過一個 epoch 後拋出 OutOfRangeError,這將停止 Estimator.train 。對於three-training-worker 分布式配置,每個訓練工作人員都可能獨立完成整個 epoch。因此,模型將使用三個 epoch 的訓練數據而不是一個 epoch 進行訓練。

本地(非分布式)訓練示例:

# Set up feature columns.
categorial_feature_a = categorial_column_with_hash_bucket(...)
categorial_feature_a_emb = embedding_column(
    categorical_column=categorial_feature_a, ...)
...  # other feature columns

estimator = DNNClassifier(
    feature_columns=[categorial_feature_a_emb, ...],
    hidden_units=[1024, 512, 256])

# Or set up the model directory
#   estimator = DNNClassifier(
#       config=tf.estimator.RunConfig(
#           model_dir='/my_model', save_summary_steps=100),
#       feature_columns=[categorial_feature_a_emb, ...],
#       hidden_units=[1024, 512, 256])

# Input pipeline for train and evaluate.
def train_input_fn():# returns x, y
  # please shuffle the data.
  pass
def eval_input_fn():# returns x, y
  pass

train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

請注意,在當前實現中,estimator.evaluate 將被多次調用。這意味著將為每個evaluate 調用重新創建評估圖(包括eval_input_fn)。 estimator.train 隻會被調用一次。

分布式訓練示例:

關於分布式訓練的例子,上麵的代碼可以不加改動地使用(請確保所有worker的RunConfig.model_dir設置為同一目錄,即所有worker都可以讀寫的共享文件係統)。唯一要做的額外工作是為每個工作人員相應地設置環境變量TF_CONFIG

另請參閱分布式 TensorFlow。

設置環境變量取決於平台。例如,在 Linux 上,可以按如下方式進行($ 是 shell 提示符):

$ TF_CONFIG='<replace_with_real_content>' python train_model.py

對於 TF_CONFIG 中的內容,假設訓練集群規範如下所示:

cluster = {"chief":["host0:2222"],
           "worker":["host1:2222", "host2:2222", "host3:2222"],
           "ps":["host4:2222", "host5:2222"]}

首席訓練工TF_CONFIG 示例(必須有且隻有一個):

# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
    "cluster":{
        "chief":["host0:2222"],
        "worker":["host1:2222", "host2:2222", "host3:2222"],
        "ps":["host4:2222", "host5:2222"]
    },
    "task":{"type":"chief", "index":0}
}'

請注意,首席員工也從事模型訓練工作,類似於其他非首席訓練員工(見下一段)。除了模型訓練之外,它還管理一些額外的工作,例如檢查點保存和恢複、編寫摘要等。

非首席訓練人員的TF_CONFIG 示例(可選,可以是多個):

# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
    "cluster":{
        "chief":["host0:2222"],
        "worker":["host1:2222", "host2:2222", "host3:2222"],
        "ps":["host4:2222", "host5:2222"]
    },
    "task":{"type":"worker", "index":0}
}'

其中task.index在本例中應分別設置為0、1、2,分別用於非首席訓練人員。

參數服務器的TF_CONFIG 示例,又名 ps(可以是多個):

# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
    "cluster":{
        "chief":["host0:2222"],
        "worker":["host1:2222", "host2:2222", "host3:2222"],
        "ps":["host4:2222", "host5:2222"]
    },
    "task":{"type":"ps", "index":0}
}'

其中 task.index 應設置為 0 和 1,在此示例中,分別用於參數服務器。

評估任務的TF_CONFIG 示例。 Evaluator 是一項特殊任務,不屬於訓練集群的一部分。可能隻有一個。它用於模型評估。

# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
    "cluster":{
        "chief":["host0:2222"],
        "worker":["host1:2222", "host2:2222", "host3:2222"],
        "ps":["host4:2222", "host5:2222"]
    },
    "task":{"type":"evaluator", "index":0}
}'

當設置distributeexperimental_distribute.train_distributeexperimental_distribute.remote_cluster時,此方法將啟動一個在當前主機上運行的客戶端,該客戶端連接到remote_cluster進行訓練和評估。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.estimator.train_and_evaluate。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。