當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python torch.distributed.optim.ZeroRedundancyOptimizer用法及代碼示例


用法:

class torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)

參數

params(Iterable) -torch.TensorIterable 給出所有參數,這些參數將跨等級分片。

關鍵字參數

  • optimizer_class(torch.nn.Optimizer) -局部優化器的類。

  • process_group(ProcessGroup, 可選的) -torch.distributed ProcessGroup(默認值:dist.group.WORLDtorch.distributed.init_process_group() 初始化)。

  • parameters_as_bucket_view(bool,可選的) -如果 True ,參數被打包到桶中以加速通信,並且 param.data 字段指向不同偏移量的桶視圖;如果 False ,每個單獨的參數單獨通信,每個 params.data 保持不變(默認值:False)。

  • overlap_with_ddp(bool,可選的) -如果Truestep()DistributedDataParallel的梯度同步重疊;這需要 (1) optimizer_class 參數的函數優化器或具有等效函數的優化器,以及 (2) 注冊從 ddp_zero_hook.py 中的一個函數構造的 DDP 通信鉤子;參數被打包到與 DistributedDataParallel 中的匹配的桶中,這意味著 parameters_as_bucket_view 參數被忽略。如果 Falsestep() 在向後傳遞後不相交地運行(按正常情況)。 (默認:False)

  • **defaults-任何尾隨參數,它們被轉發到本地優化器。

此類包裝任意 optim.Optimizer 並將其狀態分片到組中的等級,如 ZeRO 所述。每個等級的局部優化器實例隻負責更新大約1 / world_size參數,因此隻需要保持1 / world_size優化器狀態。在本地更新參數後,每個 rank 將其參數廣播給所有其他對等節點,以保持所有模型副本處於相同狀態。 ZeroRedundancyOptimizer 可以與 torch.nn.parallel.DistributedDataParallel 結合使用,以減少 per-rank 峰值內存消耗。

ZeroRedundancyOptimizer 使用sorted-greedy 算法在每個等級打包多個參數。每個參數都屬於一個等級,並且不分等級。分區是任意的,可能與參數注冊或使用順序不匹配。

例子:

>>> import torch.nn as nn
>>> from torch.distributed.optim import ZeroRedundancyOptimizer
>>> from torch.nn.parallel import DistributedDataParallel as DDP

>>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
>>> ddp = DDP(model, device_ids=[rank])
>>> opt = ZeroRedundancyOptimizer(
>>>     ddp.parameters(),
>>>     optimizer_class=torch.optim.Adam,
>>>     lr=0.01
>>> )
>>> ddp(inputs).sum().backward()
>>> opt.step()

警告

目前,ZeroRedundancyOptimizer 要求所有 passed-in 參數都是相同的密集類型。

警告

如果您通過 overlap_with_ddp=True ,請注意以下事項:鑒於當前實現 DistributedDataParallelZeroRedundancyOptimizer 重疊的方式,前兩個或三個訓練迭代不會在優化器步驟中執行參數更新,具體取決於是否 static_graph=Falsestatic_graph=True ,分別。這是因為它需要有關 DistributedDataParallel 使用的梯度分桶策略的信息,如果 static_graph=False 直到第二次前向傳遞或如果 static_graph=True 直到第三次前向傳遞才最終確定。要對此進行調整,一種選擇是預先添加虛擬輸入。

警告

ZeroRedundancyOptimizer 是實驗性的,可能會發生變化。

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.distributed.optim.ZeroRedundancyOptimizer。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。