當前位置: 首頁>>技術問答>>正文


在PyTorch中保存訓練模型的最佳方法?

我正在尋找在PyTorch中保存訓練模型的不同方法。到目前為止,發現了兩種選擇。

  1. torch.save()用於保存模型,torch.load()用於加載模型。

  2. model.state_dict()用於保存訓練模型,model.load_state_dict()用於加載保存的模型。

我已經遇到過這個discussion,其中推薦方法2優於方法1。

我的問題是,為什麽第二種方法更受歡迎?是否因為torch.nn模塊具有這兩個功能所以鼓勵使用它們?

最佳解決思路

我在github repo上找到了this page,我隻是在這裏原樣粘貼複製內容過來:


保存模型的推薦方法

序列化和恢複模型有兩種主要方法。

第一個(推薦)保存並僅加載模型參數:

torch.save(the_model.state_dict(), PATH)

然後:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

第二個保存並加載整個模型:

torch.save(the_model, PATH)

然後:

the_model = torch.load(PATH)

但是在這種情況下,序列化數據綁定到特定的類和使用的確切目錄結構,因此當在其他項目中使用時,或者在一些嚴重的重構之後,它可能以各種奇怪的方式中斷活出問題。

次佳解決思路

這取決於你想做什麽。

案例#1:保存模型以自行使用它進行預測:保存模型,恢複模型,然後將模型更改為評估模式。這樣做是因為通常有BatchNormDropout圖層,默認情況下在構造時處於訓練模式:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

案例#2:保存模型以便以後恢複訓練:如果您需要繼續訓練您將要保存的模型,則需要保存的不僅僅是模型。您還需要保存優化器,迭代輪次(epochs),分數等相關的狀態。您可以這樣做:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

要恢複訓練,您可以執行以下操作:state = torch.load(filepath),然後,恢複每個對象的狀態,如下所示:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

由於您正在恢複訓練,因此在加載時恢複狀態後請勿調用model.eval()

案例#3:其他人無法訪問您的代碼使用的模型:在Tensorflow中,您可以創建一個.pb文件,該文件定義了模型的體係結構和權重。這非常方便,特別是在使用Tensorflow serve時。在Pytorch中執行此操作的等效方法是:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

這種方式不太健壯,因為pytorch已經經曆了很多變化,我們不太推薦這種方式。

pytorch

參考資料

本文由《純淨天空》出品。文章地址: https://vimsky.com/zh-tw/article/3884.html,未經允許,請勿轉載。