調用時返回 tf.data.Dataset
的對象。
用法
tf.keras.utils.experimental.DatasetCreator(
dataset_fn, input_options=None
)
參數
-
dataset_fn
一個可調用的,它接受一個類型為tf.distribute.InputContext
的單個參數,用於批量大小計算和 cross-worker 輸入管道分片(如果兩者都不需要,可以在dataset_fn
中忽略InputContext
參數),並返回一個tf.data.Dataset
。 -
input_options
可選tf.distribute.InputOptions
,用於與分發一起使用時的特定選項,例如,是否將數據集元素預取到加速器設備內存或主機設備內存,以及在副本設備內存中預取緩衝區大小。如果不與分布式訓練一起使用,則無效。有關詳細信息,請參閱tf.distribute.InputOptions
。
tf.keras.utils.experimental.DatasetCreator
被指定為 x
或 tf.keras.Model.fit
中的輸入的支持類型。當使用返回 tf.data.Dataset
的可調用(帶有 input_context
參數)時,將此類的實例傳遞給 fit
。
model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
def dataset_fn(input_context):
global_batch_size = 64
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat()
dataset = dataset.shard(
input_context.num_input_pipelines, input_context.input_pipeline_id)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(2)
return dataset
input_options = tf.distribute.InputOptions(
experimental_fetch_to_device=True,
experimental_per_replica_buffer_size=2)
model.fit(tf.keras.utils.experimental.DatasetCreator(
dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)
Model.fit
與DatasetCreator
一起使用旨在適用於所有tf.distribute.Strategy
,隻要在模型創建時使用Strategy.scope
:
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
with strategy.scope():
model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
def dataset_fn(input_context):
...
input_options = ...
model.fit(tf.keras.utils.experimental.DatasetCreator(
dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)
注意:在Model.fit
中使用DatasetCreator
, steps_per_epoch
參數時,必須提供此類輸入的基數,因為無法推斷此類輸入的基數。
相關用法
- Python tf.keras.utils.custom_object_scope用法及代碼示例
- Python tf.keras.utils.deserialize_keras_object用法及代碼示例
- Python tf.keras.utils.array_to_img用法及代碼示例
- Python tf.keras.utils.get_file用法及代碼示例
- Python tf.keras.utils.set_random_seed用法及代碼示例
- Python tf.keras.utils.timeseries_dataset_from_array用法及代碼示例
- Python tf.keras.utils.plot_model用法及代碼示例
- Python tf.keras.utils.get_custom_objects用法及代碼示例
- Python tf.keras.utils.pack_x_y_sample_weight用法及代碼示例
- Python tf.keras.utils.img_to_array用法及代碼示例
- Python tf.keras.utils.image_dataset_from_directory用法及代碼示例
- Python tf.keras.utils.get_registered_object用法及代碼示例
- Python tf.keras.utils.SidecarEvaluator用法及代碼示例
- Python tf.keras.utils.to_categorical用法及代碼示例
- Python tf.keras.utils.load_img用法及代碼示例
- Python tf.keras.utils.text_dataset_from_directory用法及代碼示例
- Python tf.keras.utils.SequenceEnqueuer用法及代碼示例
- Python tf.keras.utils.unpack_x_y_sample_weight用法及代碼示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代碼示例
- Python tf.keras.metrics.Mean.merge_state用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.keras.utils.experimental.DatasetCreator。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。