训练和评估 estimator
。
用法
tf.estimator.train_and_evaluate(
estimator, train_spec, eval_spec
)
参数
-
estimator
用于训练和评估的Estimator
实例。 -
train_spec
用于指定训练规范的TrainSpec
实例。 -
eval_spec
用于指定评估和导出规范的EvalSpec
实例。
返回
-
evaluate
调用Estimator
的结果的元组和使用指定的Exporter
的导出结果。目前,分布式训练模式的返回值未定义。
抛出
-
ValueError
如果环境变量TF_CONFIG
设置不正确。
此实用程序函数使用给定的 estimator
训练、评估和(可选)导出模型。所有训练相关规范都保存在train_spec
中,包括训练input_fn
和训练最大步数等。所有评估和导出相关规范都保存在eval_spec
中,包括评估input_fn
、步数等。
此实用程序函数为本地(非分布式)和分布式配置提供一致的行为。默认分发配置是基于参数服务器的between-graph 复制。对于其他类型的分布配置,例如all-reduce 训练,请使用 DistributionStrategies。
过拟合:为了避免过拟合,建议设置训练input_fn
,对训练数据进行适当的shuffle。
停止条件:为了可靠地支持分布式和非分布式配置,模型训练唯一支持的停止条件是 train_spec.max_steps
。如果 train_spec.max_steps
是 None
,则模型将永远被训练。如果模型停止条件不同,请小心使用。例如,假设模型预计使用一个 epoch 的训练数据进行训练,并且训练 input_fn
配置为在经过一个 epoch 后抛出 OutOfRangeError
,这将停止 Estimator.train
。对于three-training-worker 分布式配置,每个训练工作人员都可能独立完成整个 epoch。因此,模型将使用三个 epoch 的训练数据而不是一个 epoch 进行训练。
本地(非分布式)训练示例:
# Set up feature columns.
categorial_feature_a = categorial_column_with_hash_bucket(...)
categorial_feature_a_emb = embedding_column(
categorical_column=categorial_feature_a, ...)
... # other feature columns
estimator = DNNClassifier(
feature_columns=[categorial_feature_a_emb, ...],
hidden_units=[1024, 512, 256])
# Or set up the model directory
# estimator = DNNClassifier(
# config=tf.estimator.RunConfig(
# model_dir='/my_model', save_summary_steps=100),
# feature_columns=[categorial_feature_a_emb, ...],
# hidden_units=[1024, 512, 256])
# Input pipeline for train and evaluate.
def train_input_fn():# returns x, y
# please shuffle the data.
pass
def eval_input_fn():# returns x, y
pass
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
请注意,在当前实现中,estimator.evaluate
将被多次调用。这意味着将为每个evaluate
调用重新创建评估图(包括eval_input_fn)。 estimator.train
只会被调用一次。
分布式训练示例:
关于分布式训练的例子,上面的代码可以不加改动地使用(请确保所有worker的RunConfig.model_dir
设置为同一目录,即所有worker都可以读写的共享文件系统)。唯一要做的额外工作是为每个工作人员相应地设置环境变量TF_CONFIG
。
另请参阅分布式 TensorFlow。
设置环境变量取决于平台。例如,在 Linux 上,可以按如下方式进行($
是 shell 提示符):
$ TF_CONFIG='<replace_with_real_content>' python train_model.py
对于 TF_CONFIG
中的内容,假设训练集群规范如下所示:
cluster = {"chief":["host0:2222"],
"worker":["host1:2222", "host2:2222", "host3:2222"],
"ps":["host4:2222", "host5:2222"]}
首席训练工TF_CONFIG
示例(必须有且只有一个):
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
"cluster":{
"chief":["host0:2222"],
"worker":["host1:2222", "host2:2222", "host3:2222"],
"ps":["host4:2222", "host5:2222"]
},
"task":{"type":"chief", "index":0}
}'
请注意,首席员工也从事模型训练工作,类似于其他非首席训练员工(见下一段)。除了模型训练之外,它还管理一些额外的工作,例如检查点保存和恢复、编写摘要等。
非首席训练人员的TF_CONFIG
示例(可选,可以是多个):
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
"cluster":{
"chief":["host0:2222"],
"worker":["host1:2222", "host2:2222", "host3:2222"],
"ps":["host4:2222", "host5:2222"]
},
"task":{"type":"worker", "index":0}
}'
其中task.index
在本例中应分别设置为0、1、2,分别用于非首席训练人员。
参数服务器的TF_CONFIG
示例,又名 ps(可以是多个):
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
"cluster":{
"chief":["host0:2222"],
"worker":["host1:2222", "host2:2222", "host3:2222"],
"ps":["host4:2222", "host5:2222"]
},
"task":{"type":"ps", "index":0}
}'
其中 task.index
应设置为 0 和 1,在此示例中,分别用于参数服务器。
评估任务的TF_CONFIG
示例。 Evaluator 是一项特殊任务,不属于训练集群的一部分。可能只有一个。它用于模型评估。
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
"cluster":{
"chief":["host0:2222"],
"worker":["host1:2222", "host2:2222", "host3:2222"],
"ps":["host4:2222", "host5:2222"]
},
"task":{"type":"evaluator", "index":0}
}'
当设置distribute
或experimental_distribute.train_distribute
和experimental_distribute.remote_cluster
时,此方法将启动一个在当前主机上运行的客户端,该客户端连接到remote_cluster
进行训练和评估。
相关用法
- Python tf.estimator.TrainSpec用法及代码示例
- Python tf.estimator.LogisticRegressionHead用法及代码示例
- Python tf.estimator.MultiHead用法及代码示例
- Python tf.estimator.PoissonRegressionHead用法及代码示例
- Python tf.estimator.WarmStartSettings用法及代码示例
- Python tf.estimator.experimental.stop_if_lower_hook用法及代码示例
- Python tf.estimator.RunConfig用法及代码示例
- Python tf.estimator.MultiLabelHead用法及代码示例
- Python tf.estimator.experimental.stop_if_no_increase_hook用法及代码示例
- Python tf.estimator.BaselineEstimator用法及代码示例
- Python tf.estimator.DNNLinearCombinedEstimator用法及代码示例
- Python tf.estimator.Estimator用法及代码示例
- Python tf.estimator.experimental.LinearSDCA用法及代码示例
- Python tf.estimator.experimental.RNNClassifier用法及代码示例
- Python tf.estimator.experimental.make_early_stopping_hook用法及代码示例
- Python tf.estimator.LinearRegressor用法及代码示例
- Python tf.estimator.LinearEstimator用法及代码示例
- Python tf.estimator.DNNClassifier用法及代码示例
- Python tf.estimator.BaselineClassifier用法及代码示例
- Python tf.estimator.experimental.stop_if_higher_hook用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.estimator.train_and_evaluate。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。