用法:
class torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)params(
Iterable) -torch.Tensor的Iterable給出所有參數,這些參數將跨等級分片。optimizer_class(
torch.nn.Optimizer) -局部優化器的類。process_group(
ProcessGroup, 可選的) -torch.distributedProcessGroup(默認值:dist.group.WORLD由torch.distributed.init_process_group()初始化)。parameters_as_bucket_view(bool,可選的) -如果
True,參數被打包到桶中以加速通信,並且param.data字段指向不同偏移量的桶視圖;如果False,每個單獨的參數單獨通信,每個params.data保持不變(默認值:False)。overlap_with_ddp(bool,可選的) -如果
True,step()與DistributedDataParallel的梯度同步重疊;這需要 (1)optimizer_class參數的函數優化器或具有等效函數的優化器,以及 (2) 注冊從ddp_zero_hook.py中的一個函數構造的 DDP 通信鉤子;參數被打包到與DistributedDataParallel中的匹配的桶中,這意味著parameters_as_bucket_view參數被忽略。如果False,step()在向後傳遞後不相交地運行(按正常情況)。 (默認:False)**defaults-任何尾隨參數,它們被轉發到本地優化器。
此類包裝任意
optim.Optimizer並將其狀態分片到組中的等級,如 ZeRO 所述。每個等級的局部優化器實例隻負責更新大約1 / world_size參數,因此隻需要保持1 / world_size優化器狀態。在本地更新參數後,每個 rank 將其參數廣播給所有其他對等節點,以保持所有模型副本處於相同狀態。ZeroRedundancyOptimizer可以與torch.nn.parallel.DistributedDataParallel結合使用,以減少 per-rank 峰值內存消耗。ZeroRedundancyOptimizer使用sorted-greedy 算法在每個等級打包多個參數。每個參數都屬於一個等級,並且不分等級。分區是任意的,可能與參數注冊或使用順序不匹配。例子:
>>> import torch.nn as nn >>> from torch.distributed.optim import ZeroRedundancyOptimizer >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)]) >>> ddp = DDP(model, device_ids=[rank]) >>> opt = ZeroRedundancyOptimizer( >>> ddp.parameters(), >>> optimizer_class=torch.optim.Adam, >>> lr=0.01 >>> ) >>> ddp(inputs).sum().backward() >>> opt.step()警告
目前,
ZeroRedundancyOptimizer要求所有 passed-in 參數都是相同的密集類型。警告
如果您通過
overlap_with_ddp=True,請注意以下事項:鑒於當前實現DistributedDataParallel與ZeroRedundancyOptimizer重疊的方式,前兩個或三個訓練迭代不會在優化器步驟中執行參數更新,具體取決於是否static_graph=False或static_graph=True,分別。這是因為它需要有關DistributedDataParallel使用的梯度分桶策略的信息,如果static_graph=False直到第二次前向傳遞或如果static_graph=True直到第三次前向傳遞才最終確定。要對此進行調整,一種選擇是預先添加虛擬輸入。警告
ZeroRedundancyOptimizer 是實驗性的,可能會發生變化。
參數:
關鍵字參數:
相關用法
- Python torch.distributed.optim.DistributedOptimizer用法及代碼示例
- Python torch.distributed.rpc.rpc_async用法及代碼示例
- Python torch.distributed.TCPStore用法及代碼示例
- Python torch.distributed.pipeline.sync.skip.skippable.stash用法及代碼示例
- Python torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook用法及代碼示例
- Python torch.distributed.all_reduce用法及代碼示例
- Python torch.distributed.gather_object用法及代碼示例
- Python torch.distributed.Store.set_timeout用法及代碼示例
- Python torch.distributed.rpc.functions.async_execution用法及代碼示例
- Python torch.distributed.all_gather_object用法及代碼示例
- Python torch.distributed.elastic.timer.expires用法及代碼示例
- Python torch.distributed.elastic.agent.server.ElasticAgent用法及代碼示例
- Python torch.distributed.rpc.rpc_sync用法及代碼示例
- Python torch.distributed.elastic.metrics.prof用法及代碼示例
- Python torch.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent用法及代碼示例
- Python torch.distributed.FileStore用法及代碼示例
- Python torch.distributed.Store.num_keys用法及代碼示例
- Python torch.distributed.pipeline.sync.skip.skippable.skippable用法及代碼示例
- Python torch.distributed.all_to_all用法及代碼示例
- Python torch.distributed.Store.get用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.distributed.optim.ZeroRedundancyOptimizer。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。
