當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.data.experimental.save用法及代碼示例


保存給定數據集的內容。

用法

tf.data.experimental.save(
    dataset, path, compression=None, shard_func=None, checkpoint_args=None
)

參數

  • dataset 要保存的數據集。
  • path 必需的。用於保存數據集的目錄。
  • compression 可選的。寫入數據時用於壓縮數據的算法。支持的選項是 GZIPNONE 。默認為 NONE
  • shard_func 可選的。控製數據集元素到文件分片的映射的函數。該函數預計將輸入數據集的元素映射到 int64 分片 ID。如果存在,該函數將被跟蹤並作為圖形計算執行。
  • checkpoint_args 檢查點的可選參數將被傳遞到 tf.train.CheckpointManager 。如果未指定checkpoint_args,則不會執行檢查點。 save() 實現在內部創建 tf.train.Checkpoint 對象,因此用戶不應在 checkpoint_args 中設置 checkpoint 參數。

拋出

  • ValueError 如果 checkpoint 被傳遞到 checkpoint_args

示例用法:

import tempfile
path = os.path.join(tempfile.gettempdir(), "saved_data")
# Save a dataset
dataset = tf.data.Dataset.range(2)
tf.data.experimental.save(dataset, path)
new_dataset = tf.data.experimental.load(path)
for elem in new_dataset:
  print(elem)
tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor(1, shape=(), dtype=int64)

保存的數據集保存在多個文件"shards"中。默認情況下,數據集輸出以round-robin 方式劃分為分片,但可以通過shard_func 函數指定自定義分片。例如,您可以將數據集保存為使用單個分片,如下所示:

dataset = make_dataset()
def custom_shard_func(element):
  return 0
dataset = tf.data.experimental.save(
    path="/path/to/data", ..., shard_func=custom_shard_func)

要啟用檢查點,請將 checkpoint_args 傳遞給 save 方法,如下所示:

dataset = tf.data.Dataset.range(100)
save_dir = "..."
checkpoint_prefix = "..."
step_counter = tf.Variable(0, trainable=False)
checkpoint_args = {
  "checkpoint_interval":50,
  "step_counter":step_counter,
  "directory":checkpoint_prefix,
  "max_to_keep":20,
}
dataset.save(dataset, save_dir, checkpoint_args=checkpoint_args)

注意:用於保存數據集的目錄布局和文件格式被視為實現細節,可能會發生變化。因此,通過 tf.data.experimental.save 保存的數據集隻能通過 tf.data.experimental.load 使用,保證向後兼容。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.data.experimental.save。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。