本文簡要介紹python語言中 torch.load
的用法。
用法:
torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)
f-file-like 對象(必須實現
read()
、readline()
、tell()
和seek()
),或包含文件名的字符串或 os.PathLike 對象map_location-指定如何重新映射存儲位置的函數、
torch.device
、字符串或字典pickle_module-用於 unpickling 元數據和對象的模塊(必須與用於序列化文件的
pickle_module
匹配)pickle_load_args-(僅限 Python 3)傳遞給
pickle_module.load()
和pickle_module.Unpickler()
的可選關鍵字參數,例如errors=...
。
從文件中加載使用
torch.save()
保存的對象。torch.load()
使用 Python 的 unpickling 工具,但特別對待張量底層的存儲。它們首先在 CPU 上反序列化,然後移動到保存它們的設備。如果失敗(例如,因為運行時係統沒有某些設備),則會引發異常。但是,可以使用map_location
參數將存儲動態重新映射到一組備用設備。如果
map_location
是可調用的,則將為每個序列化存儲調用一次,並帶有兩個參數:存儲和位置。存儲參數將是駐留在 CPU 上的存儲的初始反序列化。每個序列化存儲都有一個與其關聯的位置標簽,用於標識保存它的設備,該標簽是傳遞給map_location
的第二個參數。對於 CPU 張量,內置位置標簽是'cpu'
,對於 CUDA 張量,內置位置標簽是'cuda:device_id'
(例如'cuda:2'
)。map_location
應返回None
或存儲。如果map_location
返回一個存儲,它將用作最終的反序列化對象,已經移動到正確的設備。否則,torch.load()
將回退到默認行為,就好像未指定map_location
一樣。如果
map_location
是torch.device
對象或包含設備標簽的字符串,則它指示應加載所有張量的位置。否則,如果
map_location
是一個字典,它將用於將文件中出現的位置標簽(鍵)重新映射到指定存儲位置(值)的位置標簽。用戶擴展可以使用
torch.serialization.register_package()
注冊自己的位置標簽以及標記和反序列化方法。警告
torch.load()
使用pickle
隱式模塊,已知它是不安全的。可以構造惡意的 pickle 數據,在 unpickling 期間執行任意代碼。切勿加載可能來自不受信任的來源或可能已被篡改的數據。僅加載您信任的數據.注意
當您在包含 GPU 張量的文件上調用
torch.load()
時,這些張量將默認加載到 GPU。您可以先調用torch.load(.., map_location='cpu')
,然後調用load_state_dict()
,以避免加載模型檢查點時 GPU RAM 激增。注意
默認情況下,我們將字節字符串解碼為
utf-8
。這是為了避免在 Python 3 中加載由 Python 2 保存的文件時出現常見錯誤情況UnicodeDecodeError: 'ascii' codec can't decode byte 0x...
。如果此默認值不正確,您可以使用額外的encoding
關鍵字參數來指定應如何加載這些對象,例如,encoding='latin1'
使用latin1
編碼將它們解碼為字符串,而encoding='bytes'
將它們保存為字節數組,稍後可以使用byte_array.decode(...)
對其進行解碼。示例
>>> torch.load('tensors.pt') # Load all tensors onto the CPU >>> torch.load('tensors.pt', map_location=torch.device('cpu')) # Load all tensors onto the CPU, using a function >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage) # Load all tensors onto GPU 1 >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1)) # Map tensors from GPU 1 to GPU 0 >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'}) # Load tensor from io.BytesIO object >>> with open('tensor.pt', 'rb') as f: ... buffer = io.BytesIO(f.read()) >>> torch.load(buffer) # Load a module with 'ascii' encoding for unpickling >>> torch.load('module.pt', encoding='ascii')
參數:
相關用法
- Python PyTorch load_state_dict_from_url用法及代碼示例
- Python PyTorch load_sp_model用法及代碼示例
- Python PyTorch load_url用法及代碼示例
- Python PyTorch load_inline用法及代碼示例
- Python PyTorch load用法及代碼示例
- Python PyTorch log2用法及代碼示例
- Python PyTorch logical_xor用法及代碼示例
- Python PyTorch logical_and用法及代碼示例
- Python PyTorch log_softmax用法及代碼示例
- Python PyTorch logical_or用法及代碼示例
- Python PyTorch logit用法及代碼示例
- Python PyTorch logical_not用法及代碼示例
- Python PyTorch logcumsumexp用法及代碼示例
- Python PyTorch log10用法及代碼示例
- Python PyTorch logaddexp用法及代碼示例
- Python PyTorch logdet用法及代碼示例
- Python PyTorch log用法及代碼示例
- Python PyTorch logsumexp用法及代碼示例
- Python PyTorch logspace用法及代碼示例
- Python PyTorch log1p用法及代碼示例
- Python PyTorch lstsq用法及代碼示例
- Python PyTorch lerp用法及代碼示例
- Python PyTorch lt用法及代碼示例
- Python PyTorch lgamma用法及代碼示例
- Python PyTorch lazy_apply用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.load。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。