當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch MpModelWrapper用法及代碼示例

本文簡要介紹python語言中 torch_xla.distributed.xla_multiprocessing.MpModelWrapper 的用法。

用法:

class torch_xla.distributed.xla_multiprocessing.MpModelWrapper(model)

當使用fork 方法時,包裝一個模型以最小化主機內存使用。

此類應與spawn(…, start_method=’fork’) API 一起使用,以盡量減少主機內存的使用。不是在每個多處理進程上創建模型,從而複製模型的初始主機內存,而是在全局範圍內創建一次模型,然後移動到 spawn() 目標函數內的每個設備中。例子:

WRAPPED_MODEL = xmp.MpModelWrapper(MyNetwork())

def _mp_fn(index, ...):
  device = xm.xla_device()
  model = WRAPPED_MODEL.to(device)
  ...

xmp.spawn(_mp_fn, ..., start_method='fork')

這種方法有兩個優點。首先,如果僅使用內存頁麵的一個副本來托管原始模型的權重,其次它通過在此過程中降低係統內存的負載,將包裝模型的移動序列化到每個設備中。

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch_xla.distributed.xla_multiprocessing.MpModelWrapper。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。