用法
restore(
save_path, options=None
)參數
-
save_path檢查點的路徑,由save或tf.train.latest_checkpoint返回。如果檢查點是由基於名稱的tf.compat.v1.train.Saver編寫的,則名稱用於匹配變量。此路徑也可能是 SavedModel 目錄。 -
options可選的tf.train.CheckpointOptions對象。
返回
-
加載狀態對象,可用於對檢查點恢複的狀態進行斷言。
返回的狀態對象有以下方法:
assert_consumed():如果任何變量不匹配,則引發異常:檢查點值沒有匹配的 Python 對象或依賴關係圖中的 Python 對象在檢查點中沒有值。此方法返回狀態對象,因此可以與其他斷言鏈接。assert_existing_objects_matched():如果依賴圖中的任何現有 Python 對象不匹配,則引發異常。與assert_consumed不同,如果檢查點中的值沒有對應的 Python 對象,則此斷言將通過。例如,尚未構建的tf.keras.Layer對象,因此尚未創建任何變量,將通過此斷言但失敗assert_consumed。在將較大檢查點的一部分加載到新的 Python 程序中時很有用,例如保存了帶有tf.compat.v1.train.Optimizer的訓練檢查點,但僅加載了推理所需的狀態。此方法返回狀態對象,因此可以與其他斷言鏈接。assert_nontrivial_match():斷言除根對象之外的其他內容已匹配。這是一個非常弱的斷言,但對於庫代碼中的健全性檢查很有用,其中對象可能存在於檢查點中,而這些對象可能尚未在 Python 中創建,並且某些 Python 對象可能沒有檢查點值。expect_partial():關於不完整的檢查點恢複的靜默警告。當Checkpoint對象被刪除時(通常在程序關閉時),檢查點文件或對象的未使用部分會打印警告。
拋出
-
NotFoundError如果在save_path找不到檢查點或 SavedModel。
恢複訓練檢查點。
恢複此 Checkpoint 及其依賴的任何對象。
此方法旨在用於加載由 save() 創建的檢查點。對於由 write() 創建的檢查點,請使用 read() 方法,該方法不需要 save() 添加的 save_counter 變量。
restore() 如果要恢複的變量已經創建,則立即賦值,或者推遲恢複直到創建變量。如果在檢查點中有相應的對象,則在此調用之後添加的依賴項將被匹配(恢複請求將在任何可跟蹤的對象中排隊等待添加預期的依賴項)。
checkpoint = tf.train.Checkpoint( ... )
checkpoint.restore(path)
# You can additionally pass options to restore():
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.restore(path, options=options)
為確保加載完成並且不再發生延遲恢複,請使用 restore() 返回的狀態對象的 assert_consumed() 方法:
checkpoint.restore(path, options=options).assert_consumed()
如果在檢查點中未找到依賴圖中的任何 Python 對象,或者任何檢查點值沒有匹配的 Python 對象,則斷言將引發錯誤。
可以使用此方法加載來自 TensorFlow 1.x 的基於名稱的 tf.compat.v1.train.Saver 檢查點。名稱用於匹配變量。盡快使用tf.train.Checkpoint.save 重新編碼基於名稱的檢查點。
從 SavedModel 檢查點加載
要從 SavedModel 加載值,隻需將 SavedModel 目錄傳遞給 checkpoint.restore:
model = tf.keras.Model(...)
tf.saved_model.save(model, path) # or model.save(path, save_format='tf')
checkpoint = tf.train.Checkpoint(model)
checkpoint.restore(path).expect_partial()
這個例子在加載狀態上調用expect_partial(),因為從 Keras 保存的 SavedModels 通常會在檢查點中生成額外的鍵。否則,程序會在退出時打印很多關於未使用 key 的警告。
相關用法
- Python tf.train.Checkpoint.read用法及代碼示例
- Python tf.train.Checkpoint.save用法及代碼示例
- Python tf.train.Checkpoint.write用法及代碼示例
- Python tf.train.CheckpointOptions用法及代碼示例
- Python tf.train.CheckpointManager用法及代碼示例
- Python tf.train.Checkpoint用法及代碼示例
- Python tf.train.Coordinator.stop_on_exception用法及代碼示例
- Python tf.train.ClusterSpec用法及代碼示例
- Python tf.train.Coordinator用法及代碼示例
- Python tf.train.ExponentialMovingAverage用法及代碼示例
- Python tf.train.list_variables用法及代碼示例
- Python tf.transpose用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.train.Checkpoint.restore。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。
