當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.compat.v1.train.Saver用法及代碼示例


保存和恢複變量。

用法

tf.compat.v1.train.Saver(
    var_list=None, reshape=False, sharded=False, max_to_keep=5,
    keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False,
    saver_def=None, builder=None, defer_build=False, allow_empty=False,
    write_version=tf.train.SaverDef.V2, pad_step_number=False,
    save_relative_paths=False, filename=None
)

參數

  • var_list Variable /SaveableObject 的列表,或映射名稱到 SaveableObject 的字典。如果 None ,默認為所有可保存對象的列表。
  • reshape 如果 True ,允許從變量具有不同形狀的檢查點恢複參數。
  • sharded 如果 True ,將檢查點分片,每個設備一個。
  • max_to_keep 要保留的最近檢查點的最大數量。默認為 5。
  • keep_checkpoint_every_n_hours 多久設置一次檢查點。默認為 10,000 小時。
  • name String 。添加操作時用作前綴的可選名稱。
  • restore_sequentially A Bool ,如果為真,則會導致在每個設備內按順序恢複不同的變量。這可以在恢複非常大的模型時降低內存使用量。
  • saver_def 使用可選的SaverDef proto 而不是運行構建器。這僅對想要為先前構建的具有 SaverGraph 重新創建 Saver 對象的特殊代碼有用。 saver_def 原型應該是為該 Graph 創建的 Saveras_saver_def() 調用返回的原型。
  • builder 如果未提供 saver_def,則使用可選的 SaverBuilder。默認為 BulkSaverBuilder()
  • defer_build 如果 True ,請推遲將保存和恢複操作添加到 build() 調用。在這種情況下,應在完成圖形或使用保護程序之前調用 build()
  • allow_empty 如果 False(默認)在圖中沒有變量的情況下引發錯誤。否則,無論如何構建保護程序並將其設為no-op。
  • write_version 控製保存檢查點時使用的格式。它還會影響某些文件路徑匹配邏輯。推薦選擇 V2 格式:在所需的內存和還原期間產生的延遲方麵,它比 V1 優化得多。無論此標誌如何,Saver 都能夠從 V2 和 V1 檢查點恢複。
  • pad_step_number 如果為 True,則將檢查點文件路徑中的全局步驟編號填充到某個固定寬度(默認為 8)。這是默認關閉的。
  • save_relative_paths 如果 True ,將寫入檢查點狀態文件的相對路徑。如果用戶想要複製檢查點目錄並從複製的目錄重新加載,則需要這樣做。
  • filename 如果在圖形構建時已知,則用於變量加載/保存的文件名。

拋出

  • TypeError 如果var_list 無效。
  • ValueError 如果 var_list 中的任何鍵或值不是唯一的。
  • RuntimeError 如果啟用了即刻執行並且var_list 未指定要保存的變量列表。

屬性

  • last_checkpoints not-yet-deleted 檢查點文件名列表。

    您可以將任何返回值傳遞給 restore()

遷移到 TF2

警告:這個 API 是為 TensorFlow v1 設計的。繼續閱讀有關如何從該 API 遷移到本機 TensorFlow v2 等效項的詳細信息。見TensorFlow v1 到 TensorFlow v2 遷移指南有關如何遷移其餘代碼的說明。

tf.compat.v1.train.Saver 不支持在 TF2 中保存和恢複檢查點。請切換到 tf.train.Checkpointtf.keras.Model.save_weights ,它們執行更強大的基於對象的保存。

如何重寫檢查點

請立即使用基於對象的檢查點 API 重寫您的檢查點。

您可以使用 tf.train.Checkpoint.restoretf.keras.Model.load_weights 加載由 tf.compat.v1.train.Saver 編寫的基於名稱的檢查點。但是,您可能必須更改模型中的變量名稱以匹配基於名稱的檢查點中的變量名稱,可以使用 tf.train.list_variables(path) 查看。

另一種選擇是創建一個assignment_map,將基於名稱的檢查點中的變量名稱映射到模型中的變量,例如:

{
    'sequential/dense/bias':model.variables[0],
    'sequential/dense/kernel':model.variables[1]
}

並使用tf.compat.v1.train.init_from_checkpoint恢複基於名稱的檢查點。

恢複後,使用 tf.train.Checkpoint.savetf.keras.Model.save_weights 重新編碼您的檢查點。

有關更多詳細信息,請參閱遷移指南的檢查點兼容性部分。

TF2 中的檢查點管理

使用tf.train.CheckpointManager 管理 TF2 中的檢查點。 tf.train.CheckpointManager 提供等效的 keep_checkpoint_every_n_hoursmax_to_keep 參數。

要恢複最新的檢查點,

checkpoint = tf.train.Checkpoint(model)
manager = tf.train.CheckpointManager(checkpoint)
status = checkpoint.restore(manager.latest_checkpoint)

tf.train.CheckpointManager 還編寫了一個 CheckpointState 原型,其中包含每個檢查點創建時的時間戳。

在 TF2 中寫入 MetaGraphDef

要替換 tf.compat.v1.train.Saver.save(write_meta_graph=True) ,請使用 tf.saved_model.save 編寫 MetaGraphDef (包含在 saved_model.pb 中)。

有關變量、保存和恢複的概述,請參閱變量。

Saver 類添加了操作來保存和恢複檢查點的變量。它還提供了運行這些操作的便捷方法。

檢查點是專有格式的二進製文件,將變量名稱映射到張量值。檢查檢查點內容的最佳方法是使用 Saver 加載它。

Savers 可以使用提供的計數器自動編號檢查點文件名。這使您可以在訓練模型時在不同的步驟中保留多個檢查點。例如,您可以使用訓練步驟編號對檢查點文件名進行編號。為了避免填滿磁盤,保存程序會自動管理檢查點文件。例如,他們可以隻保留 N 個最近的文件,或者每 N 小時的訓練一個檢查點。

通過將值傳遞給可選的 global_step 參數給 save() 來為檢查點文件名編號:

saver.save(sess, 'my-model', global_step=0) ==> filename:'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename:'my-model-1000'

此外,Saver() 構造函數的可選參數允許您控製磁盤上檢查點文件的擴散:

  • max_to_keep 表示要保留的最近檢查點文件的最大數量。創建新文件時,會刪除舊文件。如果 None 或 0,則不會從文件係統中刪除檢查點,但隻有最後一個保留在 checkpoint 文件中。默認為 5(即保留 5 個最近的檢查點文件。)

  • keep_checkpoint_every_n_hours :除了保留最新的max_to_keep 檢查點文件外,您可能希望每訓練 N 小時保留一個檢查點文件。如果您想稍後分析模型在長時間訓練期間的進展情況,這將很有用。例如,通過 keep_checkpoint_every_n_hours=2 可確保您為每 2 小時的訓練保留一個檢查點文件。默認值 10,000 小時有效地禁用了該函數。

請注意,您仍然必須調用save() 方法來保存模型。將這些參數傳遞給構造函數不會自動為您保存變量。

定期保存的訓練計劃如下所示:

...
# Create a saver.
saver = tf.compat.v1.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.compat.v1.Session()
for step in range(1000000):
    sess.run(..training_op..)
    if step % 1000 == 0:
        # Append the step number to the checkpoint name:
        saver.save(sess, 'my-model', global_step=step)

除了檢查點文件之外,保存程序還在磁盤上保留一個協議緩衝區,其中包含最近檢查點的列表。這用於管理編號的檢查點文件和 latest_checkpoint() ,這使得發現最近檢查點的路徑變得容易。該協議緩衝區存儲在檢查點文件旁邊名為'checkpoint' 的文件中。

如果創建多個保存程序,則可以在調用 save() 時為協議緩衝區文件指定不同的文件名。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.train.Saver。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。