當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.compat.v1.distribute.StrategyExtended.call_for_each_replica用法及代碼示例


用法

call_for_each_replica(
    fn, args=(), kwargs=None
)

參數

  • fn 要運行的函數(每個副本將運行一次)。
  • args 帶有 fn 的位置參數的元組或列表。
  • kwargs 帶有 fn 的關鍵字參數的字典。

返回

  • 合並所有副本的fn 的返回值。

每個副本運行一次fn

fn 可以調用 tf.get_replica_context() 來訪問 replica_id_in_sync_groupmerge_call() 等方法。

merge_call() 用於在副本之間進行通信並重新進入 cross-replica 上下文。所有副本都在遇到merge_call() 調用後暫停執行。之後執行 merge_fn -function。然後將其結果解包並返回給每個副本調用。之後繼續執行,直到 fn 完成或遇到另一個 merge_call() 。例子:

# Called once in "cross-replica" context.
def merge_fn(distribution, three_plus_replica_id):
  # sum the values across replicas
  return sum(distribution.experimental_local_results(three_plus_replica_id))

# Called once per replica in `distribution`, in a "replica" context.
def fn(three):
  replica_ctx = tf.get_replica_context()
  v = three + replica_ctx.replica_id_in_sync_group
  # Computes the sum of the `v` values across all replicas.
  s = replica_ctx.merge_call(merge_fn, args=(v,))
  return s + v

with distribution.scope():
  # in "cross-replica" context
  ...
  merged_results = distribution.run(fn, args=[3])
  # merged_results has the values from every replica execution of `fn`.
  # This statement prints a list:
  print(distribution.experimental_local_results(merged_results))

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.distribute.StrategyExtended.call_for_each_replica。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。