當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.train.Checkpoint.restore用法及代碼示例


用法

restore(
    save_path, options=None
)

參數

返回

  • 加載狀態對象,可用於對檢查點恢複的狀態進行斷言。

    返回的狀態對象有以下方法:

    • assert_consumed() :如果任何變量不匹配,則引發異常:檢查點值沒有匹配的 Python 對象或依賴關係圖中的 Python 對象在檢查點中沒有值。此方法返回狀態對象,因此可以與其他斷言鏈接。

    • assert_existing_objects_matched() :如果依賴圖中的任何現有 Python 對象不匹配,則引發異常。與 assert_consumed 不同,如果檢查點中的值沒有對應的 Python 對象,則此斷言將通過。例如,尚未構建的 tf.keras.Layer 對象,因此尚未創建任何變量,將通過此斷言但失敗 assert_consumed 。在將較大檢查點的一部分加載到新的 Python 程序中時很有用,例如保存了帶有 tf.compat.v1.train.Optimizer 的訓練檢查點,但僅加載了推理所需的狀態。此方法返回狀態對象,因此可以與其他斷言鏈接。

    • assert_nontrivial_match() :斷言除根對象之外的其他內容已匹配。這是一個非常弱的斷言,但對於庫代碼中的健全性檢查很有用,其中對象可能存在於檢查點中,而這些對象可能尚未在 Python 中創建,並且某些 Python 對象可能沒有檢查點值。

    • expect_partial():關於不完整的檢查點恢複的靜默警告。當Checkpoint 對象被刪除時(通常在程序關閉時),檢查點文件或對象的未使用部分會打印警告。

拋出

  • NotFoundError 如果在 save_path 找不到檢查點或 SavedModel。

恢複訓練檢查點。

恢複此 Checkpoint 及其依賴的任何對象。

此方法旨在用於加載由 save() 創建的檢查點。對於由 write() 創建的檢查點,請使用 read() 方法,該方法不需要 save() 添加的 save_counter 變量。

restore() 如果要恢複的變量已經創建,則立即賦值,或者推遲恢複直到創建變量。如果在檢查點中有相應的對象,則在此調用之後添加的依賴項將被匹配(恢複請求將在任何可跟蹤的對象中排隊等待添加預期的依賴項)。

checkpoint = tf.train.Checkpoint( ... )
checkpoint.restore(path)

# You can additionally pass options to restore():
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.restore(path, options=options)

為確保加載完成並且不再發生延遲恢複,請使用 restore() 返回的狀態對象的 assert_consumed() 方法:

checkpoint.restore(path, options=options).assert_consumed()

如果在檢查點中未找到依賴圖中的任何 Python 對象,或者任何檢查點值沒有匹配的 Python 對象,則斷言將引發錯誤。

可以使用此方法加載來自 TensorFlow 1.x 的基於名稱的 tf.compat.v1.train.Saver 檢查點。名稱用於匹配變量。盡快使用tf.train.Checkpoint.save 重新編碼基於名稱的檢查點。

從 SavedModel 檢查點加載

要從 SavedModel 加載值,隻需將 SavedModel 目錄傳遞給 checkpoint.restore:

model = tf.keras.Model(...)
tf.saved_model.save(model, path)  # or model.save(path, save_format='tf')

checkpoint = tf.train.Checkpoint(model)
checkpoint.restore(path).expect_partial()

這個例子在加載狀態上調用expect_partial(),因為從 Keras 保存的 SavedModels 通常會在檢查點中生成額外的鍵。否則,程序會在退出時打印很多關於未使用 key 的警告。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.train.Checkpoint.restore。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。