當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.keras.utils.experimental.DatasetCreator用法及代碼示例


調用時返回 tf.data.Dataset 的對象。

用法

tf.keras.utils.experimental.DatasetCreator(
    dataset_fn, input_options=None
)

參數

  • dataset_fn 一個可調用的,它接受一個類型為 tf.distribute.InputContext 的單個參數,用於批量大小計算和 cross-worker 輸入管道分片(如果兩者都不需要,可以在 dataset_fn 中忽略 InputContext 參數),並返回一個 tf.data.Dataset
  • input_options 可選 tf.distribute.InputOptions ,用於與分發一起使用時的特定選項,例如,是否將數據集元素預取到加速器設備內存或主機設備內存,以及在副本設備內存中預取緩衝區大小。如果不與分布式訓練一起使用,則無效。有關詳細信息,請參閱tf.distribute.InputOptions

tf.keras.utils.experimental.DatasetCreator 被指定為 xtf.keras.Model.fit 中的輸入的支持類型。當使用返回 tf.data.Dataset 的可調用(帶有 input_context 參數)時,將此類的實例傳遞給 fit

model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")

def dataset_fn(input_context):
  global_batch_size = 64
  batch_size = input_context.get_per_replica_batch_size(global_batch_size)
  dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat()
  dataset = dataset.shard(
      input_context.num_input_pipelines, input_context.input_pipeline_id)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(2)
  return dataset

input_options = tf.distribute.InputOptions(
    experimental_fetch_to_device=True,
    experimental_per_replica_buffer_size=2)
model.fit(tf.keras.utils.experimental.DatasetCreator(
    dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)

Model.fitDatasetCreator 一起使用旨在適用於所有tf.distribute.Strategy,隻要在模型創建時使用Strategy.scope

strategy = tf.distribute.experimental.ParameterServerStrategy(
    cluster_resolver)
with strategy.scope():
  model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")

def dataset_fn(input_context):
  ...

input_options = ...
model.fit(tf.keras.utils.experimental.DatasetCreator(
    dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)

注意:Model.fit 中使用DatasetCreator , steps_per_epoch 參數時,必須提供此類輸入的基數,因為無法推斷此類輸入的基數。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.keras.utils.experimental.DatasetCreator。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。