我正在尋找在PyTorch中保存訓練模型的不同方法。到目前為止,發現了兩種選擇。
-
torch.save()用於保存模型,torch.load()用於加載模型。
-
model.state_dict()用於保存訓練模型,model.load_state_dict()用於加載保存的模型。
我已經遇到過這個discussion,其中推薦方法2優於方法1。
我的問題是,為什麽第二種方法更受歡迎?是否因為torch.nn模塊具有這兩個功能所以鼓勵使用它們?
最佳解決思路
我在github repo上找到了this page,我隻是在這裏原樣粘貼複製內容過來:
保存模型的推薦方法
序列化和恢複模型有兩種主要方法。
第一個(推薦)保存並僅加載模型參數:
torch.save(the_model.state_dict(), PATH)
然後:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
第二個保存並加載整個模型:
torch.save(the_model, PATH)
然後:
the_model = torch.load(PATH)
但是在這種情況下,序列化數據綁定到特定的類和使用的確切目錄結構,因此當在其他項目中使用時,或者在一些嚴重的重構之後,它可能以各種奇怪的方式中斷活出問題。
次佳解決思路
這取決於你想做什麽。
案例#1:保存模型以自行使用它進行預測:保存模型,恢複模型,然後將模型更改為評估模式。這樣做是因為通常有BatchNorm
和Dropout
圖層,默認情況下在構造時處於訓練模式:
torch.save(model.state_dict(), filepath)
#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()
案例#2:保存模型以便以後恢複訓練:如果您需要繼續訓練您將要保存的模型,則需要保存的不僅僅是模型。您還需要保存優化器,迭代輪次(epochs),分數等相關的狀態。您可以這樣做:
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
...
}
torch.save(state, filepath)
要恢複訓練,您可以執行以下操作:state = torch.load(filepath)
,然後,恢複每個對象的狀態,如下所示:
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
由於您正在恢複訓練,因此在加載時恢複狀態後請勿調用model.eval()
。
案例#3:其他人無法訪問您的代碼使用的模型:在Tensorflow中,您可以創建一個.pb
文件,該文件定義了模型的體係結構和權重。這非常方便,特別是在使用Tensorflow serve
時。在Pytorch中執行此操作的等效方法是:
torch.save(model, filepath)
# Then later:
model = torch.load(filepath)
這種方式不太健壯,因為pytorch已經經曆了很多變化,我們不太推薦這種方式。