当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.data.FixedLengthRecordDataset.snapshot用法及代码示例


用法

snapshot(
    path, compression='AUTO', reader_func=None, shard_func=None, name=None
)

参数

  • path 必需的。用于将快照存储/加载到/从的目录。
  • compression 可选的。应用于写入磁盘的快照的压缩类型。支持的选项是GZIP , SNAPPY , AUTO 或无。默认为 AUTO ,它尝试为数据集选择适当的压缩算法。
  • reader_func 可选的。控制如何从快照分片中读取数据的函数。
  • shard_func 可选的。控制在写入快照时如何分片数据的函数。
  • name (可选。) tf.data 操作的名称。

返回

  • 一个Dataset

用于持久化输入数据集输出的 API。

快照 API 允许用户透明地将其预处理管道的输出保存到磁盘,并在不同的训练运行中实现预处理数据。

此 API 可以整合重复的预处理步骤,并允许重新使用已处理的数据,以牺牲磁盘存储和网络带宽来释放更多宝贵的 CPU 资源和加速器计算时间。

https://github.com/tensorflow/community/blob/master/rfcs/20200107-tf-data-snapshot.md具有此函数的详细设计文档。

用户可以指定各种选项来控制快照的行为,包括通过将用户定义的函数传递给reader_funcshard_func 参数来读取和写入快照的方式。

shard_func 是用户指定的函数,用于将输入元素映射到快照分片。

用户可能希望指定此函数来控制快照文件应如何写入磁盘。下面是如何编写潜在的shard_func 的示例。

dataset = ...
dataset = dataset.enumerate()
dataset = dataset.snapshot("/path/to/snapshot/dir",
    shard_func=lambda x, y:x % NUM_SHARDS, ...)
dataset = dataset.map(lambda x, y:y)

reader_func 是用户指定的函数,它接受单个参数:(1) 数据集的数据集,每个数据集代表原始数据集元素的 "split"。输入数据集的基数与 shard_func 中指定的分片数量相匹配(见上文)。该函数应返回原始数据集元素的数据集。

用户可能希望指定此函数来控制应如何从磁盘读取快照文件,包括混洗量和并行度。

这是用户可以定义的标准阅读器函数的示例。此函数启用数据集洗牌和数据集的并行读取:

def user_reader_func(datasets):
  # shuffle the datasets splits
  datasets = datasets.shuffle(NUM_CORES)
  # read datasets in parallel and interleave their elements
  return datasets.interleave(lambda x:x, num_parallel_calls=AUTOTUNE)

dataset = dataset.snapshot("/path/to/snapshot/dir",
    reader_func=user_reader_func)

默认情况下,快照按系统上可用的核心数并行读取,但不会尝试对数据进行混洗。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.data.FixedLengthRecordDataset.snapshot。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。