同步、聚合梯度并将它们传递给优化器的类。
继承自:Optimizer
用法
tf.compat.v1.train.SyncReplicasOptimizer(
opt, replicas_to_aggregate, total_num_replicas=None, variable_averages=None,
variables_to_average=None, use_locking=False, name='sync_replicas'
)
参数
-
opt
将用于计算和应用梯度的实际优化器。必须是优化器类之一。 -
replicas_to_aggregate
为每个变量更新聚合的副本数。 -
total_num_replicas
任务/工作人员/副本的总数,可能与 replicas_to_aggregate 不同。如果total_num_replicas > replicas_to_aggregate:它是backup_replicas + replicas_to_aggregate。如果 total_num_replicas < replicas_to_aggregate:Replicas 每次更新变量时计算多个批次。 -
variable_averages
可选的ExponentialMovingAverage
对象,用于维护variables_to_average
中传递的变量的移动平均线。 -
variables_to_average
需要平均的变量列表。仅当 variable_averages 传入时才需要。 -
use_locking
如果 True 使用锁进行更新操作。 -
name
String 。返回操作的可选名称。
此类已弃用。对于同步训练,请使用分发策略。
在典型的异步训练环境中,通常会有一些陈旧的梯度。例如,使用N-replica 异步训练,梯度将独立应用于变量 N 次。根据每个副本的训练速度,一些梯度可能是从前几步(平均 N-1 步)的变量副本中计算出来的。该优化器通过从所有副本收集梯度,对其进行平均,然后一次性将它们应用于变量来避免过时的梯度,之后副本可以获取新变量并继续。
创建了以下累加器/队列:
- N
gradient accumulators
,每个要训练的变量一个。梯度被推送给他们,首席工作人员将等到收集到足够的梯度,然后在应用到变量之前对它们进行平均。累加器将丢弃所有过时的梯度(更多细节在累加器操作中)。 - 1
token
队列,优化器在所有变量更新后推送新的 global_step 值。
创建以下局部变量:
sync_rep_local_step
,每个副本一个。与每个累加器中的 global_step 进行比较,以检查梯度的陈旧性。
优化器向图中添加节点以收集梯度并暂停训练器,直到变量更新。对于参数服务器作业:
- 为每个变量创建一个累加器,每个副本将梯度推送到累加器中,而不是直接将它们应用于变量。
- 一旦累积了足够多的梯度 (replicas_to_aggregate),每个累加器就会平均。
- 将平均梯度应用于变量。
- 只有在所有变量都更新后,才增加全局步骤。
- 仅在第 4 步之后,将
global_step
推送到token_queue
中,每个工作人员副本一次。工作人员现在可以获取全局步骤,使用它来更新其local_step 变量并开始下一批。请注意,有些工人可以消耗多个小批量,而有些工人可能连一个都不能消耗。这是因为只要存在令牌,每个工作人员都会获取小批量。如果一个工人由于某种原因被卡住并且没有消耗令牌,那么另一个工人可以使用它。
对于副本:
- 开始一个步骤:获取变量并计算梯度。
- 计算出梯度后,将它们推入梯度累加器。每个累加器都会检查陈旧并丢弃陈旧。
- 推送所有梯度后,从令牌队列中取出更新后的 global_step 值,并将该步骤记录到其 local_step 变量中。请注意,这实际上是一个障碍。
- 开始下一批。
用法
# Create any optimizer to update the variables, say a simple SGD:
opt = GradientDescentOptimizer(learning_rate=0.1)
# Wrap the optimizer with sync_replicas_optimizer with 50 replicas:at each
# step the optimizer collects 50 gradients before applying to variables.
# Note that if you want to have 2 backup replicas, you can change
# total_num_replicas=52 and make sure this number matches how many physical
# replicas you started in your job.
opt = tf.compat.v1.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=50,
total_num_replicas=50)
# Some models have startup_delays to help stabilize the model but when using
# sync_replicas training, set it to 0.
# Now you can call `minimize()` or `compute_gradients()` and
# `apply_gradients()` normally
training_op = opt.minimize(total_loss, global_step=self.global_step)
# You can create the hook which handles initialization and queues.
sync_replicas_hook = opt.make_session_run_hook(is_chief)
在训练计划中,每个工人都会像不同步一样运行train_op。
with training.MonitoredTrainingSession(
master=workers[worker_id].target, is_chief=is_chief,
hooks=[sync_replicas_hook]) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(training_op)
要将 SyncReplicasOptimizer 与 Estimator
一起使用,您需要在调用 fit 时发送 sync_replicas_hook。
my_estimator = DNNClassifier(..., optimizer=opt)
my_estimator.fit(..., hooks=[sync_replicas_hook])
相关用法
- Python tf.compat.v1.train.Supervisor.managed_session用法及代码示例
- Python tf.compat.v1.train.Supervisor用法及代码示例
- Python tf.compat.v1.train.SessionManager用法及代码示例
- Python tf.compat.v1.train.SingularMonitoredSession用法及代码示例
- Python tf.compat.v1.train.Saver用法及代码示例
- Python tf.compat.v1.train.SingularMonitoredSession.run_step_fn用法及代码示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代码示例
- Python tf.compat.v1.train.get_or_create_global_step用法及代码示例
- Python tf.compat.v1.train.cosine_decay_restarts用法及代码示例
- Python tf.compat.v1.train.Optimizer用法及代码示例
- Python tf.compat.v1.train.AdagradOptimizer.compute_gradients用法及代码示例
- Python tf.compat.v1.train.init_from_checkpoint用法及代码示例
- Python tf.compat.v1.train.Checkpoint用法及代码示例
- Python tf.compat.v1.train.Checkpoint.restore用法及代码示例
- Python tf.compat.v1.train.global_step用法及代码示例
- Python tf.compat.v1.train.MonitoredSession.run_step_fn用法及代码示例
- Python tf.compat.v1.train.RMSPropOptimizer.compute_gradients用法及代码示例
- Python tf.compat.v1.train.exponential_decay用法及代码示例
- Python tf.compat.v1.train.natural_exp_decay用法及代码示例
- Python tf.compat.v1.train.MomentumOptimizer用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.train.SyncReplicasOptimizer。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。