用法
run(
fn, args=(), kwargs=None, options=None
)
参数
-
fn
在每个副本上运行的函数。 -
args
fn
的可选位置参数。它的元素可以是张量、张量的嵌套结构或tf.distribute.DistributedValues
。 -
kwargs
fn
的可选关键字参数。它的元素可以是张量、张量的嵌套结构或tf.distribute.DistributedValues
。 -
options
tf.distribute.RunOptions
的可选实例,指定运行fn
的选项。
返回
-
跨副本合并
fn
的返回值。返回值的结构与fn
的返回值相同。结构中的每个元素都可以是tf.distribute.DistributedValues
、Tensor
对象或Tensor
s(例如,如果在单个副本上运行)。
使用给定的参数在每个副本上调用 fn
。
此方法是使用 tf.distribute 对象分配计算的主要方法。它在每个副本上调用fn
。如果 args
或 kwargs
具有 tf.distribute.DistributedValues
,例如由 tf.distribute.Strategy.experimental_distribute_dataset
或 tf.distribute.Strategy.distribute_datasets_from_function
中的 tf.distribute.DistributedDataset
生成的那些,当 fn
在特定副本上执行时,它将使用以下组件执行tf.distribute.DistributedValues
对应于该副本。
fn
在副本上下文中调用。 fn
可以调用 tf.distribute.get_replica_context()
来访问成员,例如 all_reduce
。有关副本上下文的概念,请参阅 tf.distribute 的 module-level 文档字符串。
args
或 kwargs
中的所有参数都可以是张量的嵌套结构,例如张量列表,在这种情况下 args
和 kwargs
将传递给每个副本上调用的 fn
。或者 args
或 kwargs
可以是包含张量或复合张量的 tf.distribute.DistributedValues
,即 tf.compat.v1.TensorInfo.CompositeTensor
,在这种情况下,每个 fn
调用都将获得与其副本相对应的 tf.distribute.DistributedValues
的组件。请注意,不支持上述类型的任意 Python 值。
重要的:根据tf.distribute.Strategy
的实现以及是否启用了即刻执行,fn
可能会被调用一次或多次。如果 fn
使用 tf.function
注释或在 tf.function
内调用 tf.distribute.Strategy.run
(默认情况下,在 tf.function
内禁用即刻执行),则每个副本调用一次 fn
以生成 Tensorflow 图,然后将被重用于新输入的执行。否则,如果启用了即刻执行,fn
将在每个副本的每个步骤中调用一次,就像常规 python 代码一样。
示例用法:
- 恒定张量输入。
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
tensor_input = tf.constant(3.0)
@tf.function
def replica_fn(input):
return input*2.0
result = strategy.run(replica_fn, args=(tensor_input,))
result
PerReplica:{
0:<tf.Tensor:shape=(), dtype=float32, numpy=6.0>,
1:<tf.Tensor:shape=(), dtype=float32, numpy=6.0>
}
- 分布式值输入。
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
@tf.function
def run():
def value_fn(value_context):
return value_context.num_replicas_in_sync
distributed_values = (
strategy.experimental_distribute_values_from_function(
value_fn))
def replica_fn2(input):
return input*2
return strategy.run(replica_fn2, args=(distributed_values,))
result = run()
result
<tf.Tensor:shape=(), dtype=int32, numpy=4>
- 使用
tf.distribute.ReplicaContext
来减少所有值。
strategy = tf.distribute.MirroredStrategy(["gpu:0", "gpu:1"])
@tf.function
def run():
def value_fn(value_context):
return tf.constant(value_context.replica_id_in_sync_group)
distributed_values = (
strategy.experimental_distribute_values_from_function(
value_fn))
def replica_fn(input):
return tf.distribute.get_replica_context().all_reduce("sum", input)
return strategy.run(replica_fn, args=(distributed_values,))
result = run()
result
PerReplica:{
0:<tf.Tensor:shape=(), dtype=int32, numpy=1>,
1:<tf.Tensor:shape=(), dtype=int32, numpy=1>
}
相关用法
- Python tf.compat.v1.distribute.OneDeviceStrategy.reduce用法及代码示例
- Python tf.compat.v1.distribute.OneDeviceStrategy.scope用法及代码示例
- Python tf.compat.v1.distribute.OneDeviceStrategy.experimental_distribute_dataset用法及代码示例
- Python tf.compat.v1.distribute.OneDeviceStrategy.make_input_fn_iterator用法及代码示例
- Python tf.compat.v1.distribute.OneDeviceStrategy.experimental_make_numpy_dataset用法及代码示例
- Python tf.compat.v1.distribute.OneDeviceStrategy用法及代码示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_distribute_dataset用法及代码示例
- Python tf.compat.v1.distribute.Strategy.run用法及代码示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_make_numpy_dataset用法及代码示例
- Python tf.compat.v1.distribute.experimental.TPUStrategy.experimental_distribute_dataset用法及代码示例
- Python tf.compat.v1.distribute.StrategyExtended.batch_reduce_to用法及代码示例
- Python tf.compat.v1.distribute.experimental.TPUStrategy.experimental_make_numpy_dataset用法及代码示例
- Python tf.compat.v1.distribute.experimental.CentralStorageStrategy.make_input_fn_iterator用法及代码示例
- Python tf.compat.v1.distribute.experimental.MultiWorkerMirroredStrategy.reduce用法及代码示例
- Python tf.compat.v1.distribute.experimental.MultiWorkerMirroredStrategy.experimental_make_numpy_dataset用法及代码示例
- Python tf.compat.v1.distribute.Strategy.experimental_make_numpy_dataset用法及代码示例
- Python tf.compat.v1.distribute.StrategyExtended.colocate_vars_with用法及代码示例
- Python tf.compat.v1.distribute.experimental.CentralStorageStrategy用法及代码示例
- Python tf.compat.v1.distribute.ReplicaContext用法及代码示例
- Python tf.compat.v1.distribute.experimental.TPUStrategy.scope用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.distribute.OneDeviceStrategy.run。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。