為專門的評估任務設計的類。
用法
tf.keras.utils.SidecarEvaluator(
model, data, checkpoint_dir, steps=None, max_evaluations=None, callbacks=None
)
參數
-
model
用於評估的模型。此處使用的模型對象應該是tf.keras.Model
,並且應該與訓練中使用的模型對象相同,其中tf.keras.Model
是檢查點。該模型應該在使用SidecarEvaluator
之前編譯一個或多個指標。 -
data
用於評估的輸入數據。SidecarEvaluator
支持 Kerasmodel.evaluate
支持的所有數據類型作為輸入數據x
,例如tf.data.Dataset
。 -
checkpoint_dir
保存檢查點文件的目錄。 -
steps
評估單個檢查點文件時執行評估的步驟數。如果None
,評估將繼續,直到數據集用完。對於重複評估數據集,用戶必須指定steps
以避免無限評估循環。 -
max_evaluations
要評估的檢查點文件的最大數量,對於SidecarEvaluator
知道什麽時候停止。評估程序將在評估以 ' 結尾的檢查點文件路徑後停止- '。如果使用 tf.train.CheckpointManager.save
對於保存檢查點,第 k 個保存的檢查點具有文件路徑後綴 '- '(第一次保存時 k=1),如果在訓練後每個 epoch 都保存檢查點,則在第 k 個 epoch 保存的文件路徑將以 ' 結尾 - .因此,如果訓練運行 n 個 epoch,並且評估器應該在訓練完成後結束,則使用 n 作為此參數。請注意,這不一定等於總評估的數量,因為如果評估慢於創建檢查點,則可能會跳過某些檢查點。如果 None
,SidecarEvaluator
將無限期評估,用戶必須自行終止評估程序。 -
callbacks
列表tf.keras.callbacks.Callback在評估期間應用的實例。看回調.
SidecarEvaluator
預計將在與訓練集群不同的機器上的進程中運行。它旨在用於專用評估器,評估具有一個或多個工作人員執行訓練的訓練集群的度量結果,並保存檢查點。
SidecarEvaluator
API 與自定義訓練循環 (CTL) 和要在訓練集群中使用的 Keras Model.fit
兼容。使用 __init__
, SidecarEvaluator
提供的模型(帶有編譯指標)在找到尚未使用的檢查點時重複執行評估 "epochs"。根據 steps
參數,一個 eval epoch 是對所有 eval 數據或最多一定數量的步驟(批次)的評估。請參閱下麵的示例,了解訓練程序應如何保存檢查點以便被 SidecarEvaluator
識別。
由於在底層,SidecarEvaluator
使用 model.evaluate
進行評估,它還支持任意 Keras 回調。也就是說,如果提供了一個或多個回調,則它們的 on_test_batch_begin
和 on_test_batch_end
方法在批處理的開始和結束時被調用,而它們的 on_test_begin
和 on_test_end
在評估的開始和結束時被調用時代。請注意,SidecarEvaluator
可能會跳過一些檢查點,因為它總是選擇可用的最新檢查點,並且在評估時期,可以從訓練端產生多個檢查點。
例子:
model = tf.keras.models.Sequential(...)
model.compile(metrics=tf.keras.metrics.SparseCategoricalAccuracy(
name="eval_metrics"))
data = tf.data.Dataset.from_tensor_slices(...)
tf.keras.SidecarEvaluator(
model=model,
data=data,
checkpoint_dir='/tmp/checkpoint_dir', # dir for training-saved checkpoint
steps=None, # Eval until dataset is exhausted
max_evaluations=None, # The evaluation needs to be stopped manually
callbacks=[tf.keras.callbacks.TensorBoard(log_dir='/tmp/log_dir')]
).start()
SidecarEvaluator.start
編寫了一係列摘要文件,可以通過 tensorboard 可視化(提供網頁鏈接):
$ tensorboard --logdir=/tmp/log_dir/validation
...
TensorBoard 2.4.0a0 at http://host:port (Press CTRL+C to quit)
如果訓練集群使用 CTL,則 checkpoint_dir
應包含跟蹤 model
和 optimizer
的檢查點,以滿足 SidecarEvaluator
的期望。這可以通過 tf.train.Checkpoint
和 tf.train.CheckpointManager
來完成:
checkpoint_dir = ... # Same `checkpoint_dir` supplied to `SidecarEvaluator`.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, checkpoint_dir=..., max_to_keep=...)
checkpoint_manager.save()
如果訓練集群使用 Keras Model.fit
API,則應使用 tf.keras.callbacks.ModelCheckpoint
,並帶有 save_weights_only=True
,並且 filepath
應附加 'ckpt-{epoch}':
checkpoint_dir = ... # Same `checkpoint_dir` supplied to `SidecarEvaluator`.
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(checkpoint_dir, 'ckpt-{epoch}'),
save_weights_only=True)
model.fit(dataset, epochs, callbacks=[model_checkpoint])
相關用法
- Python tf.keras.utils.SequenceEnqueuer用法及代碼示例
- Python tf.keras.utils.custom_object_scope用法及代碼示例
- Python tf.keras.utils.deserialize_keras_object用法及代碼示例
- Python tf.keras.utils.array_to_img用法及代碼示例
- Python tf.keras.utils.get_file用法及代碼示例
- Python tf.keras.utils.experimental.DatasetCreator用法及代碼示例
- Python tf.keras.utils.set_random_seed用法及代碼示例
- Python tf.keras.utils.timeseries_dataset_from_array用法及代碼示例
- Python tf.keras.utils.plot_model用法及代碼示例
- Python tf.keras.utils.get_custom_objects用法及代碼示例
- Python tf.keras.utils.pack_x_y_sample_weight用法及代碼示例
- Python tf.keras.utils.img_to_array用法及代碼示例
- Python tf.keras.utils.image_dataset_from_directory用法及代碼示例
- Python tf.keras.utils.get_registered_object用法及代碼示例
- Python tf.keras.utils.to_categorical用法及代碼示例
- Python tf.keras.utils.load_img用法及代碼示例
- Python tf.keras.utils.text_dataset_from_directory用法及代碼示例
- Python tf.keras.utils.unpack_x_y_sample_weight用法及代碼示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代碼示例
- Python tf.keras.metrics.Mean.merge_state用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.keras.utils.SidecarEvaluator。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。