用法
shard(
num_shards, index, name=None
)
參數
-
num_shards
一個tf.int64
標量tf.Tensor
,表示並行運行的分片數。 -
index
一個tf.int64
標量tf.Tensor
,表示工作人員索引。 -
name
(可選。) tf.data 操作的名稱。
返回
-
Dataset
一個Dataset
。
拋出
-
InvalidArgumentError
如果num_shards
或者index
是非法值。注意:錯誤檢查是盡最大努力完成的,並且不能保證在創建數據集時會發現錯誤。 (例如,提供占位符張量繞過了早期檢查,而是在 session.run 調用期間導致錯誤。)
創建一個僅包含此數據集的 1/num_shards
的 Dataset
。
shard
是確定性的。 A.shard(n, i)
生成的數據集將包含索引 mod n = i 的 A 的所有元素。
A = tf.data.Dataset.range(10)
B = A.shard(num_shards=3, index=0)
list(B.as_numpy_iterator())
[0, 3, 6, 9]
C = A.shard(num_shards=3, index=1)
list(C.as_numpy_iterator())
[1, 4, 7]
D = A.shard(num_shards=3, index=2)
list(D.as_numpy_iterator())
[2, 5, 8]
此數據集運算符在運行分布式訓練時非常有用,因為它允許每個工作人員讀取唯一的子集。
讀取單個輸入文件時,您可以按如下方式對元素進行分片:
d = tf.data.TFRecordDataset(input_file)
d = d.shard(num_workers, worker_index)
d = d.repeat(num_epochs)
d = d.shuffle(shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=num_map_threads)
重要警告:
- 確保在使用任何隨機操作符(例如 shuffle)之前進行分片。
- 通常,最好在數據集管道的早期使用分片運算符。例如,當從一組 TFRecord 文件中讀取數據時,在將數據集轉換為輸入樣本之前進行分片。這避免了讀取每個工作人員的每個文件。以下是完整管道中高效分片策略的示例:
d = Dataset.list_files(pattern)
d = d.shard(num_workers, worker_index)
d = d.repeat(num_epochs)
d = d.shuffle(shuffle_buffer_size)
d = d.interleave(tf.data.TFRecordDataset,
cycle_length=num_readers, block_length=1)
d = d.map(parser_fn, num_parallel_calls=num_map_threads)
相關用法
- Python tf.data.experimental.SqlDataset.shuffle用法及代碼示例
- Python tf.data.experimental.SqlDataset.snapshot用法及代碼示例
- Python tf.data.experimental.SqlDataset.skip用法及代碼示例
- Python tf.data.experimental.SqlDataset.scan用法及代碼示例
- Python tf.data.experimental.SqlDataset.sample_from_datasets用法及代碼示例
- Python tf.data.experimental.SqlDataset.enumerate用法及代碼示例
- Python tf.data.experimental.SqlDataset.zip用法及代碼示例
- Python tf.data.experimental.SqlDataset.get_single_element用法及代碼示例
- Python tf.data.experimental.SqlDataset.take用法及代碼示例
- Python tf.data.experimental.SqlDataset.random用法及代碼示例
- Python tf.data.experimental.SqlDataset.concatenate用法及代碼示例
- Python tf.data.experimental.SqlDataset.range用法及代碼示例
- Python tf.data.experimental.SqlDataset.from_tensor_slices用法及代碼示例
- Python tf.data.experimental.SqlDataset.from_generator用法及代碼示例
- Python tf.data.experimental.SqlDataset.rejection_resample用法及代碼示例
- Python tf.data.experimental.SqlDataset.from_tensors用法及代碼示例
- Python tf.data.experimental.SqlDataset.group_by_window用法及代碼示例
- Python tf.data.experimental.SqlDataset.unique用法及代碼示例
- Python tf.data.experimental.SqlDataset.bucket_by_sequence_length用法及代碼示例
- Python tf.data.experimental.SqlDataset.prefetch用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.data.experimental.SqlDataset.shard。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。