当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.distribute.experimental.CollectiveHints用法及代码示例


AllReduce 等集体操作的提示。

用法

tf.distribute.experimental.CollectiveHints(
    bytes_per_pack=0, timeout_seconds=None
)

参数

  • bytes_per_pack 一个非负整数。将集体操作分成一定大小的包。如果为零,则自动确定该值。这仅适用于当前带有 MultiWorkerMirroredStrategy 的 all-reduce。
  • timeout_seconds 浮点数或无,以秒为单位超时。如果不是 None,如果花费的时间超过此超时时间,则集体会引发 tf.errors.DeadlineExceededError。这在调试挂起问题时很有用。这应该只用于调试,因为它为每个集合创建一个新线程,即 timeout_seconds * num_collectives_per_second 更多线程的开销。这仅适用于 tf.distribute.experimental.MultiWorkerMirroredStrategy

抛出

  • ValueError 当参数具有无效值时。

这可以传递给 tf.distribute.get_replica_context().all_reduce() 等方法以优化集体操作性能。请注意,这些只是提示,可能会也可能不会改变实际行为。某些选项仅适用于某些策略,而被其他选项忽略。

一种常见的优化是将梯度all-reduce分成多个包,以便权重更新可以与梯度all-reduce重叠。

例子:

  • bytes_per_pack
hints = tf.distribute.experimental.CollectiveHints(
    bytes_per_pack=50 * 1024 * 1024)
grads = tf.distribute.get_replica_context().all_reduce(
    'sum', grads, experimental_hints=hints)
optimizer.apply_gradients(zip(grads, vars),
    experimental_aggregate_gradients=False)
  • timeout_seconds
strategy = tf.distribute.MirroredStrategy()
hints = tf.distribute.experimental.CollectiveHints(
    timeout_seconds=120.0)
try:
  strategy.reduce("sum", v, axis=None, experimental_hints=hints)
except tf.errors.DeadlineExceededError:
  do_something()

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.distribute.experimental.CollectiveHints。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。