当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.estimator.RunConfig用法及代码示例


此类指定Estimator 运行的配置。

用法

tf.estimator.RunConfig(
    model_dir=None, tf_random_seed=None, save_summary_steps=100,
    save_checkpoints_steps=_USE_DEFAULT, save_checkpoints_secs=_USE_DEFAULT,
    session_config=None, keep_checkpoint_max=5, keep_checkpoint_every_n_hours=10000,
    log_step_count_steps=100, train_distribute=None, device_fn=None, protocol=None,
    eval_distribute=None, experimental_distribute=None,
    experimental_max_worker_delay_secs=None, session_creation_timeout_secs=7200,
    checkpoint_save_graph_def=True
)

参数

  • model_dir 保存模型参数、图形等的目录。如果PathLike 对象,路径将被解析。如果 None ,将使用 Estimator 设置的默认值。
  • tf_random_seed TensorFlow 初始化程序的随机种子。设置此值允许重新运行之间的一致性。
  • save_summary_steps 每隔这么多步骤保存摘要。
  • save_checkpoints_steps 每隔这么多步骤保存检查点。不能用 save_checkpoints_secs 指定。
  • save_checkpoints_secs 每隔这么多秒保存检查点。不能用 save_checkpoints_steps 指定。如果 save_checkpoints_stepssave_checkpoints_secs 均未在构造函数中设置,则默认为 600 秒。如果 save_checkpoints_stepssave_checkpoints_secs 都是 None ,则禁用检查点。
  • session_config 用于设置会话参数的 ConfigProto,或 None
  • keep_checkpoint_max 要保留的最近检查点文件的最大数量。创建新文件时,会删除旧文件。如果None 或 0,则保留所有检查点文件。默认为 5(即保留 5 个最近的检查点文件)。如果将保护程序传递给估计器,则此参数将被忽略。
  • keep_checkpoint_every_n_hours 要保存的每个检查点之间的小时数。默认值 10,000 小时有效地禁用了该函数。
  • log_step_count_steps 在训练期间记录全局步长和损失的频率,以全局步数为单位。还控制在训练期间记录全局步数(并写入摘要)的频率。
  • train_distribute tf.distribute.Strategy 的可选实例。如果指定,则 Estimator 将根据该策略指定的策略在训练期间分发用户的模型。首选设置experimental_distribute.train_distribute
  • device_fn 为每个接受 Operation 并返回设备字符串的 Operation 调用的可调用对象。如果 None ,默认为 tf.train.replica_device_setter 返回的设备函数,使用 round-robin 策略。
  • protocol 一个可选参数,它指定启动服务器时使用的协议。 None 表示默认为 grpc。
  • eval_distribute tf.distribute.Strategy 的可选实例。如果指定,则 Estimator 将根据该策略指定的策略在评估期间分发用户的模型。首选设置experimental_distribute.eval_distribute
  • experimental_distribute 指定DistributionStrategy-related 配置的可选tf.contrib.distribute.DistributeConfig 对象。 train_distributeeval_distribute 可以作为参数传递给 RunConfig 或在 experimental_distribute 中设置,但不能同时设置。
  • experimental_max_worker_delay_secs 一个可选整数,指定工作人员在开始之前应等待的最长时间。默认情况下,worker 是在交错时间启动的,每个 worker 最多延迟 60 秒。这是为了降低发散的风险,当许多工作人员同时更新随机初始化模型的权重时,可能会发生发散的风险。热启动模型并进行短时间(几分钟或更短时间)训练的用户应考虑减少此默认值以缩短训练时间。
  • session_creation_timeout_secs 工作人员应使用 MonitoredTrainingSession 等待会话可用(在初始化或恢复会话时)的最长时间。默认为 7200 秒,但用户可能希望设置较低的值以更快地检测变量/会话(重新)初始化的问题。
  • checkpoint_save_graph_def 是否将 GraphDef 和 MetaGraphDef 保存到 checkpoint_dir 。 GraphDef 在会话创建后保存为 graph.pbtxt 。 MetaGraphDefs 为每个检查点保存为 model.ckpt-*.meta

抛出

  • ValueError 如果同时设置了save_checkpoints_stepssave_checkpoints_secs

属性

  • checkpoint_save_graph_def
  • cluster_spec
  • device_fn 返回device_fn。

    如果 device_fn 不是 None ,它将覆盖 Estimator 中使用的默认设备函数。否则使用默认值。

  • eval_distribute 用于评估的可选tf.distribute.Strategy
  • evaluation_master
  • experimental_max_worker_delay_secs
  • global_id_in_cluster 训练集群中的全局 id。

    训练集群中的所有全局 id 都是从一个递增的连续整数序列中分配的。第一个id是0。

    注意:任务 id(属性字段 task_id )正在跟踪具有相同任务类型的所有节点中的节点索引。例如,给定集群定义如下:

    cluster = {'chief':['host0:2222'],
                 'ps':['host1:2222', 'host2:2222'],
                 'worker':['host3:2222', 'host4:2222', 'host5:2222']}

    任务类型 worker 的节点可以有 id 0, 1, 2。任务类型 ps 的节点可以有 id, 0, 1。所以,task_id 不是唯一的,但对 (task_type , task_id) 可以是唯一的确定集群中的一个节点。

    全局id,即该字段,是跟踪该节点在集群中所有节点中的索引。它是唯一分配的。例如,对于上面给出的集群规范,全局 id 被分配为:

    task_type  | task_id  |  global_id
      --------------------------------
      chief      | 0        |  0
      worker     | 0        |  1
      worker     | 1        |  2
      worker     | 2        |  3
      ps         | 0        |  4
      ps         | 1        |  5
  • is_chief
  • keep_checkpoint_every_n_hours
  • keep_checkpoint_max
  • log_step_count_steps
  • master
  • model_dir
  • num_ps_replicas
  • num_worker_replicas
  • protocol 返回可选协议值。
  • save_checkpoints_secs
  • save_checkpoints_steps
  • save_summary_steps
  • service 返回定义的平台(在TF_CONFIG中)服务字典。
  • session_config
  • session_creation_timeout_secs
  • task_id
  • task_type
  • tf_random_seed
  • train_distribute 可选的tf.distribute.Strategy 用于训练。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.estimator.RunConfig。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。