用法
restore(
save_path, options=None
)
參數
-
save_path
檢查點的路徑,由save
或tf.train.latest_checkpoint
返回。如果檢查點是由基於名稱的tf.compat.v1.train.Saver
編寫的,則名稱用於匹配變量。此路徑也可能是 SavedModel 目錄。 -
options
可選的tf.train.CheckpointOptions
對象。
返回
-
加載狀態對象,可用於對檢查點恢複的狀態進行斷言。
返回的狀態對象有以下方法:
assert_consumed()
:如果任何變量不匹配,則引發異常:檢查點值沒有匹配的 Python 對象或依賴關係圖中的 Python 對象在檢查點中沒有值。此方法返回狀態對象,因此可以與其他斷言鏈接。assert_existing_objects_matched()
:如果依賴圖中的任何現有 Python 對象不匹配,則引發異常。與assert_consumed
不同,如果檢查點中的值沒有對應的 Python 對象,則此斷言將通過。例如,尚未構建的tf.keras.Layer
對象,因此尚未創建任何變量,將通過此斷言但失敗assert_consumed
。在將較大檢查點的一部分加載到新的 Python 程序中時很有用,例如保存了帶有tf.compat.v1.train.Optimizer
的訓練檢查點,但僅加載了推理所需的狀態。此方法返回狀態對象,因此可以與其他斷言鏈接。assert_nontrivial_match()
:斷言除根對象之外的其他內容已匹配。這是一個非常弱的斷言,但對於庫代碼中的健全性檢查很有用,其中對象可能存在於檢查點中,而這些對象可能尚未在 Python 中創建,並且某些 Python 對象可能沒有檢查點值。expect_partial()
:關於不完整的檢查點恢複的靜默警告。當Checkpoint
對象被刪除時(通常在程序關閉時),檢查點文件或對象的未使用部分會打印警告。
拋出
-
NotFoundError
如果在save_path
找不到檢查點或 SavedModel。
恢複訓練檢查點。
恢複此 Checkpoint
及其依賴的任何對象。
此方法旨在用於加載由 save()
創建的檢查點。對於由 write()
創建的檢查點,請使用 read()
方法,該方法不需要 save()
添加的 save_counter
變量。
restore()
如果要恢複的變量已經創建,則立即賦值,或者推遲恢複直到創建變量。如果在檢查點中有相應的對象,則在此調用之後添加的依賴項將被匹配(恢複請求將在任何可跟蹤的對象中排隊等待添加預期的依賴項)。
checkpoint = tf.train.Checkpoint( ... )
checkpoint.restore(path)
# You can additionally pass options to restore():
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.restore(path, options=options)
為確保加載完成並且不再發生延遲恢複,請使用 restore()
返回的狀態對象的 assert_consumed()
方法:
checkpoint.restore(path, options=options).assert_consumed()
如果在檢查點中未找到依賴圖中的任何 Python 對象,或者任何檢查點值沒有匹配的 Python 對象,則斷言將引發錯誤。
可以使用此方法加載來自 TensorFlow 1.x 的基於名稱的 tf.compat.v1.train.Saver
檢查點。名稱用於匹配變量。盡快使用tf.train.Checkpoint.save
重新編碼基於名稱的檢查點。
從 SavedModel 檢查點加載
要從 SavedModel 加載值,隻需將 SavedModel 目錄傳遞給 checkpoint.restore:
model = tf.keras.Model(...)
tf.saved_model.save(model, path) # or model.save(path, save_format='tf')
checkpoint = tf.train.Checkpoint(model)
checkpoint.restore(path).expect_partial()
這個例子在加載狀態上調用expect_partial()
,因為從 Keras 保存的 SavedModels 通常會在檢查點中生成額外的鍵。否則,程序會在退出時打印很多關於未使用 key 的警告。
相關用法
- Python tf.train.Checkpoint.read用法及代碼示例
- Python tf.train.Checkpoint.save用法及代碼示例
- Python tf.train.Checkpoint.write用法及代碼示例
- Python tf.train.CheckpointOptions用法及代碼示例
- Python tf.train.CheckpointManager用法及代碼示例
- Python tf.train.Checkpoint用法及代碼示例
- Python tf.train.Coordinator.stop_on_exception用法及代碼示例
- Python tf.train.ClusterSpec用法及代碼示例
- Python tf.train.Coordinator用法及代碼示例
- Python tf.train.ExponentialMovingAverage用法及代碼示例
- Python tf.train.list_variables用法及代碼示例
- Python tf.transpose用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.train.Checkpoint.restore。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。