當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch save用法及代碼示例


本文簡要介紹python語言中 torch.jit.save 的用法。

用法:

torch.jit.save(m, f, _extra_files=None)

參數

  • m-A ScriptModule 要保存。

  • f-file-like 對象(必須實現寫入和刷新)或包含文件名的字符串。

  • _extra_files-從文件名映射到將作為 f 的一部分存儲的內容。

保存此模塊的離線版本以在單獨的過程中使用。保存的模塊序列化了該模塊的所有方法、子模塊、參數和屬性。它可以使用 torch::jit::load(filename) 加載到 C++ API 中,也可以使用 torch.jit.load 加載到 Python API 中。

為了能夠保存模塊,它不能對本機 Python 函數進行任何調用。這意味著所有子模塊也必須是 ScriptModule 的子類。

危險

所有模塊,無論其設備如何,在加載期間始終加載到 CPU 上。這與 torch.load() 的語義不同,將來可能會改變。

注意

torch.jit.save 嘗試跨版本保留某些運算符的行為。例如,將 PyTorch 1.5 中的兩個整數張量相除執行了樓層除法,如果包含該代碼的模塊保存在 PyTorch 1.5 中並加載到 PyTorch 1.6 中,則其除法行為將被保留。然而,在 PyTorch 1.6 中保存的相同模塊將無法在 PyTorch 1.5 中加載,因為除法的行為在 1.6 中發生了變化,並且 1.5 不知道如何複製 1.6 的行為。

例子:

import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

m = torch.jit.script(MyModule())

# Save to file
torch.jit.save(m, 'scriptmodule.pt')
# This line is equivalent to the previous
m.save("scriptmodule.pt")

# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.jit.save(m, buffer)

# Save with extra files
extra_files = {'foo.txt': b'bar'}
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.jit.save。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。