当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.data.experimental.save用法及代码示例


保存给定数据集的内容。

用法

tf.data.experimental.save(
    dataset, path, compression=None, shard_func=None, checkpoint_args=None
)

参数

  • dataset 要保存的数据集。
  • path 必需的。用于保存数据集的目录。
  • compression 可选的。写入数据时用于压缩数据的算法。支持的选项是 GZIPNONE 。默认为 NONE
  • shard_func 可选的。控制数据集元素到文件分片的映射的函数。该函数预计将输入数据集的元素映射到 int64 分片 ID。如果存在,该函数将被跟踪并作为图形计算执行。
  • checkpoint_args 检查点的可选参数将被传递到 tf.train.CheckpointManager 。如果未指定checkpoint_args,则不会执行检查点。 save() 实现在内部创建 tf.train.Checkpoint 对象,因此用户不应在 checkpoint_args 中设置 checkpoint 参数。

抛出

  • ValueError 如果 checkpoint 被传递到 checkpoint_args

示例用法:

import tempfile
path = os.path.join(tempfile.gettempdir(), "saved_data")
# Save a dataset
dataset = tf.data.Dataset.range(2)
tf.data.experimental.save(dataset, path)
new_dataset = tf.data.experimental.load(path)
for elem in new_dataset:
  print(elem)
tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor(1, shape=(), dtype=int64)

保存的数据集保存在多个文件"shards"中。默认情况下,数据集输出以round-robin 方式划分为分片,但可以通过shard_func 函数指定自定义分片。例如,您可以将数据集保存为使用单个分片,如下所示:

dataset = make_dataset()
def custom_shard_func(element):
  return 0
dataset = tf.data.experimental.save(
    path="/path/to/data", ..., shard_func=custom_shard_func)

要启用检查点,请将 checkpoint_args 传递给 save 方法,如下所示:

dataset = tf.data.Dataset.range(100)
save_dir = "..."
checkpoint_prefix = "..."
step_counter = tf.Variable(0, trainable=False)
checkpoint_args = {
  "checkpoint_interval":50,
  "step_counter":step_counter,
  "directory":checkpoint_prefix,
  "max_to_keep":20,
}
dataset.save(dataset, save_dir, checkpoint_args=checkpoint_args)

注意:用于保存数据集的目录布局和文件格式被视为实现细节,可能会发生变化。因此,通过 tf.data.experimental.save 保存的数据集只能通过 tf.data.experimental.load 使用,保证向后兼容。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.data.experimental.save。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。