此类指定Estimator
运行的配置。
用法
tf.estimator.RunConfig(
model_dir=None, tf_random_seed=None, save_summary_steps=100,
save_checkpoints_steps=_USE_DEFAULT, save_checkpoints_secs=_USE_DEFAULT,
session_config=None, keep_checkpoint_max=5, keep_checkpoint_every_n_hours=10000,
log_step_count_steps=100, train_distribute=None, device_fn=None, protocol=None,
eval_distribute=None, experimental_distribute=None,
experimental_max_worker_delay_secs=None, session_creation_timeout_secs=7200,
checkpoint_save_graph_def=True
)
参数
-
model_dir
保存模型参数、图形等的目录。如果PathLike
对象,路径将被解析。如果None
,将使用 Estimator 设置的默认值。 -
tf_random_seed
TensorFlow 初始化程序的随机种子。设置此值允许重新运行之间的一致性。 -
save_summary_steps
每隔这么多步骤保存摘要。 -
save_checkpoints_steps
每隔这么多步骤保存检查点。不能用save_checkpoints_secs
指定。 -
save_checkpoints_secs
每隔这么多秒保存检查点。不能用save_checkpoints_steps
指定。如果save_checkpoints_steps
和save_checkpoints_secs
均未在构造函数中设置,则默认为 600 秒。如果save_checkpoints_steps
和save_checkpoints_secs
都是None
,则禁用检查点。 -
session_config
用于设置会话参数的 ConfigProto,或None
。 -
keep_checkpoint_max
要保留的最近检查点文件的最大数量。创建新文件时,会删除旧文件。如果None
或 0,则保留所有检查点文件。默认为 5(即保留 5 个最近的检查点文件)。如果将保护程序传递给估计器,则此参数将被忽略。 -
keep_checkpoint_every_n_hours
要保存的每个检查点之间的小时数。默认值 10,000 小时有效地禁用了该函数。 -
log_step_count_steps
在训练期间记录全局步长和损失的频率,以全局步数为单位。还控制在训练期间记录全局步数(并写入摘要)的频率。 -
train_distribute
tf.distribute.Strategy
的可选实例。如果指定,则 Estimator 将根据该策略指定的策略在训练期间分发用户的模型。首选设置experimental_distribute.train_distribute
。 -
device_fn
为每个接受Operation
并返回设备字符串的Operation
调用的可调用对象。如果None
,默认为tf.train.replica_device_setter
返回的设备函数,使用 round-robin 策略。 -
protocol
一个可选参数,它指定启动服务器时使用的协议。None
表示默认为 grpc。 -
eval_distribute
tf.distribute.Strategy
的可选实例。如果指定,则 Estimator 将根据该策略指定的策略在评估期间分发用户的模型。首选设置experimental_distribute.eval_distribute
。 -
experimental_distribute
指定DistributionStrategy-related 配置的可选tf.contrib.distribute.DistributeConfig
对象。train_distribute
和eval_distribute
可以作为参数传递给RunConfig
或在experimental_distribute
中设置,但不能同时设置。 -
experimental_max_worker_delay_secs
一个可选整数,指定工作人员在开始之前应等待的最长时间。默认情况下,worker 是在交错时间启动的,每个 worker 最多延迟 60 秒。这是为了降低发散的风险,当许多工作人员同时更新随机初始化模型的权重时,可能会发生发散的风险。热启动模型并进行短时间(几分钟或更短时间)训练的用户应考虑减少此默认值以缩短训练时间。 -
session_creation_timeout_secs
工作人员应使用 MonitoredTrainingSession 等待会话可用(在初始化或恢复会话时)的最长时间。默认为 7200 秒,但用户可能希望设置较低的值以更快地检测变量/会话(重新)初始化的问题。 -
checkpoint_save_graph_def
是否将 GraphDef 和 MetaGraphDef 保存到checkpoint_dir
。 GraphDef 在会话创建后保存为graph.pbtxt
。 MetaGraphDefs 为每个检查点保存为model.ckpt-*.meta
。
抛出
-
ValueError
如果同时设置了save_checkpoints_steps
和save_checkpoints_secs
。
属性
-
checkpoint_save_graph_def
-
cluster_spec
-
device_fn
返回device_fn。如果 device_fn 不是
None
,它将覆盖Estimator
中使用的默认设备函数。否则使用默认值。 -
eval_distribute
用于评估的可选tf.distribute.Strategy
。 -
evaluation_master
-
experimental_max_worker_delay_secs
-
global_id_in_cluster
训练集群中的全局 id。训练集群中的所有全局 id 都是从一个递增的连续整数序列中分配的。第一个id是0。
注意:任务 id(属性字段
task_id
)正在跟踪具有相同任务类型的所有节点中的节点索引。例如,给定集群定义如下:cluster = {'chief':['host0:2222'], 'ps':['host1:2222', 'host2:2222'], 'worker':['host3:2222', 'host4:2222', 'host5:2222']}
任务类型
worker
的节点可以有 id 0, 1, 2。任务类型ps
的节点可以有 id, 0, 1。所以,task_id
不是唯一的,但对 (task_type
,task_id
) 可以是唯一的确定集群中的一个节点。全局id,即该字段,是跟踪该节点在集群中所有节点中的索引。它是唯一分配的。例如,对于上面给出的集群规范,全局 id 被分配为:
task_type | task_id | global_id -------------------------------- chief | 0 | 0 worker | 0 | 1 worker | 1 | 2 worker | 2 | 3 ps | 0 | 4 ps | 1 | 5
-
is_chief
-
keep_checkpoint_every_n_hours
-
keep_checkpoint_max
-
log_step_count_steps
-
master
-
model_dir
-
num_ps_replicas
-
num_worker_replicas
-
protocol
返回可选协议值。 -
save_checkpoints_secs
-
save_checkpoints_steps
-
save_summary_steps
-
service
返回定义的平台(在TF_CONFIG中)服务字典。 -
session_config
-
session_creation_timeout_secs
-
task_id
-
task_type
-
tf_random_seed
-
train_distribute
可选的tf.distribute.Strategy
用于训练。
相关用法
- Python tf.estimator.RegressionHead用法及代码示例
- Python tf.estimator.TrainSpec用法及代码示例
- Python tf.estimator.LogisticRegressionHead用法及代码示例
- Python tf.estimator.MultiHead用法及代码示例
- Python tf.estimator.PoissonRegressionHead用法及代码示例
- Python tf.estimator.WarmStartSettings用法及代码示例
- Python tf.estimator.experimental.stop_if_lower_hook用法及代码示例
- Python tf.estimator.MultiLabelHead用法及代码示例
- Python tf.estimator.experimental.stop_if_no_increase_hook用法及代码示例
- Python tf.estimator.BaselineEstimator用法及代码示例
- Python tf.estimator.DNNLinearCombinedEstimator用法及代码示例
- Python tf.estimator.Estimator用法及代码示例
- Python tf.estimator.experimental.LinearSDCA用法及代码示例
- Python tf.estimator.experimental.RNNClassifier用法及代码示例
- Python tf.estimator.experimental.make_early_stopping_hook用法及代码示例
- Python tf.estimator.LinearRegressor用法及代码示例
- Python tf.estimator.LinearEstimator用法及代码示例
- Python tf.estimator.DNNClassifier用法及代码示例
- Python tf.estimator.BaselineClassifier用法及代码示例
- Python tf.estimator.experimental.stop_if_higher_hook用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.estimator.RunConfig。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。