訓練和評估 estimator
。
用法
tf.estimator.train_and_evaluate(
estimator, train_spec, eval_spec
)
參數
-
estimator
用於訓練和評估的Estimator
實例。 -
train_spec
用於指定訓練規範的TrainSpec
實例。 -
eval_spec
用於指定評估和導出規範的EvalSpec
實例。
返回
-
evaluate
調用Estimator
的結果的元組和使用指定的Exporter
的導出結果。目前,分布式訓練模式的返回值未定義。
拋出
-
ValueError
如果環境變量TF_CONFIG
設置不正確。
此實用程序函數使用給定的 estimator
訓練、評估和(可選)導出模型。所有訓練相關規範都保存在train_spec
中,包括訓練input_fn
和訓練最大步數等。所有評估和導出相關規範都保存在eval_spec
中,包括評估input_fn
、步數等。
此實用程序函數為本地(非分布式)和分布式配置提供一致的行為。默認分發配置是基於參數服務器的between-graph 複製。對於其他類型的分布配置,例如all-reduce 訓練,請使用 DistributionStrategies。
過擬合:為了避免過擬合,建議設置訓練input_fn
,對訓練數據進行適當的shuffle。
停止條件:為了可靠地支持分布式和非分布式配置,模型訓練唯一支持的停止條件是 train_spec.max_steps
。如果 train_spec.max_steps
是 None
,則模型將永遠被訓練。如果模型停止條件不同,請小心使用。例如,假設模型預計使用一個 epoch 的訓練數據進行訓練,並且訓練 input_fn
配置為在經過一個 epoch 後拋出 OutOfRangeError
,這將停止 Estimator.train
。對於three-training-worker 分布式配置,每個訓練工作人員都可能獨立完成整個 epoch。因此,模型將使用三個 epoch 的訓練數據而不是一個 epoch 進行訓練。
本地(非分布式)訓練示例:
# Set up feature columns.
categorial_feature_a = categorial_column_with_hash_bucket(...)
categorial_feature_a_emb = embedding_column(
categorical_column=categorial_feature_a, ...)
... # other feature columns
estimator = DNNClassifier(
feature_columns=[categorial_feature_a_emb, ...],
hidden_units=[1024, 512, 256])
# Or set up the model directory
# estimator = DNNClassifier(
# config=tf.estimator.RunConfig(
# model_dir='/my_model', save_summary_steps=100),
# feature_columns=[categorial_feature_a_emb, ...],
# hidden_units=[1024, 512, 256])
# Input pipeline for train and evaluate.
def train_input_fn():# returns x, y
# please shuffle the data.
pass
def eval_input_fn():# returns x, y
pass
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
請注意,在當前實現中,estimator.evaluate
將被多次調用。這意味著將為每個evaluate
調用重新創建評估圖(包括eval_input_fn)。 estimator.train
隻會被調用一次。
分布式訓練示例:
關於分布式訓練的例子,上麵的代碼可以不加改動地使用(請確保所有worker的RunConfig.model_dir
設置為同一目錄,即所有worker都可以讀寫的共享文件係統)。唯一要做的額外工作是為每個工作人員相應地設置環境變量TF_CONFIG
。
另請參閱分布式 TensorFlow。
設置環境變量取決於平台。例如,在 Linux 上,可以按如下方式進行($
是 shell 提示符):
$ TF_CONFIG='<replace_with_real_content>' python train_model.py
對於 TF_CONFIG
中的內容,假設訓練集群規範如下所示:
cluster = {"chief":["host0:2222"],
"worker":["host1:2222", "host2:2222", "host3:2222"],
"ps":["host4:2222", "host5:2222"]}
首席訓練工TF_CONFIG
示例(必須有且隻有一個):
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
"cluster":{
"chief":["host0:2222"],
"worker":["host1:2222", "host2:2222", "host3:2222"],
"ps":["host4:2222", "host5:2222"]
},
"task":{"type":"chief", "index":0}
}'
請注意,首席員工也從事模型訓練工作,類似於其他非首席訓練員工(見下一段)。除了模型訓練之外,它還管理一些額外的工作,例如檢查點保存和恢複、編寫摘要等。
非首席訓練人員的TF_CONFIG
示例(可選,可以是多個):
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
"cluster":{
"chief":["host0:2222"],
"worker":["host1:2222", "host2:2222", "host3:2222"],
"ps":["host4:2222", "host5:2222"]
},
"task":{"type":"worker", "index":0}
}'
其中task.index
在本例中應分別設置為0、1、2,分別用於非首席訓練人員。
參數服務器的TF_CONFIG
示例,又名 ps(可以是多個):
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
"cluster":{
"chief":["host0:2222"],
"worker":["host1:2222", "host2:2222", "host3:2222"],
"ps":["host4:2222", "host5:2222"]
},
"task":{"type":"ps", "index":0}
}'
其中 task.index
應設置為 0 和 1,在此示例中,分別用於參數服務器。
評估任務的TF_CONFIG
示例。 Evaluator 是一項特殊任務,不屬於訓練集群的一部分。可能隻有一個。它用於模型評估。
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
"cluster":{
"chief":["host0:2222"],
"worker":["host1:2222", "host2:2222", "host3:2222"],
"ps":["host4:2222", "host5:2222"]
},
"task":{"type":"evaluator", "index":0}
}'
當設置distribute
或experimental_distribute.train_distribute
和experimental_distribute.remote_cluster
時,此方法將啟動一個在當前主機上運行的客戶端,該客戶端連接到remote_cluster
進行訓練和評估。
相關用法
- Python tf.estimator.TrainSpec用法及代碼示例
- Python tf.estimator.LogisticRegressionHead用法及代碼示例
- Python tf.estimator.MultiHead用法及代碼示例
- Python tf.estimator.PoissonRegressionHead用法及代碼示例
- Python tf.estimator.WarmStartSettings用法及代碼示例
- Python tf.estimator.experimental.stop_if_lower_hook用法及代碼示例
- Python tf.estimator.RunConfig用法及代碼示例
- Python tf.estimator.MultiLabelHead用法及代碼示例
- Python tf.estimator.experimental.stop_if_no_increase_hook用法及代碼示例
- Python tf.estimator.BaselineEstimator用法及代碼示例
- Python tf.estimator.DNNLinearCombinedEstimator用法及代碼示例
- Python tf.estimator.Estimator用法及代碼示例
- Python tf.estimator.experimental.LinearSDCA用法及代碼示例
- Python tf.estimator.experimental.RNNClassifier用法及代碼示例
- Python tf.estimator.experimental.make_early_stopping_hook用法及代碼示例
- Python tf.estimator.LinearRegressor用法及代碼示例
- Python tf.estimator.LinearEstimator用法及代碼示例
- Python tf.estimator.DNNClassifier用法及代碼示例
- Python tf.estimator.BaselineClassifier用法及代碼示例
- Python tf.estimator.experimental.stop_if_higher_hook用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.estimator.train_and_evaluate。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。