為專門的評估任務設計的類。
用法
tf.keras.utils.SidecarEvaluator(
model, data, checkpoint_dir, steps=None, max_evaluations=None, callbacks=None
)參數
-
model用於評估的模型。此處使用的模型對象應該是tf.keras.Model,並且應該與訓練中使用的模型對象相同,其中tf.keras.Model是檢查點。該模型應該在使用SidecarEvaluator之前編譯一個或多個指標。 -
data用於評估的輸入數據。SidecarEvaluator支持 Kerasmodel.evaluate支持的所有數據類型作為輸入數據x,例如tf.data.Dataset。 -
checkpoint_dir保存檢查點文件的目錄。 -
steps評估單個檢查點文件時執行評估的步驟數。如果None,評估將繼續,直到數據集用完。對於重複評估數據集,用戶必須指定steps以避免無限評估循環。 -
max_evaluations要評估的檢查點文件的最大數量,對於SidecarEvaluator知道什麽時候停止。評估程序將在評估以 ' 結尾的檢查點文件路徑後停止- '。如果使用 tf.train.CheckpointManager.save對於保存檢查點,第 k 個保存的檢查點具有文件路徑後綴 '- '(第一次保存時 k=1),如果在訓練後每個 epoch 都保存檢查點,則在第 k 個 epoch 保存的文件路徑將以 ' 結尾 - .因此,如果訓練運行 n 個 epoch,並且評估器應該在訓練完成後結束,則使用 n 作為此參數。請注意,這不一定等於總評估的數量,因為如果評估慢於創建檢查點,則可能會跳過某些檢查點。如果 None,SidecarEvaluator將無限期評估,用戶必須自行終止評估程序。 -
callbacks列表tf.keras.callbacks.Callback在評估期間應用的實例。看回調.
SidecarEvaluator 預計將在與訓練集群不同的機器上的進程中運行。它旨在用於專用評估器,評估具有一個或多個工作人員執行訓練的訓練集群的度量結果,並保存檢查點。
SidecarEvaluator API 與自定義訓練循環 (CTL) 和要在訓練集群中使用的 Keras Model.fit 兼容。使用 __init__ , SidecarEvaluator 提供的模型(帶有編譯指標)在找到尚未使用的檢查點時重複執行評估 "epochs"。根據 steps 參數,一個 eval epoch 是對所有 eval 數據或最多一定數量的步驟(批次)的評估。請參閱下麵的示例,了解訓練程序應如何保存檢查點以便被 SidecarEvaluator 識別。
由於在底層,SidecarEvaluator 使用 model.evaluate 進行評估,它還支持任意 Keras 回調。也就是說,如果提供了一個或多個回調,則它們的 on_test_batch_begin 和 on_test_batch_end 方法在批處理的開始和結束時被調用,而它們的 on_test_begin 和 on_test_end 在評估的開始和結束時被調用時代。請注意,SidecarEvaluator 可能會跳過一些檢查點,因為它總是選擇可用的最新檢查點,並且在評估時期,可以從訓練端產生多個檢查點。
例子:
model = tf.keras.models.Sequential(...)
model.compile(metrics=tf.keras.metrics.SparseCategoricalAccuracy(
name="eval_metrics"))
data = tf.data.Dataset.from_tensor_slices(...)
tf.keras.SidecarEvaluator(
model=model,
data=data,
checkpoint_dir='/tmp/checkpoint_dir', # dir for training-saved checkpoint
steps=None, # Eval until dataset is exhausted
max_evaluations=None, # The evaluation needs to be stopped manually
callbacks=[tf.keras.callbacks.TensorBoard(log_dir='/tmp/log_dir')]
).start()
SidecarEvaluator.start 編寫了一係列摘要文件,可以通過 tensorboard 可視化(提供網頁鏈接):
$ tensorboard --logdir=/tmp/log_dir/validation
...
TensorBoard 2.4.0a0 at http://host:port (Press CTRL+C to quit)
如果訓練集群使用 CTL,則 checkpoint_dir 應包含跟蹤 model 和 optimizer 的檢查點,以滿足 SidecarEvaluator 的期望。這可以通過 tf.train.Checkpoint 和 tf.train.CheckpointManager 來完成:
checkpoint_dir = ... # Same `checkpoint_dir` supplied to `SidecarEvaluator`.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, checkpoint_dir=..., max_to_keep=...)
checkpoint_manager.save()
如果訓練集群使用 Keras Model.fit API,則應使用 tf.keras.callbacks.ModelCheckpoint,並帶有 save_weights_only=True ,並且 filepath 應附加 'ckpt-{epoch}':
checkpoint_dir = ... # Same `checkpoint_dir` supplied to `SidecarEvaluator`.
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(checkpoint_dir, 'ckpt-{epoch}'),
save_weights_only=True)
model.fit(dataset, epochs, callbacks=[model_checkpoint])
相關用法
- Python tf.keras.utils.SequenceEnqueuer用法及代碼示例
- Python tf.keras.utils.custom_object_scope用法及代碼示例
- Python tf.keras.utils.deserialize_keras_object用法及代碼示例
- Python tf.keras.utils.array_to_img用法及代碼示例
- Python tf.keras.utils.get_file用法及代碼示例
- Python tf.keras.utils.experimental.DatasetCreator用法及代碼示例
- Python tf.keras.utils.set_random_seed用法及代碼示例
- Python tf.keras.utils.timeseries_dataset_from_array用法及代碼示例
- Python tf.keras.utils.plot_model用法及代碼示例
- Python tf.keras.utils.get_custom_objects用法及代碼示例
- Python tf.keras.utils.pack_x_y_sample_weight用法及代碼示例
- Python tf.keras.utils.img_to_array用法及代碼示例
- Python tf.keras.utils.image_dataset_from_directory用法及代碼示例
- Python tf.keras.utils.get_registered_object用法及代碼示例
- Python tf.keras.utils.to_categorical用法及代碼示例
- Python tf.keras.utils.load_img用法及代碼示例
- Python tf.keras.utils.text_dataset_from_directory用法及代碼示例
- Python tf.keras.utils.unpack_x_y_sample_weight用法及代碼示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代碼示例
- Python tf.keras.metrics.Mean.merge_state用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.keras.utils.SidecarEvaluator。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。
