當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.distribute.TPUStrategy.gather用法及代碼示例


用法

gather(
    value, axis
)

參數

返回

  • Tensorvalue 沿 axis 維度跨副本的串聯。

沿 axis 跨副本收集 value 到當前設備。

給定一個 tf.distribute.DistributedValuestf.Tensor 類對象 value ,此 API 沿 axis -th 維度跨副本收集並連接 value。結果被複製到 "current" 設備,該設備通常是運行程序的工作線程的 CPU。對於 tf.distribute.TPUStrategy ,它是第一個 TPU 主機。對於 multi-client tf.distribute.MultiWorkerMirroredStrategy ,這是每個工作人員的 CPU。

此 API 隻能在 cross-replica 上下文中調用。對於副本上下文中的對應項,請參閱 tf.distribute.ReplicaContext.all_gather

注意:對於除 tf.distribute.TPUStrategy 之外的所有策略,不同副本上的輸入 value 必須具有相同的等級,並且它們的形狀必須在除 axis -th 維度之外的所有維度上相同。換句話說,它們的形狀在d 維度上不能不同,其中d 不等於axis 參數。例如,給定一個 tf.distribute.DistributedValues 在兩個副本上具有形狀為 (1, 2, 3)(1, 3, 3) 的分量張量,您可以在其上調用 gather(..., axis=1, ...) ,但不能調用 gather(..., axis=0, ...)gather(..., axis=2, ...) 。但是,對於 tf.distribute.TPUStrategy.gather ,所有張量必須具有完全相同的秩和相同的形狀。

注意:給定 tf.distribute.DistributedValues value ,其分量張量必須具有非零秩。否則,請考慮在收集它們之前使用tf.expand_dims

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
# A DistributedValues with component tensor of shape (2, 1) on each replica
distributed_values = strategy.experimental_distribute_values_from_function(lambda _:tf.identity(tf.constant([[1], [2]])))
@tf.function
def run():
  return strategy.gather(distributed_values, axis=0)
run()
<tf.Tensor:shape=(4, 1), dtype=int32, numpy=
array([[1],
       [2],
       [1],
       [2]], dtype=int32)>

考慮以下示例以獲得更多組合:

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
single_tensor = tf.reshape(tf.range(6), shape=(1,2,3))
distributed_values = strategy.experimental_distribute_values_from_function(lambda _:tf.identity(single_tensor))
@tf.function
def run(axis):
  return strategy.gather(distributed_values, axis=axis)
axis=0
run(axis)
<tf.Tensor:shape=(4, 2, 3), dtype=int32, numpy=
array([[[0, 1, 2],
        [3, 4, 5]],
       [[0, 1, 2],
        [3, 4, 5]],
       [[0, 1, 2],
        [3, 4, 5]],
       [[0, 1, 2],
        [3, 4, 5]]], dtype=int32)>
axis=1
run(axis)
<tf.Tensor:shape=(1, 8, 3), dtype=int32, numpy=
array([[[0, 1, 2],
        [3, 4, 5],
        [0, 1, 2],
        [3, 4, 5],
        [0, 1, 2],
        [3, 4, 5],
        [0, 1, 2],
        [3, 4, 5]]], dtype=int32)>
axis=2
run(axis)
<tf.Tensor:shape=(1, 2, 12), dtype=int32, numpy=
array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2],
        [3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)>

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.distribute.TPUStrategy.gather。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。