本文简要介绍python语言中 torch.load
的用法。
用法:
torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)
f-file-like 对象(必须实现
read()
、readline()
、tell()
和seek()
),或包含文件名的字符串或 os.PathLike 对象map_location-指定如何重新映射存储位置的函数、
torch.device
、字符串或字典pickle_module-用于 unpickling 元数据和对象的模块(必须与用于序列化文件的
pickle_module
匹配)pickle_load_args-(仅限 Python 3)传递给
pickle_module.load()
和pickle_module.Unpickler()
的可选关键字参数,例如errors=...
。
从文件中加载使用
torch.save()
保存的对象。torch.load()
使用 Python 的 unpickling 工具,但特别对待张量底层的存储。它们首先在 CPU 上反序列化,然后移动到保存它们的设备。如果失败(例如,因为运行时系统没有某些设备),则会引发异常。但是,可以使用map_location
参数将存储动态重新映射到一组备用设备。如果
map_location
是可调用的,则将为每个序列化存储调用一次,并带有两个参数:存储和位置。存储参数将是驻留在 CPU 上的存储的初始反序列化。每个序列化存储都有一个与其关联的位置标签,用于标识保存它的设备,该标签是传递给map_location
的第二个参数。对于 CPU 张量,内置位置标签是'cpu'
,对于 CUDA 张量,内置位置标签是'cuda:device_id'
(例如'cuda:2'
)。map_location
应返回None
或存储。如果map_location
返回一个存储,它将用作最终的反序列化对象,已经移动到正确的设备。否则,torch.load()
将回退到默认行为,就好像未指定map_location
一样。如果
map_location
是torch.device
对象或包含设备标签的字符串,则它指示应加载所有张量的位置。否则,如果
map_location
是一个字典,它将用于将文件中出现的位置标签(键)重新映射到指定存储位置(值)的位置标签。用户扩展可以使用
torch.serialization.register_package()
注册自己的位置标签以及标记和反序列化方法。警告
torch.load()
使用pickle
隐式模块,已知它是不安全的。可以构造恶意的 pickle 数据,在 unpickling 期间执行任意代码。切勿加载可能来自不受信任的来源或可能已被篡改的数据。仅加载您信任的数据.注意
当您在包含 GPU 张量的文件上调用
torch.load()
时,这些张量将默认加载到 GPU。您可以先调用torch.load(.., map_location='cpu')
,然后调用load_state_dict()
,以避免加载模型检查点时 GPU RAM 激增。注意
默认情况下,我们将字节字符串解码为
utf-8
。这是为了避免在 Python 3 中加载由 Python 2 保存的文件时出现常见错误情况UnicodeDecodeError: 'ascii' codec can't decode byte 0x...
。如果此默认值不正确,您可以使用额外的encoding
关键字参数来指定应如何加载这些对象,例如,encoding='latin1'
使用latin1
编码将它们解码为字符串,而encoding='bytes'
将它们保存为字节数组,稍后可以使用byte_array.decode(...)
对其进行解码。示例
>>> torch.load('tensors.pt') # Load all tensors onto the CPU >>> torch.load('tensors.pt', map_location=torch.device('cpu')) # Load all tensors onto the CPU, using a function >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage) # Load all tensors onto GPU 1 >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1)) # Map tensors from GPU 1 to GPU 0 >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'}) # Load tensor from io.BytesIO object >>> with open('tensor.pt', 'rb') as f: ... buffer = io.BytesIO(f.read()) >>> torch.load(buffer) # Load a module with 'ascii' encoding for unpickling >>> torch.load('module.pt', encoding='ascii')
参数:
相关用法
- Python PyTorch load_state_dict_from_url用法及代码示例
- Python PyTorch load_sp_model用法及代码示例
- Python PyTorch load_url用法及代码示例
- Python PyTorch load_inline用法及代码示例
- Python PyTorch load用法及代码示例
- Python PyTorch log2用法及代码示例
- Python PyTorch logical_xor用法及代码示例
- Python PyTorch logical_and用法及代码示例
- Python PyTorch log_softmax用法及代码示例
- Python PyTorch logical_or用法及代码示例
- Python PyTorch logit用法及代码示例
- Python PyTorch logical_not用法及代码示例
- Python PyTorch logcumsumexp用法及代码示例
- Python PyTorch log10用法及代码示例
- Python PyTorch logaddexp用法及代码示例
- Python PyTorch logdet用法及代码示例
- Python PyTorch log用法及代码示例
- Python PyTorch logsumexp用法及代码示例
- Python PyTorch logspace用法及代码示例
- Python PyTorch log1p用法及代码示例
- Python PyTorch lstsq用法及代码示例
- Python PyTorch lerp用法及代码示例
- Python PyTorch lt用法及代码示例
- Python PyTorch lgamma用法及代码示例
- Python PyTorch lazy_apply用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.load。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。