當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.compat.v1.distribute.OneDeviceStrategy.experimental_distribute_dataset用法及代碼示例


用法

experimental_distribute_dataset(
    dataset, options=None
)

參數

返回

tf.data.Dataset 創建 tf.distribute.DistributedDataset

返回的tf.distribute.DistributedDataset 可以類似於常規數據集進行迭代。注意:用戶不能向 tf.distribute.DistributedDataset 添加任何更多的轉換。您隻能創建一個迭代器或檢查它生成的數據的tf.TypeSpec。請參閱tf.distribute.DistributedDataset 的 API 文檔以了解更多信息。

下麵是一個例子:

global_batch_size = 2
# Passing the devices is optional.
strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
# Create a dataset
dataset = tf.data.Dataset.range(4).batch(global_batch_size)
# Distribute that dataset
dist_dataset = strategy.experimental_distribute_dataset(dataset)
@tf.function
def replica_fn(input):
  return input*2
result = []
# Iterate over the `tf.distribute.DistributedDataset`
for x in dist_dataset:
  # process dataset elements
  result.append(strategy.run(replica_fn, args=(x,)))
print(result)
[PerReplica:{
  0:<tf.Tensor:shape=(1,), dtype=int64, numpy=array([0])>,
  1:<tf.Tensor:shape=(1,), dtype=int64, numpy=array([2])>
}, PerReplica:{
  0:<tf.Tensor:shape=(1,), dtype=int64, numpy=array([4])>,
  1:<tf.Tensor:shape=(1,), dtype=int64, numpy=array([6])>
}]

此方法背後發生的三個關鍵操作是批處理、分片和預取。

在上麵的代碼片段中,datasetglobal_batch_size 批處理,並在其上調用 experimental_distribute_datasetdataset 重新批處理為等於全局批處理大小除以同步副本數的新批處理大小。我們使用 Pythonic for 循環遍曆它。 x 是一個 tf.distribute.DistributedValues 包含所有副本的數據,每個副本獲取新批量大小的數據。 tf.distribute.Strategy.run 將負責將 x 中的正確 per-replica 數據饋送到在每個副本上執行的右側 replica_fn

分片包含跨多個工作人員和每個工作人員內部的自動分片。首先,在multi-worker 分布式訓練中(即當您使用tf.distribute.experimental.MultiWorkerMirroredStrategytf.distribute.TPUStrategy 時),在一組workers 上自動分片數據集意味著每個worker 被分配整個數據集的一個子集(如果正確的tf.data.experimental.AutoShardPolicy 是放)。這是為了確保在每個步驟中,每個工作人員將處理非重疊數據集元素的全局批量大小。自動分片有幾個不同的選項,可以使用 tf.data.experimental.DistributeOptions 指定。然後,在每個工作人員內進行分片意味著該方法將在所有工作人員設備之間拆分數據(如果存在多個)。無論multi-worker 自動分片如何,都會發生這種情況。

注意:對於跨多個工作人員的自動分片,默認模式是 tf.data.experimental.AutoShardPolicy.AUTO 。如果數據集是從讀取器數據集(例如 tf.data.TFRecordDatasettf.data.TextLineDataset 等)創建的,則此模式將嘗試按文件對輸入數據集進行分片,或者按數據對數據集進行分片,其中每個工作人員將讀取整個數據集,並且隻處理分配給它的分片。但是,如果每個工作人員的輸入文件少於一個,我們建議您通過將 tf.data.experimental.DistributeOptions.auto_shard_policy 設置為 tf.data.experimental.AutoShardPolicy.OFF 來禁用跨工作人員的數據集自動分片。

默認情況下,此方法在用戶提供的tf.data.Dataset 實例的末尾添加預取轉換。預取轉換的參數 buffer_size 等於同步的副本數。

如果上述批量拆分和數據集分片邏輯不可取,請使用tf.distribute.Strategy.distribute_datasets_from_function相反,它不會為您執行任何自動批處理或分片。

注意:如果您使用的是 TPUStrategy,則使用時工作人員處理數據的順序tf.distribute.Strategy.experimental_distribute_dataset或者tf.distribute.Strategy.distribute_datasets_from_function不保證。如果您正在使用,這通常是必需的tf.distribute規模預測。但是,您可以為批處理中的每個元素插入索引並相應地對輸出進行排序。參考這個片段有關如何排序輸出的示例。

注意:tf.distribute.experimental_distribute_datasettf.distribute.distribute_datasets_from_function 目前不支持有狀態數據集轉換。當前忽略數據集可能具有的任何有狀態操作。例如,如果您的數據集有一個使用tf.random.uniform 旋轉圖像的map_fn,那麽您有一個依賴於執行python 進程的本地機器上的狀態(即隨機種子)的數據集圖。

有關此方法的更多用法和屬性的教程,請參閱分布式輸入教程。如果您對最後部分批處理感興趣,請閱讀本節。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.distribute.OneDeviceStrategy.experimental_distribute_dataset。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。