调用时返回 tf.data.Dataset
的对象。
用法
tf.keras.utils.experimental.DatasetCreator(
dataset_fn, input_options=None
)
参数
-
dataset_fn
一个可调用的,它接受一个类型为tf.distribute.InputContext
的单个参数,用于批量大小计算和 cross-worker 输入管道分片(如果两者都不需要,可以在dataset_fn
中忽略InputContext
参数),并返回一个tf.data.Dataset
。 -
input_options
可选tf.distribute.InputOptions
,用于与分发一起使用时的特定选项,例如,是否将数据集元素预取到加速器设备内存或主机设备内存,以及在副本设备内存中预取缓冲区大小。如果不与分布式训练一起使用,则无效。有关详细信息,请参阅tf.distribute.InputOptions
。
tf.keras.utils.experimental.DatasetCreator
被指定为 x
或 tf.keras.Model.fit
中的输入的支持类型。当使用返回 tf.data.Dataset
的可调用(带有 input_context
参数)时,将此类的实例传递给 fit
。
model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
def dataset_fn(input_context):
global_batch_size = 64
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat()
dataset = dataset.shard(
input_context.num_input_pipelines, input_context.input_pipeline_id)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(2)
return dataset
input_options = tf.distribute.InputOptions(
experimental_fetch_to_device=True,
experimental_per_replica_buffer_size=2)
model.fit(tf.keras.utils.experimental.DatasetCreator(
dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)
Model.fit
与DatasetCreator
一起使用旨在适用于所有tf.distribute.Strategy
,只要在模型创建时使用Strategy.scope
:
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
with strategy.scope():
model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
def dataset_fn(input_context):
...
input_options = ...
model.fit(tf.keras.utils.experimental.DatasetCreator(
dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)
注意:在Model.fit
中使用DatasetCreator
, steps_per_epoch
参数时,必须提供此类输入的基数,因为无法推断此类输入的基数。
相关用法
- Python tf.keras.utils.custom_object_scope用法及代码示例
- Python tf.keras.utils.deserialize_keras_object用法及代码示例
- Python tf.keras.utils.array_to_img用法及代码示例
- Python tf.keras.utils.get_file用法及代码示例
- Python tf.keras.utils.set_random_seed用法及代码示例
- Python tf.keras.utils.timeseries_dataset_from_array用法及代码示例
- Python tf.keras.utils.plot_model用法及代码示例
- Python tf.keras.utils.get_custom_objects用法及代码示例
- Python tf.keras.utils.pack_x_y_sample_weight用法及代码示例
- Python tf.keras.utils.img_to_array用法及代码示例
- Python tf.keras.utils.image_dataset_from_directory用法及代码示例
- Python tf.keras.utils.get_registered_object用法及代码示例
- Python tf.keras.utils.SidecarEvaluator用法及代码示例
- Python tf.keras.utils.to_categorical用法及代码示例
- Python tf.keras.utils.load_img用法及代码示例
- Python tf.keras.utils.text_dataset_from_directory用法及代码示例
- Python tf.keras.utils.SequenceEnqueuer用法及代码示例
- Python tf.keras.utils.unpack_x_y_sample_weight用法及代码示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代码示例
- Python tf.keras.metrics.Mean.merge_state用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.utils.experimental.DatasetCreator。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。