当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.distribute.TPUStrategy.gather用法及代码示例


用法

gather(
    value, axis
)

参数

返回

  • Tensorvalue 沿 axis 维度跨副本的串联。

沿 axis 跨副本收集 value 到当前设备。

给定一个 tf.distribute.DistributedValuestf.Tensor 类对象 value ,此 API 沿 axis -th 维度跨副本收集并连接 value。结果被复制到 "current" 设备,该设备通常是运行程序的工作线程的 CPU。对于 tf.distribute.TPUStrategy ,它是第一个 TPU 主机。对于 multi-client tf.distribute.MultiWorkerMirroredStrategy ,这是每个工作人员的 CPU。

此 API 只能在 cross-replica 上下文中调用。对于副本上下文中的对应项,请参阅 tf.distribute.ReplicaContext.all_gather

注意:对于除 tf.distribute.TPUStrategy 之外的所有策略,不同副本上的输入 value 必须具有相同的等级,并且它们的形状必须在除 axis -th 维度之外的所有维度上相同。换句话说,它们的形状在d 维度上不能不同,其中d 不等于axis 参数。例如,给定一个 tf.distribute.DistributedValues 在两个副本上具有形状为 (1, 2, 3)(1, 3, 3) 的分量张量,您可以在其上调用 gather(..., axis=1, ...) ,但不能调用 gather(..., axis=0, ...)gather(..., axis=2, ...) 。但是,对于 tf.distribute.TPUStrategy.gather ,所有张量必须具有完全相同的秩和相同的形状。

注意:给定 tf.distribute.DistributedValues value ,其分量张量必须具有非零秩。否则,请考虑在收集它们之前使用tf.expand_dims

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
# A DistributedValues with component tensor of shape (2, 1) on each replica
distributed_values = strategy.experimental_distribute_values_from_function(lambda _:tf.identity(tf.constant([[1], [2]])))
@tf.function
def run():
  return strategy.gather(distributed_values, axis=0)
run()
<tf.Tensor:shape=(4, 1), dtype=int32, numpy=
array([[1],
       [2],
       [1],
       [2]], dtype=int32)>

考虑以下示例以获得更多组合:

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
single_tensor = tf.reshape(tf.range(6), shape=(1,2,3))
distributed_values = strategy.experimental_distribute_values_from_function(lambda _:tf.identity(single_tensor))
@tf.function
def run(axis):
  return strategy.gather(distributed_values, axis=axis)
axis=0
run(axis)
<tf.Tensor:shape=(4, 2, 3), dtype=int32, numpy=
array([[[0, 1, 2],
        [3, 4, 5]],
       [[0, 1, 2],
        [3, 4, 5]],
       [[0, 1, 2],
        [3, 4, 5]],
       [[0, 1, 2],
        [3, 4, 5]]], dtype=int32)>
axis=1
run(axis)
<tf.Tensor:shape=(1, 8, 3), dtype=int32, numpy=
array([[[0, 1, 2],
        [3, 4, 5],
        [0, 1, 2],
        [3, 4, 5],
        [0, 1, 2],
        [3, 4, 5],
        [0, 1, 2],
        [3, 4, 5]]], dtype=int32)>
axis=2
run(axis)
<tf.Tensor:shape=(1, 2, 12), dtype=int32, numpy=
array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2],
        [3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)>

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.distribute.TPUStrategy.gather。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。