當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.keras.utils.SidecarEvaluator用法及代碼示例


為專門的評估任務設計的類。

用法

tf.keras.utils.SidecarEvaluator(
    model, data, checkpoint_dir, steps=None, max_evaluations=None, callbacks=None
)

參數

  • model 用於評估的模型。此處使用的模型對象應該是 tf.keras.Model ,並且應該與訓練中使用的模型對象相同,其中 tf.keras.Model 是檢查點。該模型應該在使用 SidecarEvaluator 之前編譯一個或多個指標。
  • data 用於評估的輸入數據。 SidecarEvaluator 支持 Keras model.evaluate 支持的所有數據類型作為輸入數據 x ,例如 tf.data.Dataset
  • checkpoint_dir 保存檢查點文件的目錄。
  • steps 評估單個檢查點文件時執行評估的步驟數。如果 None ,評估將繼續,直到數據集用完。對於重複評估數據集,用戶必須指定steps以避免無限評估循環。
  • max_evaluations 要評估的檢查點文件的最大數量,對於SidecarEvaluator知道什麽時候停止。評估程序將在評估以 ' 結尾的檢查點文件路徑後停止-'。如果使用tf.train.CheckpointManager.save對於保存檢查點,第 k 個保存的檢查點具有文件路徑後綴 '-'(第一次保存時 k=1),如果在訓練後每個 epoch 都保存檢查點,則在第 k 個 epoch 保存的文件路徑將以 ' 結尾-.因此,如果訓練運行 n 個 epoch,並且評估器應該在訓練完成後結束,則使用 n 作為此參數。請注意,這不一定等於總評估的數量,因為如果評估慢於創建檢查點,則可能會跳過某些檢查點。如果None , SidecarEvaluator 將無限期評估,用戶必須自行終止評估程序。
  • callbacks 列表tf.keras.callbacks.Callback在評估期間應用的實例。看回調.

SidecarEvaluator 預計將在與訓練集群不同的機器上的進程中運行。它旨在用於專用評估器,評估具有一個或多個工作人員執行訓練的訓練集群的度量結果,並保存檢查點。

SidecarEvaluator API 與自定義訓練循環 (CTL) 和要在訓練集群中使用的 Keras Model.fit 兼容。使用 __init__ , SidecarEvaluator 提供的模型(帶有編譯指標)在找到尚未使用的檢查點時重複執行評估 "epochs"。根據 steps 參數,一個 eval epoch 是對所有 eval 數據或最多一定數量的步驟(批次)的評估。請參閱下麵的示例,了解訓練程序應如何保存檢查點以便被 SidecarEvaluator 識別。

由於在底層,SidecarEvaluator 使用 model.evaluate 進行評估,它還支持任意 Keras 回調。也就是說,如果提供了一個或多個回調,則它們的 on_test_batch_beginon_test_batch_end 方法在批處理的開始和結束時被調用,而它們的 on_test_beginon_test_end 在評估的開始和結束時被調用時代。請注意,SidecarEvaluator 可能會跳過一些檢查點,因為它總是選擇可用的最新檢查點,並且在評估時期,可以從訓練端產生多個檢查點。

例子:

model = tf.keras.models.Sequential(...)
model.compile(metrics=tf.keras.metrics.SparseCategoricalAccuracy(
    name="eval_metrics"))
data = tf.data.Dataset.from_tensor_slices(...)

tf.keras.SidecarEvaluator(
    model=model,
    data=data,
    checkpoint_dir='/tmp/checkpoint_dir',  # dir for training-saved checkpoint
    steps=None,  # Eval until dataset is exhausted
    max_evaluations=None,  # The evaluation needs to be stopped manually
    callbacks=[tf.keras.callbacks.TensorBoard(log_dir='/tmp/log_dir')]
).start()

SidecarEvaluator.start 編寫了一係列摘要文件,可以通過 tensorboard 可視化(提供網頁鏈接):

$ tensorboard --logdir=/tmp/log_dir/validation
...
TensorBoard 2.4.0a0 at http://host:port (Press CTRL+C to quit)

如果訓練集群使用 CTL,則 checkpoint_dir 應包含跟蹤 modeloptimizer 的檢查點,以滿足 SidecarEvaluator 的期望。這可以通過 tf.train.Checkpointtf.train.CheckpointManager 來完成:

checkpoint_dir = ...  # Same `checkpoint_dir` supplied to `SidecarEvaluator`.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
    checkpoint, checkpoint_dir=..., max_to_keep=...)
checkpoint_manager.save()

如果訓練集群使用 Keras Model.fit API,則應使用 tf.keras.callbacks.ModelCheckpoint,並帶有 save_weights_only=True ,並且 filepath 應附加 'ckpt-{epoch}':

checkpoint_dir = ...  # Same `checkpoint_dir` supplied to `SidecarEvaluator`.
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=os.path.join(checkpoint_dir, 'ckpt-{epoch}'),
    save_weights_only=True)
model.fit(dataset, epochs, callbacks=[model_checkpoint])

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.keras.utils.SidecarEvaluator。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。