保存和恢複變量。
用法
tf.compat.v1.train.Saver(
var_list=None, reshape=False, sharded=False, max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False,
saver_def=None, builder=None, defer_build=False, allow_empty=False,
write_version=tf.train.SaverDef.V2, pad_step_number=False,
save_relative_paths=False, filename=None
)
參數
-
var_list
Variable
/SaveableObject
的列表,或映射名稱到SaveableObject
的字典。如果None
,默認為所有可保存對象的列表。 -
reshape
如果True
,允許從變量具有不同形狀的檢查點恢複參數。 -
sharded
如果True
,將檢查點分片,每個設備一個。 -
max_to_keep
要保留的最近檢查點的最大數量。默認為 5。 -
keep_checkpoint_every_n_hours
多久設置一次檢查點。默認為 10,000 小時。 -
name
String 。添加操作時用作前綴的可選名稱。 -
restore_sequentially
ABool
,如果為真,則會導致在每個設備內按順序恢複不同的變量。這可以在恢複非常大的模型時降低內存使用量。 -
saver_def
使用可選的SaverDef
proto 而不是運行構建器。這僅對想要為先前構建的具有Saver
的Graph
重新創建Saver
對象的特殊代碼有用。saver_def
原型應該是為該Graph
創建的Saver
的as_saver_def()
調用返回的原型。 -
builder
如果未提供saver_def
,則使用可選的SaverBuilder
。默認為BulkSaverBuilder()
。 -
defer_build
如果True
,請推遲將保存和恢複操作添加到build()
調用。在這種情況下,應在完成圖形或使用保護程序之前調用build()
。 -
allow_empty
如果False
(默認)在圖中沒有變量的情況下引發錯誤。否則,無論如何構建保護程序並將其設為no-op。 -
write_version
控製保存檢查點時使用的格式。它還會影響某些文件路徑匹配邏輯。推薦選擇 V2 格式:在所需的內存和還原期間產生的延遲方麵,它比 V1 優化得多。無論此標誌如何,Saver 都能夠從 V2 和 V1 檢查點恢複。 -
pad_step_number
如果為 True,則將檢查點文件路徑中的全局步驟編號填充到某個固定寬度(默認為 8)。這是默認關閉的。 -
save_relative_paths
如果True
,將寫入檢查點狀態文件的相對路徑。如果用戶想要複製檢查點目錄並從複製的目錄重新加載,則需要這樣做。 -
filename
如果在圖形構建時已知,則用於變量加載/保存的文件名。
拋出
-
TypeError
如果var_list
無效。 -
ValueError
如果var_list
中的任何鍵或值不是唯一的。 -
RuntimeError
如果啟用了即刻執行並且var_list
未指定要保存的變量列表。
屬性
-
last_checkpoints
not-yet-deleted 檢查點文件名列表。您可以將任何返回值傳遞給
restore()
。
遷移到 TF2
警告:這個 API 是為 TensorFlow v1 設計的。繼續閱讀有關如何從該 API 遷移到本機 TensorFlow v2 等效項的詳細信息。見TensorFlow v1 到 TensorFlow v2 遷移指南有關如何遷移其餘代碼的說明。
tf.compat.v1.train.Saver
不支持在 TF2 中保存和恢複檢查點。請切換到 tf.train.Checkpoint
或 tf.keras.Model.save_weights
,它們執行更強大的基於對象的保存。
如何重寫檢查點
請立即使用基於對象的檢查點 API 重寫您的檢查點。
您可以使用 tf.train.Checkpoint.restore
或 tf.keras.Model.load_weights
加載由 tf.compat.v1.train.Saver
編寫的基於名稱的檢查點。但是,您可能必須更改模型中的變量名稱以匹配基於名稱的檢查點中的變量名稱,可以使用 tf.train.list_variables(path)
查看。
另一種選擇是創建一個assignment_map
,將基於名稱的檢查點中的變量名稱映射到模型中的變量,例如:
{
'sequential/dense/bias':model.variables[0],
'sequential/dense/kernel':model.variables[1]
}
並使用tf.compat.v1.train.init_from_checkpoint恢複基於名稱的檢查點。
恢複後,使用 tf.train.Checkpoint.save
或 tf.keras.Model.save_weights
重新編碼您的檢查點。
有關更多詳細信息,請參閱遷移指南的檢查點兼容性部分。
TF2 中的檢查點管理
使用tf.train.CheckpointManager
管理 TF2 中的檢查點。 tf.train.CheckpointManager
提供等效的 keep_checkpoint_every_n_hours
和 max_to_keep
參數。
要恢複最新的檢查點,
checkpoint = tf.train.Checkpoint(model)
manager = tf.train.CheckpointManager(checkpoint)
status = checkpoint.restore(manager.latest_checkpoint)
tf.train.CheckpointManager
還編寫了一個 CheckpointState 原型,其中包含每個檢查點創建時的時間戳。
在 TF2 中寫入 MetaGraphDef
要替換 tf.compat.v1.train.Saver.save(write_meta_graph=True)
,請使用 tf.saved_model.save
編寫 MetaGraphDef
(包含在 saved_model.pb
中)。
有關變量、保存和恢複的概述,請參閱變量。
Saver
類添加了操作來保存和恢複檢查點的變量。它還提供了運行這些操作的便捷方法。
檢查點是專有格式的二進製文件,將變量名稱映射到張量值。檢查檢查點內容的最佳方法是使用 Saver
加載它。
Savers 可以使用提供的計數器自動編號檢查點文件名。這使您可以在訓練模型時在不同的步驟中保留多個檢查點。例如,您可以使用訓練步驟編號對檢查點文件名進行編號。為了避免填滿磁盤,保存程序會自動管理檢查點文件。例如,他們可以隻保留 N 個最近的文件,或者每 N 小時的訓練一個檢查點。
通過將值傳遞給可選的 global_step
參數給 save()
來為檢查點文件名編號:
saver.save(sess, 'my-model', global_step=0) ==> filename:'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename:'my-model-1000'
此外,Saver()
構造函數的可選參數允許您控製磁盤上檢查點文件的擴散:
max_to_keep
表示要保留的最近檢查點文件的最大數量。創建新文件時,會刪除舊文件。如果 None 或 0,則不會從文件係統中刪除檢查點,但隻有最後一個保留在checkpoint
文件中。默認為 5(即保留 5 個最近的檢查點文件。)keep_checkpoint_every_n_hours
:除了保留最新的max_to_keep
檢查點文件外,您可能希望每訓練 N 小時保留一個檢查點文件。如果您想稍後分析模型在長時間訓練期間的進展情況,這將很有用。例如,通過keep_checkpoint_every_n_hours=2
可確保您為每 2 小時的訓練保留一個檢查點文件。默認值 10,000 小時有效地禁用了該函數。
請注意,您仍然必須調用save()
方法來保存模型。將這些參數傳遞給構造函數不會自動為您保存變量。
定期保存的訓練計劃如下所示:
...
# Create a saver.
saver = tf.compat.v1.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.compat.v1.Session()
for step in range(1000000):
sess.run(..training_op..)
if step % 1000 == 0:
# Append the step number to the checkpoint name:
saver.save(sess, 'my-model', global_step=step)
除了檢查點文件之外,保存程序還在磁盤上保留一個協議緩衝區,其中包含最近檢查點的列表。這用於管理編號的檢查點文件和 latest_checkpoint()
,這使得發現最近檢查點的路徑變得容易。該協議緩衝區存儲在檢查點文件旁邊名為'checkpoint' 的文件中。
如果創建多個保存程序,則可以在調用 save()
時為協議緩衝區文件指定不同的文件名。
相關用法
- Python tf.compat.v1.train.Supervisor.managed_session用法及代碼示例
- Python tf.compat.v1.train.Supervisor用法及代碼示例
- Python tf.compat.v1.train.SessionManager用法及代碼示例
- Python tf.compat.v1.train.SingularMonitoredSession用法及代碼示例
- Python tf.compat.v1.train.SyncReplicasOptimizer用法及代碼示例
- Python tf.compat.v1.train.SingularMonitoredSession.run_step_fn用法及代碼示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代碼示例
- Python tf.compat.v1.train.get_or_create_global_step用法及代碼示例
- Python tf.compat.v1.train.cosine_decay_restarts用法及代碼示例
- Python tf.compat.v1.train.Optimizer用法及代碼示例
- Python tf.compat.v1.train.AdagradOptimizer.compute_gradients用法及代碼示例
- Python tf.compat.v1.train.init_from_checkpoint用法及代碼示例
- Python tf.compat.v1.train.Checkpoint用法及代碼示例
- Python tf.compat.v1.train.Checkpoint.restore用法及代碼示例
- Python tf.compat.v1.train.global_step用法及代碼示例
- Python tf.compat.v1.train.MonitoredSession.run_step_fn用法及代碼示例
- Python tf.compat.v1.train.RMSPropOptimizer.compute_gradients用法及代碼示例
- Python tf.compat.v1.train.exponential_decay用法及代碼示例
- Python tf.compat.v1.train.natural_exp_decay用法及代碼示例
- Python tf.compat.v1.train.MomentumOptimizer用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.train.Saver。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。