當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.data.experimental.RandomDataset.shard用法及代碼示例


用法

shard(
    num_shards, index, name=None
)

參數

  • num_shards 一個 tf.int64 標量 tf.Tensor ,表示並行運行的分片數。
  • index 一個 tf.int64 標量 tf.Tensor ,表示工作人員索引。
  • name (可選。) tf.data 操作的名稱。

返回

  • Dataset 一個Dataset

拋出

  • InvalidArgumentError 如果num_shards或者index是非法值。

    注意:錯誤檢查是盡最大努力完成的,並且不能保證在創建數據集時會發現錯誤。 (例如,提供占位符張量繞過了早期檢查,而是在 session.run 調用期間導致錯誤。)

創建一個僅包含此數據集的 1/num_shardsDataset

shard 是確定性的。 A.shard(n, i) 生成的數據集將包含索引 mod n = i 的 A 的所有元素。

A = tf.data.Dataset.range(10)
B = A.shard(num_shards=3, index=0)
list(B.as_numpy_iterator())
[0, 3, 6, 9]
C = A.shard(num_shards=3, index=1)
list(C.as_numpy_iterator())
[1, 4, 7]
D = A.shard(num_shards=3, index=2)
list(D.as_numpy_iterator())
[2, 5, 8]

此數據集運算符在運行分布式訓練時非常有用,因為它允許每個工作人員讀取唯一的子集。

讀取單個輸入文件時,您可以按如下方式對元素進行分片:

d = tf.data.TFRecordDataset(input_file)
d = d.shard(num_workers, worker_index)
d = d.repeat(num_epochs)
d = d.shuffle(shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=num_map_threads)

重要警告:

  • 確保在使用任何隨機操作符(例如 shuffle)之前進行分片。
  • 通常,最好在數據集管道的早期使用分片運算符。例如,當從一組 TFRecord 文件中讀取數據時,在將數據集轉換為輸入樣本之前進行分片。這避免了讀取每個工作人員的每個文件。以下是完整管道中高效分片策略的示例:
d = Dataset.list_files(pattern)
d = d.shard(num_workers, worker_index)
d = d.repeat(num_epochs)
d = d.shuffle(shuffle_buffer_size)
d = d.interleave(tf.data.TFRecordDataset,
                 cycle_length=num_readers, block_length=1)
d = d.map(parser_fn, num_parallel_calls=num_map_threads)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.data.experimental.RandomDataset.shard。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。