用法
shard(
num_shards, index, name=None
)
参数
-
num_shards
一个tf.int64
标量tf.Tensor
,表示并行运行的分片数。 -
index
一个tf.int64
标量tf.Tensor
,表示工作人员索引。 -
name
(可选。) tf.data 操作的名称。
返回
-
Dataset
一个Dataset
。
抛出
-
InvalidArgumentError
如果num_shards
或者index
是非法值。注意:错误检查是尽最大努力完成的,并且不能保证在创建数据集时会发现错误。 (例如,提供占位符张量绕过了早期检查,而是在 session.run 调用期间导致错误。)
创建一个仅包含此数据集的 1/num_shards
的 Dataset
。
shard
是确定性的。 A.shard(n, i)
生成的数据集将包含索引 mod n = i 的 A 的所有元素。
A = tf.data.Dataset.range(10)
B = A.shard(num_shards=3, index=0)
list(B.as_numpy_iterator())
[0, 3, 6, 9]
C = A.shard(num_shards=3, index=1)
list(C.as_numpy_iterator())
[1, 4, 7]
D = A.shard(num_shards=3, index=2)
list(D.as_numpy_iterator())
[2, 5, 8]
此数据集运算符在运行分布式训练时非常有用,因为它允许每个工作人员读取唯一的子集。
读取单个输入文件时,您可以按如下方式对元素进行分片:
d = tf.data.TFRecordDataset(input_file)
d = d.shard(num_workers, worker_index)
d = d.repeat(num_epochs)
d = d.shuffle(shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=num_map_threads)
重要警告:
- 确保在使用任何随机操作符(例如 shuffle)之前进行分片。
- 通常,最好在数据集管道的早期使用分片运算符。例如,当从一组 TFRecord 文件中读取数据时,在将数据集转换为输入样本之前进行分片。这避免了读取每个工作人员的每个文件。以下是完整管道中高效分片策略的示例:
d = Dataset.list_files(pattern)
d = d.shard(num_workers, worker_index)
d = d.repeat(num_epochs)
d = d.shuffle(shuffle_buffer_size)
d = d.interleave(tf.data.TFRecordDataset,
cycle_length=num_readers, block_length=1)
d = d.map(parser_fn, num_parallel_calls=num_map_threads)
相关用法
- Python tf.data.TFRecordDataset.shuffle用法及代码示例
- Python tf.data.TFRecordDataset.scan用法及代码示例
- Python tf.data.TFRecordDataset.snapshot用法及代码示例
- Python tf.data.TFRecordDataset.sample_from_datasets用法及代码示例
- Python tf.data.TFRecordDataset.skip用法及代码示例
- Python tf.data.TFRecordDataset.filter用法及代码示例
- Python tf.data.TFRecordDataset.random用法及代码示例
- Python tf.data.TFRecordDataset.zip用法及代码示例
- Python tf.data.TFRecordDataset.choose_from_datasets用法及代码示例
- Python tf.data.TFRecordDataset.apply用法及代码示例
- Python tf.data.TFRecordDataset.rejection_resample用法及代码示例
- Python tf.data.TFRecordDataset.flat_map用法及代码示例
- Python tf.data.TFRecordDataset.unique用法及代码示例
- Python tf.data.TFRecordDataset.cardinality用法及代码示例
- Python tf.data.TFRecordDataset.group_by_window用法及代码示例
- Python tf.data.TFRecordDataset.cache用法及代码示例
- Python tf.data.TFRecordDataset.range用法及代码示例
- Python tf.data.TFRecordDataset.reduce用法及代码示例
- Python tf.data.TFRecordDataset.take_while用法及代码示例
- Python tf.data.TFRecordDataset.with_options用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.data.TFRecordDataset.shard。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。