當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.distribute.experimental.CollectiveHints用法及代碼示例


AllReduce 等集體操作的提示。

用法

tf.distribute.experimental.CollectiveHints(
    bytes_per_pack=0, timeout_seconds=None
)

參數

  • bytes_per_pack 一個非負整數。將集體操作分成一定大小的包。如果為零,則自動確定該值。這僅適用於當前帶有 MultiWorkerMirroredStrategy 的 all-reduce。
  • timeout_seconds 浮點數或無,以秒為單位超時。如果不是 None,如果花費的時間超過此超時時間,則集體會引發 tf.errors.DeadlineExceededError。這在調試掛起問題時很有用。這應該隻用於調試,因為它為每個集合創建一個新線程,即 timeout_seconds * num_collectives_per_second 更多線程的開銷。這僅適用於 tf.distribute.experimental.MultiWorkerMirroredStrategy

拋出

  • ValueError 當參數具有無效值時。

這可以傳遞給 tf.distribute.get_replica_context().all_reduce() 等方法以優化集體操作性能。請注意,這些隻是提示,可能會也可能不會改變實際行為。某些選項僅適用於某些策略,而被其他選項忽略。

一種常見的優化是將梯度all-reduce分成多個包,以便權重更新可以與梯度all-reduce重疊。

例子:

  • bytes_per_pack
hints = tf.distribute.experimental.CollectiveHints(
    bytes_per_pack=50 * 1024 * 1024)
grads = tf.distribute.get_replica_context().all_reduce(
    'sum', grads, experimental_hints=hints)
optimizer.apply_gradients(zip(grads, vars),
    experimental_aggregate_gradients=False)
  • timeout_seconds
strategy = tf.distribute.MirroredStrategy()
hints = tf.distribute.experimental.CollectiveHints(
    timeout_seconds=120.0)
try:
  strategy.reduce("sum", v, axis=None, experimental_hints=hints)
except tf.errors.DeadlineExceededError:
  do_something()

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.distribute.experimental.CollectiveHints。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。