用法
run(
fn, args=(), kwargs=None, options=None
)
參數
-
fn
在每個副本上運行的函數。 -
args
fn
的可選位置參數。它的元素可以是張量、張量的嵌套結構或tf.distribute.DistributedValues
。 -
kwargs
fn
的可選關鍵字參數。它的元素可以是張量、張量的嵌套結構或tf.distribute.DistributedValues
。 -
options
tf.distribute.RunOptions
的可選實例,指定運行fn
的選項。
返回
-
跨副本合並
fn
的返回值。返回值的結構與fn
的返回值相同。結構中的每個元素都可以是tf.distribute.DistributedValues
、Tensor
對象或Tensor
s(例如,如果在單個副本上運行)。
使用給定的參數在每個副本上調用 fn
。
此方法是使用 tf.distribute 對象分配計算的主要方法。它在每個副本上調用fn
。如果 args
或 kwargs
具有 tf.distribute.DistributedValues
,例如由 tf.distribute.Strategy.experimental_distribute_dataset
或 tf.distribute.Strategy.distribute_datasets_from_function
中的 tf.distribute.DistributedDataset
生成的那些,當 fn
在特定副本上執行時,它將使用以下組件執行tf.distribute.DistributedValues
對應於該副本。
fn
在副本上下文中調用。 fn
可以調用 tf.distribute.get_replica_context()
來訪問成員,例如 all_reduce
。有關副本上下文的概念,請參閱 tf.distribute 的 module-level 文檔字符串。
args
或 kwargs
中的所有參數都可以是張量的嵌套結構,例如張量列表,在這種情況下 args
和 kwargs
將傳遞給每個副本上調用的 fn
。或者 args
或 kwargs
可以是包含張量或複合張量的 tf.distribute.DistributedValues
,即 tf.compat.v1.TensorInfo.CompositeTensor
,在這種情況下,每個 fn
調用都將獲得與其副本相對應的 tf.distribute.DistributedValues
的組件。請注意,不支持上述類型的任意 Python 值。
重要的:根據tf.distribute.Strategy
的實現以及是否啟用了即刻執行,fn
可能會被調用一次或多次。如果 fn
使用 tf.function
注釋或在 tf.function
內調用 tf.distribute.Strategy.run
(默認情況下,在 tf.function
內禁用即刻執行),則每個副本調用一次 fn
以生成 Tensorflow 圖,然後將被重用於新輸入的執行。否則,如果啟用了即刻執行,fn
將在每個副本的每個步驟中調用一次,就像常規 python 代碼一樣。
示例用法:
- 恒定張量輸入。
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
tensor_input = tf.constant(3.0)
@tf.function
def replica_fn(input):
return input*2.0
result = strategy.run(replica_fn, args=(tensor_input,))
result
PerReplica:{
0:<tf.Tensor:shape=(), dtype=float32, numpy=6.0>,
1:<tf.Tensor:shape=(), dtype=float32, numpy=6.0>
}
- 分布式值輸入。
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
@tf.function
def run():
def value_fn(value_context):
return value_context.num_replicas_in_sync
distributed_values = (
strategy.experimental_distribute_values_from_function(
value_fn))
def replica_fn2(input):
return input*2
return strategy.run(replica_fn2, args=(distributed_values,))
result = run()
result
<tf.Tensor:shape=(), dtype=int32, numpy=4>
- 使用
tf.distribute.ReplicaContext
來減少所有值。
strategy = tf.distribute.MirroredStrategy(["gpu:0", "gpu:1"])
@tf.function
def run():
def value_fn(value_context):
return tf.constant(value_context.replica_id_in_sync_group)
distributed_values = (
strategy.experimental_distribute_values_from_function(
value_fn))
def replica_fn(input):
return tf.distribute.get_replica_context().all_reduce("sum", input)
return strategy.run(replica_fn, args=(distributed_values,))
result = run()
result
PerReplica:{
0:<tf.Tensor:shape=(), dtype=int32, numpy=1>,
1:<tf.Tensor:shape=(), dtype=int32, numpy=1>
}
相關用法
- Python tf.compat.v1.distribute.Strategy.reduce用法及代碼示例
- Python tf.compat.v1.distribute.Strategy.experimental_make_numpy_dataset用法及代碼示例
- Python tf.compat.v1.distribute.Strategy.make_input_fn_iterator用法及代碼示例
- Python tf.compat.v1.distribute.Strategy.scope用法及代碼示例
- Python tf.compat.v1.distribute.Strategy.experimental_distribute_dataset用法及代碼示例
- Python tf.compat.v1.distribute.StrategyExtended.batch_reduce_to用法及代碼示例
- Python tf.compat.v1.distribute.StrategyExtended.colocate_vars_with用法及代碼示例
- Python tf.compat.v1.distribute.StrategyExtended.non_slot_devices用法及代碼示例
- Python tf.compat.v1.distribute.StrategyExtended.update用法及代碼示例
- Python tf.compat.v1.distribute.StrategyExtended.reduce_to用法及代碼示例
- Python tf.compat.v1.distribute.StrategyExtended.call_for_each_replica用法及代碼示例
- Python tf.compat.v1.distribute.Strategy用法及代碼示例
- Python tf.compat.v1.distribute.StrategyExtended.variable_created_in_scope用法及代碼示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_distribute_dataset用法及代碼示例
- Python tf.compat.v1.distribute.OneDeviceStrategy用法及代碼示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_make_numpy_dataset用法及代碼示例
- Python tf.compat.v1.distribute.experimental.TPUStrategy.experimental_distribute_dataset用法及代碼示例
- Python tf.compat.v1.distribute.OneDeviceStrategy.scope用法及代碼示例
- Python tf.compat.v1.distribute.experimental.TPUStrategy.experimental_make_numpy_dataset用法及代碼示例
- Python tf.compat.v1.distribute.OneDeviceStrategy.experimental_distribute_dataset用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.distribute.Strategy.run。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。