当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.distribute.ReplicaContext.merge_call用法及代码示例


用法

merge_call(
      merge_fn, args=(), kwargs=None
  )

参数

  • `merge_fn` 连接来自作为 PerReplica 给出的线程的参数的函数。它接受tf.distribute.Strategy 对象作为第一个参数。
  • `args` `merge_fn` 的带有位置 per-thread 参数的列表或元组。
  • `kwargs` `merge_fn` 的关键字 per-thread 参数的字典。

返回

  • `merge_fn` 的返回值,`PerReplica` 值除外,它是解包的。

跨副本合并 args 并在 cross-replica 上下文中运行 merge_fn

当对 strategy.run(step_fn, ...) 的调用触发对 step_fn 的多个调用时,这允许进行通信和协调。

有关说明,请参见tf.distribute.Strategy.run

如果不在分布式范围内,则相当于:

strategy = tf.distribute.get_strategy()
with cross-replica-context(strategy):
  return merge_fn(strategy, *args, **kwargs)
```

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.distribute.ReplicaContext.merge_call。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。