当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.keras.utils.experimental.DatasetCreator用法及代码示例


调用时返回 tf.data.Dataset 的对象。

用法

tf.keras.utils.experimental.DatasetCreator(
    dataset_fn, input_options=None
)

参数

  • dataset_fn 一个可调用的,它接受一个类型为 tf.distribute.InputContext 的单个参数,用于批量大小计算和 cross-worker 输入管道分片(如果两者都不需要,可以在 dataset_fn 中忽略 InputContext 参数),并返回一个 tf.data.Dataset
  • input_options 可选 tf.distribute.InputOptions ,用于与分发一起使用时的特定选项,例如,是否将数据集元素预取到加速器设备内存或主机设备内存,以及在副本设备内存中预取缓冲区大小。如果不与分布式训练一起使用,则无效。有关详细信息,请参阅tf.distribute.InputOptions

tf.keras.utils.experimental.DatasetCreator 被指定为 xtf.keras.Model.fit 中的输入的支持类型。当使用返回 tf.data.Dataset 的可调用(带有 input_context 参数)时,将此类的实例传递给 fit

model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")

def dataset_fn(input_context):
  global_batch_size = 64
  batch_size = input_context.get_per_replica_batch_size(global_batch_size)
  dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat()
  dataset = dataset.shard(
      input_context.num_input_pipelines, input_context.input_pipeline_id)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(2)
  return dataset

input_options = tf.distribute.InputOptions(
    experimental_fetch_to_device=True,
    experimental_per_replica_buffer_size=2)
model.fit(tf.keras.utils.experimental.DatasetCreator(
    dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)

Model.fitDatasetCreator 一起使用旨在适用于所有tf.distribute.Strategy,只要在模型创建时使用Strategy.scope

strategy = tf.distribute.experimental.ParameterServerStrategy(
    cluster_resolver)
with strategy.scope():
  model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")

def dataset_fn(input_context):
  ...

input_options = ...
model.fit(tf.keras.utils.experimental.DatasetCreator(
    dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)

注意:Model.fit 中使用DatasetCreator , steps_per_epoch 参数时,必须提供此类输入的基数,因为无法推断此类输入的基数。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.utils.experimental.DatasetCreator。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。