當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.distribute.TPUStrategy.run用法及代碼示例


用法

run(
    fn, args=(), kwargs=None, options=None
)

參數

  • fn 要運行的函數。輸出必須是 Tensortf.nest
  • args (可選)fn 的位置參數。
  • kwargs (可選)fn 的關鍵字參數。
  • options (可選)tf.distribute.RunOptions 的實例,指定運行 fn 的選項。

返回

  • 跨副本合並fn 的返回值。返回值的結構與 fn 的返回值相同。結構中的每個元素都可以是 tf.distribute.DistributedValuesTensor 對象或 Tensor s(例如,如果在單個副本上運行)。

在每個 TPU 副本上運行由 fn 定義的計算。

在每個副本上執行 fn 指定的操作。如果 argskwargs 具有 tf.distribute.DistributedValues ,例如由 tf.distribute.Strategy.experimental_distribute_datasettf.distribute.Strategy.distribute_datasets_from_function 中的 tf.distribute.DistributedDataset 生成的那些,當 fn 在特定副本上執行時,它將使用以下組件執行tf.distribute.DistributedValues 對應於該副本。

fn 可以調用 tf.distribute.get_replica_context() 來訪問成員,例如 all_reduce

argskwargs 中的所有參數應該是張量嵌套或包含張量或複合張量的 tf.distribute.DistributedValues

示例用法:

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
@tf.function
def run():
  def value_fn(value_context):
    return value_context.num_replicas_in_sync
  distributed_values = (
      strategy.experimental_distribute_values_from_function(value_fn))
  def replica_fn(input):
    return input * 2
  return strategy.run(replica_fn, args=(distributed_values,))
result = run()

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.distribute.TPUStrategy.run。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。