当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.estimator.train_and_evaluate用法及代码示例


训练和评估 estimator

用法

tf.estimator.train_and_evaluate(
    estimator, train_spec, eval_spec
)

参数

  • estimator 用于训练和评估的 Estimator 实例。
  • train_spec 用于指定训练规范的 TrainSpec 实例。
  • eval_spec 用于指定评估和导出规范的 EvalSpec 实例。

返回

  • evaluate 调用 Estimator 的结果的元组和使用指定的 Exporter 的导出结果。目前,分布式训练模式的返回值未定义。

抛出

  • ValueError 如果环境变量TF_CONFIG 设置不正确。

此实用程序函数使用给定的 estimator 训练、评估和(可选)导出模型。所有训练相关规范都保存在train_spec中,包括训练input_fn和训练最大步数等。所有评估和导出相关规范都保存在eval_spec中,包括评估input_fn、步数等。

此实用程序函数为本地(非分布式)和分布式配置提供一致的行为。默认分发配置是基于参数服务器的between-graph 复制。对于其他类型的分布配置,例如all-reduce 训练,请使用 DistributionStrategies。

过拟合:为了避免过拟合,建议设置训练input_fn,对训练数据进行适当的shuffle。

停止条件:为了可靠地支持分布式和非分布式配置,模型训练唯一支持的停止条件是 train_spec.max_steps 。如果 train_spec.max_stepsNone ,则模型将永远被训练。如果模型停止条件不同,请小心使用。例如,假设模型预计使用一个 epoch 的训练数据进行训练,并且训练 input_fn 配置为在经过一个 epoch 后抛出 OutOfRangeError,这将停止 Estimator.train 。对于three-training-worker 分布式配置,每个训练工作人员都可能独立完成整个 epoch。因此,模型将使用三个 epoch 的训练数据而不是一个 epoch 进行训练。

本地(非分布式)训练示例:

# Set up feature columns.
categorial_feature_a = categorial_column_with_hash_bucket(...)
categorial_feature_a_emb = embedding_column(
    categorical_column=categorial_feature_a, ...)
...  # other feature columns

estimator = DNNClassifier(
    feature_columns=[categorial_feature_a_emb, ...],
    hidden_units=[1024, 512, 256])

# Or set up the model directory
#   estimator = DNNClassifier(
#       config=tf.estimator.RunConfig(
#           model_dir='/my_model', save_summary_steps=100),
#       feature_columns=[categorial_feature_a_emb, ...],
#       hidden_units=[1024, 512, 256])

# Input pipeline for train and evaluate.
def train_input_fn():# returns x, y
  # please shuffle the data.
  pass
def eval_input_fn():# returns x, y
  pass

train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

请注意,在当前实现中,estimator.evaluate 将被多次调用。这意味着将为每个evaluate 调用重新创建评估图(包括eval_input_fn)。 estimator.train 只会被调用一次。

分布式训练示例:

关于分布式训练的例子,上面的代码可以不加改动地使用(请确保所有worker的RunConfig.model_dir设置为同一目录,即所有worker都可以读写的共享文件系统)。唯一要做的额外工作是为每个工作人员相应地设置环境变量TF_CONFIG

另请参阅分布式 TensorFlow。

设置环境变量取决于平台。例如,在 Linux 上,可以按如下方式进行($ 是 shell 提示符):

$ TF_CONFIG='<replace_with_real_content>' python train_model.py

对于 TF_CONFIG 中的内容,假设训练集群规范如下所示:

cluster = {"chief":["host0:2222"],
           "worker":["host1:2222", "host2:2222", "host3:2222"],
           "ps":["host4:2222", "host5:2222"]}

首席训练工TF_CONFIG 示例(必须有且只有一个):

# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
    "cluster":{
        "chief":["host0:2222"],
        "worker":["host1:2222", "host2:2222", "host3:2222"],
        "ps":["host4:2222", "host5:2222"]
    },
    "task":{"type":"chief", "index":0}
}'

请注意,首席员工也从事模型训练工作,类似于其他非首席训练员工(见下一段)。除了模型训练之外,它还管理一些额外的工作,例如检查点保存和恢复、编写摘要等。

非首席训练人员的TF_CONFIG 示例(可选,可以是多个):

# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
    "cluster":{
        "chief":["host0:2222"],
        "worker":["host1:2222", "host2:2222", "host3:2222"],
        "ps":["host4:2222", "host5:2222"]
    },
    "task":{"type":"worker", "index":0}
}'

其中task.index在本例中应分别设置为0、1、2,分别用于非首席训练人员。

参数服务器的TF_CONFIG 示例,又名 ps(可以是多个):

# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
    "cluster":{
        "chief":["host0:2222"],
        "worker":["host1:2222", "host2:2222", "host3:2222"],
        "ps":["host4:2222", "host5:2222"]
    },
    "task":{"type":"ps", "index":0}
}'

其中 task.index 应设置为 0 和 1,在此示例中,分别用于参数服务器。

评估任务的TF_CONFIG 示例。 Evaluator 是一项特殊任务,不属于训练集群的一部分。可能只有一个。它用于模型评估。

# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
    "cluster":{
        "chief":["host0:2222"],
        "worker":["host1:2222", "host2:2222", "host3:2222"],
        "ps":["host4:2222", "host5:2222"]
    },
    "task":{"type":"evaluator", "index":0}
}'

当设置distributeexperimental_distribute.train_distributeexperimental_distribute.remote_cluster时,此方法将启动一个在当前主机上运行的客户端,该客户端连接到remote_cluster进行训练和评估。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.estimator.train_and_evaluate。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。