当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.keras.utils.SidecarEvaluator用法及代码示例


为专门的评估任务设计的类。

用法

tf.keras.utils.SidecarEvaluator(
    model, data, checkpoint_dir, steps=None, max_evaluations=None, callbacks=None
)

参数

  • model 用于评估的模型。此处使用的模型对象应该是 tf.keras.Model ,并且应该与训练中使用的模型对象相同,其中 tf.keras.Model 是检查点。该模型应该在使用 SidecarEvaluator 之前编译一个或多个指标。
  • data 用于评估的输入数据。 SidecarEvaluator 支持 Keras model.evaluate 支持的所有数据类型作为输入数据 x ,例如 tf.data.Dataset
  • checkpoint_dir 保存检查点文件的目录。
  • steps 评估单个检查点文件时执行评估的步骤数。如果 None ,评估将继续,直到数据集用完。对于重复评估数据集,用户必须指定steps以避免无限评估循环。
  • max_evaluations 要评估的检查点文件的最大数量,对于SidecarEvaluator知道什么时候停止。评估程序将在评估以 ' 结尾的检查点文件路径后停止-'。如果使用tf.train.CheckpointManager.save对于保存检查点,第 k 个保存的检查点具有文件路径后缀 '-'(第一次保存时 k=1),如果在训练后每个 epoch 都保存检查点,则在第 k 个 epoch 保存的文件路径将以 ' 结尾-.因此,如果训练运行 n 个 epoch,并且评估器应该在训练完成后结束,则使用 n 作为此参数。请注意,这不一定等于总评估的数量,因为如果评估慢于创建检查点,则可能会跳过某些检查点。如果None , SidecarEvaluator 将无限期评估,用户必须自行终止评估程序。
  • callbacks 列表tf.keras.callbacks.Callback在评估期间应用的实例。看回调.

SidecarEvaluator 预计将在与训练集群不同的机器上的进程中运行。它旨在用于专用评估器,评估具有一个或多个工作人员执行训练的训练集群的度量结果,并保存检查点。

SidecarEvaluator API 与自定义训练循环 (CTL) 和要在训练集群中使用的 Keras Model.fit 兼容。使用 __init__ , SidecarEvaluator 提供的模型(带有编译指标)在找到尚未使用的检查点时重复执行评估 "epochs"。根据 steps 参数,一个 eval epoch 是对所有 eval 数据或最多一定数量的步骤(批次)的评估。请参阅下面的示例,了解训练程序应如何保存检查点以便被 SidecarEvaluator 识别。

由于在底层,SidecarEvaluator 使用 model.evaluate 进行评估,它还支持任意 Keras 回调。也就是说,如果提供了一个或多个回调,则它们的 on_test_batch_beginon_test_batch_end 方法在批处理的开始和结束时被调用,而它们的 on_test_beginon_test_end 在评估的开始和结束时被调用时代。请注意,SidecarEvaluator 可能会跳过一些检查点,因为它总是选择可用的最新检查点,并且在评估时期,可以从训练端产生多个检查点。

例子:

model = tf.keras.models.Sequential(...)
model.compile(metrics=tf.keras.metrics.SparseCategoricalAccuracy(
    name="eval_metrics"))
data = tf.data.Dataset.from_tensor_slices(...)

tf.keras.SidecarEvaluator(
    model=model,
    data=data,
    checkpoint_dir='/tmp/checkpoint_dir',  # dir for training-saved checkpoint
    steps=None,  # Eval until dataset is exhausted
    max_evaluations=None,  # The evaluation needs to be stopped manually
    callbacks=[tf.keras.callbacks.TensorBoard(log_dir='/tmp/log_dir')]
).start()

SidecarEvaluator.start 编写了一系列摘要文件,可以通过 tensorboard 可视化(提供网页链接):

$ tensorboard --logdir=/tmp/log_dir/validation
...
TensorBoard 2.4.0a0 at http://host:port (Press CTRL+C to quit)

如果训练集群使用 CTL,则 checkpoint_dir 应包含跟踪 modeloptimizer 的检查点,以满足 SidecarEvaluator 的期望。这可以通过 tf.train.Checkpointtf.train.CheckpointManager 来完成:

checkpoint_dir = ...  # Same `checkpoint_dir` supplied to `SidecarEvaluator`.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
    checkpoint, checkpoint_dir=..., max_to_keep=...)
checkpoint_manager.save()

如果训练集群使用 Keras Model.fit API,则应使用 tf.keras.callbacks.ModelCheckpoint,并带有 save_weights_only=True ,并且 filepath 应附加 'ckpt-{epoch}':

checkpoint_dir = ...  # Same `checkpoint_dir` supplied to `SidecarEvaluator`.
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=os.path.join(checkpoint_dir, 'ckpt-{epoch}'),
    save_weights_only=True)
model.fit(dataset, epochs, callbacks=[model_checkpoint])

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.utils.SidecarEvaluator。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。