当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch save用法及代码示例


本文简要介绍python语言中 torch.jit.save 的用法。

用法:

torch.jit.save(m, f, _extra_files=None)

参数

  • m-A ScriptModule 要保存。

  • f-file-like 对象(必须实现写入和刷新)或包含文件名的字符串。

  • _extra_files-从文件名映射到将作为 f 的一部分存储的内容。

保存此模块的离线版本以在单独的过程中使用。保存的模块序列化了该模块的所有方法、子模块、参数和属性。它可以使用 torch::jit::load(filename) 加载到 C++ API 中,也可以使用 torch.jit.load 加载到 Python API 中。

为了能够保存模块,它不能对本机 Python 函数进行任何调用。这意味着所有子模块也必须是 ScriptModule 的子类。

危险

所有模块,无论其设备如何,在加载期间始终加载到 CPU 上。这与 torch.load() 的语义不同,将来可能会改变。

注意

torch.jit.save 尝试跨版本保留某些运算符的行为。例如,将 PyTorch 1.5 中的两个整数张量相除执行了楼层除法,如果包含该代码的模块保存在 PyTorch 1.5 中并加载到 PyTorch 1.6 中,则其除法行为将被保留。然而,在 PyTorch 1.6 中保存的相同模块将无法在 PyTorch 1.5 中加载,因为除法的行为在 1.6 中发生了变化,并且 1.5 不知道如何复制 1.6 的行为。

例子:

import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

m = torch.jit.script(MyModule())

# Save to file
torch.jit.save(m, 'scriptmodule.pt')
# This line is equivalent to the previous
m.save("scriptmodule.pt")

# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.jit.save(m, buffer)

# Save with extra files
extra_files = {'foo.txt': b'bar'}
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.jit.save。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。