用法:
class torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)
params(
Iterable
) -torch.Tensor
的Iterable
給出所有參數,這些參數將跨等級分片。optimizer_class(
torch.nn.Optimizer
) -局部優化器的類。process_group(
ProcessGroup
, 可選的) -torch.distributed
ProcessGroup
(默認值:dist.group.WORLD
由torch.distributed.init_process_group()
初始化)。parameters_as_bucket_view(bool,可選的) -如果
True
,參數被打包到桶中以加速通信,並且param.data
字段指向不同偏移量的桶視圖;如果False
,每個單獨的參數單獨通信,每個params.data
保持不變(默認值:False
)。overlap_with_ddp(bool,可選的) -如果
True
,step()
與DistributedDataParallel
的梯度同步重疊;這需要 (1)optimizer_class
參數的函數優化器或具有等效函數的優化器,以及 (2) 注冊從ddp_zero_hook.py
中的一個函數構造的 DDP 通信鉤子;參數被打包到與DistributedDataParallel
中的匹配的桶中,這意味著parameters_as_bucket_view
參數被忽略。如果False
,step()
在向後傳遞後不相交地運行(按正常情況)。 (默認:False
)**defaults-任何尾隨參數,它們被轉發到本地優化器。
此類包裝任意
optim.Optimizer
並將其狀態分片到組中的等級,如 ZeRO 所述。每個等級的局部優化器實例隻負責更新大約1 / world_size
參數,因此隻需要保持1 / world_size
優化器狀態。在本地更新參數後,每個 rank 將其參數廣播給所有其他對等節點,以保持所有模型副本處於相同狀態。ZeroRedundancyOptimizer
可以與torch.nn.parallel.DistributedDataParallel
結合使用,以減少 per-rank 峰值內存消耗。ZeroRedundancyOptimizer
使用sorted-greedy 算法在每個等級打包多個參數。每個參數都屬於一個等級,並且不分等級。分區是任意的,可能與參數注冊或使用順序不匹配。例子:
>>> import torch.nn as nn >>> from torch.distributed.optim import ZeroRedundancyOptimizer >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)]) >>> ddp = DDP(model, device_ids=[rank]) >>> opt = ZeroRedundancyOptimizer( >>> ddp.parameters(), >>> optimizer_class=torch.optim.Adam, >>> lr=0.01 >>> ) >>> ddp(inputs).sum().backward() >>> opt.step()
警告
目前,
ZeroRedundancyOptimizer
要求所有 passed-in 參數都是相同的密集類型。警告
如果您通過
overlap_with_ddp=True
,請注意以下事項:鑒於當前實現DistributedDataParallel
與ZeroRedundancyOptimizer
重疊的方式,前兩個或三個訓練迭代不會在優化器步驟中執行參數更新,具體取決於是否static_graph=False
或static_graph=True
,分別。這是因為它需要有關DistributedDataParallel
使用的梯度分桶策略的信息,如果static_graph=False
直到第二次前向傳遞或如果static_graph=True
直到第三次前向傳遞才最終確定。要對此進行調整,一種選擇是預先添加虛擬輸入。警告
ZeroRedundancyOptimizer 是實驗性的,可能會發生變化。
參數:
關鍵字參數:
相關用法
- Python torch.distributed.optim.DistributedOptimizer用法及代碼示例
- Python torch.distributed.rpc.rpc_async用法及代碼示例
- Python torch.distributed.TCPStore用法及代碼示例
- Python torch.distributed.pipeline.sync.skip.skippable.stash用法及代碼示例
- Python torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook用法及代碼示例
- Python torch.distributed.all_reduce用法及代碼示例
- Python torch.distributed.gather_object用法及代碼示例
- Python torch.distributed.Store.set_timeout用法及代碼示例
- Python torch.distributed.rpc.functions.async_execution用法及代碼示例
- Python torch.distributed.all_gather_object用法及代碼示例
- Python torch.distributed.elastic.timer.expires用法及代碼示例
- Python torch.distributed.elastic.agent.server.ElasticAgent用法及代碼示例
- Python torch.distributed.rpc.rpc_sync用法及代碼示例
- Python torch.distributed.elastic.metrics.prof用法及代碼示例
- Python torch.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent用法及代碼示例
- Python torch.distributed.FileStore用法及代碼示例
- Python torch.distributed.Store.num_keys用法及代碼示例
- Python torch.distributed.pipeline.sync.skip.skippable.skippable用法及代碼示例
- Python torch.distributed.all_to_all用法及代碼示例
- Python torch.distributed.Store.get用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.distributed.optim.ZeroRedundancyOptimizer。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。