同步、聚合梯度並將它們傳遞給優化器的類。
繼承自:Optimizer
用法
tf.compat.v1.train.SyncReplicasOptimizer(
opt, replicas_to_aggregate, total_num_replicas=None, variable_averages=None,
variables_to_average=None, use_locking=False, name='sync_replicas'
)
參數
-
opt
將用於計算和應用梯度的實際優化器。必須是優化器類之一。 -
replicas_to_aggregate
為每個變量更新聚合的副本數。 -
total_num_replicas
任務/工作人員/副本的總數,可能與 replicas_to_aggregate 不同。如果total_num_replicas > replicas_to_aggregate:它是backup_replicas + replicas_to_aggregate。如果 total_num_replicas < replicas_to_aggregate:Replicas 每次更新變量時計算多個批次。 -
variable_averages
可選的ExponentialMovingAverage
對象,用於維護variables_to_average
中傳遞的變量的移動平均線。 -
variables_to_average
需要平均的變量列表。僅當 variable_averages 傳入時才需要。 -
use_locking
如果 True 使用鎖進行更新操作。 -
name
String 。返回操作的可選名稱。
此類已棄用。對於同步訓練,請使用分發策略。
在典型的異步訓練環境中,通常會有一些陳舊的梯度。例如,使用N-replica 異步訓練,梯度將獨立應用於變量 N 次。根據每個副本的訓練速度,一些梯度可能是從前幾步(平均 N-1 步)的變量副本中計算出來的。該優化器通過從所有副本收集梯度,對其進行平均,然後一次性將它們應用於變量來避免過時的梯度,之後副本可以獲取新變量並繼續。
創建了以下累加器/隊列:
- N
gradient accumulators
,每個要訓練的變量一個。梯度被推送給他們,首席工作人員將等到收集到足夠的梯度,然後在應用到變量之前對它們進行平均。累加器將丟棄所有過時的梯度(更多細節在累加器操作中)。 - 1
token
隊列,優化器在所有變量更新後推送新的 global_step 值。
創建以下局部變量:
sync_rep_local_step
,每個副本一個。與每個累加器中的 global_step 進行比較,以檢查梯度的陳舊性。
優化器向圖中添加節點以收集梯度並暫停訓練器,直到變量更新。對於參數服務器作業:
- 為每個變量創建一個累加器,每個副本將梯度推送到累加器中,而不是直接將它們應用於變量。
- 一旦累積了足夠多的梯度 (replicas_to_aggregate),每個累加器就會平均。
- 將平均梯度應用於變量。
- 隻有在所有變量都更新後,才增加全局步驟。
- 僅在第 4 步之後,將
global_step
推送到token_queue
中,每個工作人員副本一次。工作人員現在可以獲取全局步驟,使用它來更新其local_step 變量並開始下一批。請注意,有些工人可以消耗多個小批量,而有些工人可能連一個都不能消耗。這是因為隻要存在令牌,每個工作人員都會獲取小批量。如果一個工人由於某種原因被卡住並且沒有消耗令牌,那麽另一個工人可以使用它。
對於副本:
- 開始一個步驟:獲取變量並計算梯度。
- 計算出梯度後,將它們推入梯度累加器。每個累加器都會檢查陳舊並丟棄陳舊。
- 推送所有梯度後,從令牌隊列中取出更新後的 global_step 值,並將該步驟記錄到其 local_step 變量中。請注意,這實際上是一個障礙。
- 開始下一批。
用法
# Create any optimizer to update the variables, say a simple SGD:
opt = GradientDescentOptimizer(learning_rate=0.1)
# Wrap the optimizer with sync_replicas_optimizer with 50 replicas:at each
# step the optimizer collects 50 gradients before applying to variables.
# Note that if you want to have 2 backup replicas, you can change
# total_num_replicas=52 and make sure this number matches how many physical
# replicas you started in your job.
opt = tf.compat.v1.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=50,
total_num_replicas=50)
# Some models have startup_delays to help stabilize the model but when using
# sync_replicas training, set it to 0.
# Now you can call `minimize()` or `compute_gradients()` and
# `apply_gradients()` normally
training_op = opt.minimize(total_loss, global_step=self.global_step)
# You can create the hook which handles initialization and queues.
sync_replicas_hook = opt.make_session_run_hook(is_chief)
在訓練計劃中,每個工人都會像不同步一樣運行train_op。
with training.MonitoredTrainingSession(
master=workers[worker_id].target, is_chief=is_chief,
hooks=[sync_replicas_hook]) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(training_op)
要將 SyncReplicasOptimizer 與 Estimator
一起使用,您需要在調用 fit 時發送 sync_replicas_hook。
my_estimator = DNNClassifier(..., optimizer=opt)
my_estimator.fit(..., hooks=[sync_replicas_hook])
相關用法
- Python tf.compat.v1.train.Supervisor.managed_session用法及代碼示例
- Python tf.compat.v1.train.Supervisor用法及代碼示例
- Python tf.compat.v1.train.SessionManager用法及代碼示例
- Python tf.compat.v1.train.SingularMonitoredSession用法及代碼示例
- Python tf.compat.v1.train.Saver用法及代碼示例
- Python tf.compat.v1.train.SingularMonitoredSession.run_step_fn用法及代碼示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代碼示例
- Python tf.compat.v1.train.get_or_create_global_step用法及代碼示例
- Python tf.compat.v1.train.cosine_decay_restarts用法及代碼示例
- Python tf.compat.v1.train.Optimizer用法及代碼示例
- Python tf.compat.v1.train.AdagradOptimizer.compute_gradients用法及代碼示例
- Python tf.compat.v1.train.init_from_checkpoint用法及代碼示例
- Python tf.compat.v1.train.Checkpoint用法及代碼示例
- Python tf.compat.v1.train.Checkpoint.restore用法及代碼示例
- Python tf.compat.v1.train.global_step用法及代碼示例
- Python tf.compat.v1.train.MonitoredSession.run_step_fn用法及代碼示例
- Python tf.compat.v1.train.RMSPropOptimizer.compute_gradients用法及代碼示例
- Python tf.compat.v1.train.exponential_decay用法及代碼示例
- Python tf.compat.v1.train.natural_exp_decay用法及代碼示例
- Python tf.compat.v1.train.MomentumOptimizer用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.train.SyncReplicasOptimizer。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。