当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.distribute.StrategyExtended.call_for_each_replica用法及代码示例


用法

call_for_each_replica(
    fn, args=(), kwargs=None
)

参数

  • fn 要运行的函数(每个副本将运行一次)。
  • args 带有 fn 的位置参数的元组或列表。
  • kwargs 带有 fn 的关键字参数的字典。

返回

  • 合并所有副本的fn 的返回值。

每个副本运行一次fn

fn 可以调用 tf.get_replica_context() 来访问 replica_id_in_sync_groupmerge_call() 等方法。

merge_call() 用于在副本之间进行通信并重新进入 cross-replica 上下文。所有副本都在遇到merge_call() 调用后暂停执行。之后执行 merge_fn -function。然后将其结果解包并返回给每个副本调用。之后继续执行,直到 fn 完成或遇到另一个 merge_call() 。例子:

# Called once in "cross-replica" context.
def merge_fn(distribution, three_plus_replica_id):
  # sum the values across replicas
  return sum(distribution.experimental_local_results(three_plus_replica_id))

# Called once per replica in `distribution`, in a "replica" context.
def fn(three):
  replica_ctx = tf.get_replica_context()
  v = three + replica_ctx.replica_id_in_sync_group
  # Computes the sum of the `v` values across all replicas.
  s = replica_ctx.merge_call(merge_fn, args=(v,))
  return s + v

with distribution.scope():
  # in "cross-replica" context
  ...
  merged_results = distribution.run(fn, args=[3])
  # merged_results has the values from every replica execution of `fn`.
  # This statement prints a list:
  print(distribution.experimental_local_results(merged_results))

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.distribute.StrategyExtended.call_for_each_replica。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。