为专门的评估任务设计的类。
用法
tf.keras.utils.SidecarEvaluator(
model, data, checkpoint_dir, steps=None, max_evaluations=None, callbacks=None
)
参数
-
model
用于评估的模型。此处使用的模型对象应该是tf.keras.Model
,并且应该与训练中使用的模型对象相同,其中tf.keras.Model
是检查点。该模型应该在使用SidecarEvaluator
之前编译一个或多个指标。 -
data
用于评估的输入数据。SidecarEvaluator
支持 Kerasmodel.evaluate
支持的所有数据类型作为输入数据x
,例如tf.data.Dataset
。 -
checkpoint_dir
保存检查点文件的目录。 -
steps
评估单个检查点文件时执行评估的步骤数。如果None
,评估将继续,直到数据集用完。对于重复评估数据集,用户必须指定steps
以避免无限评估循环。 -
max_evaluations
要评估的检查点文件的最大数量,对于SidecarEvaluator
知道什么时候停止。评估程序将在评估以 ' 结尾的检查点文件路径后停止- '。如果使用 tf.train.CheckpointManager.save
对于保存检查点,第 k 个保存的检查点具有文件路径后缀 '- '(第一次保存时 k=1),如果在训练后每个 epoch 都保存检查点,则在第 k 个 epoch 保存的文件路径将以 ' 结尾 - .因此,如果训练运行 n 个 epoch,并且评估器应该在训练完成后结束,则使用 n 作为此参数。请注意,这不一定等于总评估的数量,因为如果评估慢于创建检查点,则可能会跳过某些检查点。如果 None
,SidecarEvaluator
将无限期评估,用户必须自行终止评估程序。 -
callbacks
列表tf.keras.callbacks.Callback在评估期间应用的实例。看回调.
SidecarEvaluator
预计将在与训练集群不同的机器上的进程中运行。它旨在用于专用评估器,评估具有一个或多个工作人员执行训练的训练集群的度量结果,并保存检查点。
SidecarEvaluator
API 与自定义训练循环 (CTL) 和要在训练集群中使用的 Keras Model.fit
兼容。使用 __init__
, SidecarEvaluator
提供的模型(带有编译指标)在找到尚未使用的检查点时重复执行评估 "epochs"。根据 steps
参数,一个 eval epoch 是对所有 eval 数据或最多一定数量的步骤(批次)的评估。请参阅下面的示例,了解训练程序应如何保存检查点以便被 SidecarEvaluator
识别。
由于在底层,SidecarEvaluator
使用 model.evaluate
进行评估,它还支持任意 Keras 回调。也就是说,如果提供了一个或多个回调,则它们的 on_test_batch_begin
和 on_test_batch_end
方法在批处理的开始和结束时被调用,而它们的 on_test_begin
和 on_test_end
在评估的开始和结束时被调用时代。请注意,SidecarEvaluator
可能会跳过一些检查点,因为它总是选择可用的最新检查点,并且在评估时期,可以从训练端产生多个检查点。
例子:
model = tf.keras.models.Sequential(...)
model.compile(metrics=tf.keras.metrics.SparseCategoricalAccuracy(
name="eval_metrics"))
data = tf.data.Dataset.from_tensor_slices(...)
tf.keras.SidecarEvaluator(
model=model,
data=data,
checkpoint_dir='/tmp/checkpoint_dir', # dir for training-saved checkpoint
steps=None, # Eval until dataset is exhausted
max_evaluations=None, # The evaluation needs to be stopped manually
callbacks=[tf.keras.callbacks.TensorBoard(log_dir='/tmp/log_dir')]
).start()
SidecarEvaluator.start
编写了一系列摘要文件,可以通过 tensorboard 可视化(提供网页链接):
$ tensorboard --logdir=/tmp/log_dir/validation
...
TensorBoard 2.4.0a0 at http://host:port (Press CTRL+C to quit)
如果训练集群使用 CTL,则 checkpoint_dir
应包含跟踪 model
和 optimizer
的检查点,以满足 SidecarEvaluator
的期望。这可以通过 tf.train.Checkpoint
和 tf.train.CheckpointManager
来完成:
checkpoint_dir = ... # Same `checkpoint_dir` supplied to `SidecarEvaluator`.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, checkpoint_dir=..., max_to_keep=...)
checkpoint_manager.save()
如果训练集群使用 Keras Model.fit
API,则应使用 tf.keras.callbacks.ModelCheckpoint
,并带有 save_weights_only=True
,并且 filepath
应附加 'ckpt-{epoch}':
checkpoint_dir = ... # Same `checkpoint_dir` supplied to `SidecarEvaluator`.
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(checkpoint_dir, 'ckpt-{epoch}'),
save_weights_only=True)
model.fit(dataset, epochs, callbacks=[model_checkpoint])
相关用法
- Python tf.keras.utils.SequenceEnqueuer用法及代码示例
- Python tf.keras.utils.custom_object_scope用法及代码示例
- Python tf.keras.utils.deserialize_keras_object用法及代码示例
- Python tf.keras.utils.array_to_img用法及代码示例
- Python tf.keras.utils.get_file用法及代码示例
- Python tf.keras.utils.experimental.DatasetCreator用法及代码示例
- Python tf.keras.utils.set_random_seed用法及代码示例
- Python tf.keras.utils.timeseries_dataset_from_array用法及代码示例
- Python tf.keras.utils.plot_model用法及代码示例
- Python tf.keras.utils.get_custom_objects用法及代码示例
- Python tf.keras.utils.pack_x_y_sample_weight用法及代码示例
- Python tf.keras.utils.img_to_array用法及代码示例
- Python tf.keras.utils.image_dataset_from_directory用法及代码示例
- Python tf.keras.utils.get_registered_object用法及代码示例
- Python tf.keras.utils.to_categorical用法及代码示例
- Python tf.keras.utils.load_img用法及代码示例
- Python tf.keras.utils.text_dataset_from_directory用法及代码示例
- Python tf.keras.utils.unpack_x_y_sample_weight用法及代码示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代码示例
- Python tf.keras.metrics.Mean.merge_state用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.utils.SidecarEvaluator。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。