用法
gather(
value, axis
)
参数
-
value
tf.distribute.DistributedValues
实例,例如由Strategy.run
返回,组合成一个张量。当与tf.distribute.OneDeviceStrategy
或默认策略一起使用时,它也可以是常规张量。构成 DistributedValues 的张量只能是具有非零秩的密集张量,而不是tf.IndexedSlices
。 -
axis
0-D int32 张量。沿其聚集的维度。必须在 [0, rank(value)) 范围内。
返回
-
Tensor
是value
沿axis
维度跨副本的串联。
沿 axis
跨副本收集 value
到当前设备。
给定一个 tf.distribute.DistributedValues
或 tf.Tensor
类对象 value
,此 API 沿 axis
-th 维度跨副本收集并连接 value
。结果被复制到 "current" 设备,该设备通常是运行程序的工作线程的 CPU。对于 tf.distribute.TPUStrategy
,它是第一个 TPU 主机。对于 multi-client tf.distribute.MultiWorkerMirroredStrategy
,这是每个工作人员的 CPU。
此 API 只能在 cross-replica 上下文中调用。对于副本上下文中的对应项,请参阅 tf.distribute.ReplicaContext.all_gather
。
注意:对于除 tf.distribute.TPUStrategy
之外的所有策略,不同副本上的输入 value
必须具有相同的等级,并且它们的形状必须在除 axis
-th 维度之外的所有维度上相同。换句话说,它们的形状在d
维度上不能不同,其中d
不等于axis
参数。例如,给定一个 tf.distribute.DistributedValues
在两个副本上具有形状为 (1, 2, 3)
和 (1, 3, 3)
的分量张量,您可以在其上调用 gather(..., axis=1, ...)
,但不能调用 gather(..., axis=0, ...)
或 gather(..., axis=2, ...)
。但是,对于 tf.distribute.TPUStrategy.gather
,所有张量必须具有完全相同的秩和相同的形状。
注意:给定 tf.distribute.DistributedValues
value
,其分量张量必须具有非零秩。否则,请考虑在收集它们之前使用tf.expand_dims
。
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
# A DistributedValues with component tensor of shape (2, 1) on each replica
distributed_values = strategy.experimental_distribute_values_from_function(lambda _:tf.identity(tf.constant([[1], [2]])))
@tf.function
def run():
return strategy.gather(distributed_values, axis=0)
run()
<tf.Tensor:shape=(4, 1), dtype=int32, numpy=
array([[1],
[2],
[1],
[2]], dtype=int32)>
考虑以下示例以获得更多组合:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
single_tensor = tf.reshape(tf.range(6), shape=(1,2,3))
distributed_values = strategy.experimental_distribute_values_from_function(lambda _:tf.identity(single_tensor))
@tf.function
def run(axis):
return strategy.gather(distributed_values, axis=axis)
axis=0
run(axis)
<tf.Tensor:shape=(4, 2, 3), dtype=int32, numpy=
array([[[0, 1, 2],
[3, 4, 5]],
[[0, 1, 2],
[3, 4, 5]],
[[0, 1, 2],
[3, 4, 5]],
[[0, 1, 2],
[3, 4, 5]]], dtype=int32)>
axis=1
run(axis)
<tf.Tensor:shape=(1, 8, 3), dtype=int32, numpy=
array([[[0, 1, 2],
[3, 4, 5],
[0, 1, 2],
[3, 4, 5],
[0, 1, 2],
[3, 4, 5],
[0, 1, 2],
[3, 4, 5]]], dtype=int32)>
axis=2
run(axis)
<tf.Tensor:shape=(1, 2, 12), dtype=int32, numpy=
array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)>
相关用法
- Python tf.distribute.experimental.TPUStrategy.experimental_distribute_values_from_function用法及代码示例
- Python tf.distribute.experimental.TPUStrategy.experimental_distribute_dataset用法及代码示例
- Python tf.distribute.experimental.TPUStrategy.reduce用法及代码示例
- Python tf.distribute.experimental.TPUStrategy.scope用法及代码示例
- Python tf.distribute.experimental.TPUStrategy用法及代码示例
- Python tf.distribute.experimental.MultiWorkerMirroredStrategy.gather用法及代码示例
- Python tf.distribute.experimental.MultiWorkerMirroredStrategy用法及代码示例
- Python tf.distribute.experimental.rpc.Server.create用法及代码示例
- Python tf.distribute.experimental.MultiWorkerMirroredStrategy.experimental_distribute_dataset用法及代码示例
- Python tf.distribute.experimental.partitioners.Partitioner.__call__用法及代码示例
- Python tf.distribute.experimental.MultiWorkerMirroredStrategy.run用法及代码示例
- Python tf.distribute.experimental.partitioners.MaxSizePartitioner.__call__用法及代码示例
- Python tf.distribute.experimental.partitioners.FixedShardsPartitioner用法及代码示例
- Python tf.distribute.experimental.ParameterServerStrategy.gather用法及代码示例
- Python tf.distribute.experimental.MultiWorkerMirroredStrategy.scope用法及代码示例
- Python tf.distribute.experimental.MultiWorkerMirroredStrategy.reduce用法及代码示例
- Python tf.distribute.experimental.partitioners.MinSizePartitioner用法及代码示例
- Python tf.distribute.experimental.ParameterServerStrategy.experimental_distribute_values_from_function用法及代码示例
- Python tf.distribute.experimental.CentralStorageStrategy.experimental_distribute_values_from_function用法及代码示例
- Python tf.distribute.experimental.CentralStorageStrategy用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.distribute.experimental.TPUStrategy.gather。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。