当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.data.TFRecordDataset.shard用法及代码示例


用法

shard(
    num_shards, index, name=None
)

参数

  • num_shards 一个 tf.int64 标量 tf.Tensor ,表示并行运行的分片数。
  • index 一个 tf.int64 标量 tf.Tensor ,表示工作人员索引。
  • name (可选。) tf.data 操作的名称。

返回

  • Dataset 一个Dataset

抛出

  • InvalidArgumentError 如果num_shards或者index是非法值。

    注意:错误检查是尽最大努力完成的,并且不能保证在创建数据集时会发现错误。 (例如,提供占位符张量绕过了早期检查,而是在 session.run 调用期间导致错误。)

创建一个仅包含此数据集的 1/num_shardsDataset

shard 是确定性的。 A.shard(n, i) 生成的数据集将包含索引 mod n = i 的 A 的所有元素。

A = tf.data.Dataset.range(10)
B = A.shard(num_shards=3, index=0)
list(B.as_numpy_iterator())
[0, 3, 6, 9]
C = A.shard(num_shards=3, index=1)
list(C.as_numpy_iterator())
[1, 4, 7]
D = A.shard(num_shards=3, index=2)
list(D.as_numpy_iterator())
[2, 5, 8]

此数据集运算符在运行分布式训练时非常有用,因为它允许每个工作人员读取唯一的子集。

读取单个输入文件时,您可以按如下方式对元素进行分片:

d = tf.data.TFRecordDataset(input_file)
d = d.shard(num_workers, worker_index)
d = d.repeat(num_epochs)
d = d.shuffle(shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=num_map_threads)

重要警告:

  • 确保在使用任何随机操作符(例如 shuffle)之前进行分片。
  • 通常,最好在数据集管道的早期使用分片运算符。例如,当从一组 TFRecord 文件中读取数据时,在将数据集转换为输入样本之前进行分片。这避免了读取每个工作人员的每个文件。以下是完整管道中高效分片策略的示例:
d = Dataset.list_files(pattern)
d = d.shard(num_workers, worker_index)
d = d.repeat(num_epochs)
d = d.shuffle(shuffle_buffer_size)
d = d.interleave(tf.data.TFRecordDataset,
                 cycle_length=num_readers, block_length=1)
d = d.map(parser_fn, num_parallel_calls=num_map_threads)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.data.TFRecordDataset.shard。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。