用法:
class torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)
params(
Iterable
) -torch.Tensor
的Iterable
给出所有参数,这些参数将跨等级分片。optimizer_class(
torch.nn.Optimizer
) -局部优化器的类。process_group(
ProcessGroup
, 可选的) -torch.distributed
ProcessGroup
(默认值:dist.group.WORLD
由torch.distributed.init_process_group()
初始化)。parameters_as_bucket_view(bool,可选的) -如果
True
,参数被打包到桶中以加速通信,并且param.data
字段指向不同偏移量的桶视图;如果False
,每个单独的参数单独通信,每个params.data
保持不变(默认值:False
)。overlap_with_ddp(bool,可选的) -如果
True
,step()
与DistributedDataParallel
的梯度同步重叠;这需要 (1)optimizer_class
参数的函数优化器或具有等效函数的优化器,以及 (2) 注册从ddp_zero_hook.py
中的一个函数构造的 DDP 通信钩子;参数被打包到与DistributedDataParallel
中的匹配的桶中,这意味着parameters_as_bucket_view
参数被忽略。如果False
,step()
在向后传递后不相交地运行(按正常情况)。 (默认:False
)**defaults-任何尾随参数,它们被转发到本地优化器。
此类包装任意
optim.Optimizer
并将其状态分片到组中的等级,如 ZeRO 所述。每个等级的局部优化器实例只负责更新大约1 / world_size
参数,因此只需要保持1 / world_size
优化器状态。在本地更新参数后,每个 rank 将其参数广播给所有其他对等节点,以保持所有模型副本处于相同状态。ZeroRedundancyOptimizer
可以与torch.nn.parallel.DistributedDataParallel
结合使用,以减少 per-rank 峰值内存消耗。ZeroRedundancyOptimizer
使用sorted-greedy 算法在每个等级打包多个参数。每个参数都属于一个等级,并且不分等级。分区是任意的,可能与参数注册或使用顺序不匹配。例子:
>>> import torch.nn as nn >>> from torch.distributed.optim import ZeroRedundancyOptimizer >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)]) >>> ddp = DDP(model, device_ids=[rank]) >>> opt = ZeroRedundancyOptimizer( >>> ddp.parameters(), >>> optimizer_class=torch.optim.Adam, >>> lr=0.01 >>> ) >>> ddp(inputs).sum().backward() >>> opt.step()
警告
目前,
ZeroRedundancyOptimizer
要求所有 passed-in 参数都是相同的密集类型。警告
如果您通过
overlap_with_ddp=True
,请注意以下事项:鉴于当前实现DistributedDataParallel
与ZeroRedundancyOptimizer
重叠的方式,前两个或三个训练迭代不会在优化器步骤中执行参数更新,具体取决于是否static_graph=False
或static_graph=True
,分别。这是因为它需要有关DistributedDataParallel
使用的梯度分桶策略的信息,如果static_graph=False
直到第二次前向传递或如果static_graph=True
直到第三次前向传递才最终确定。要对此进行调整,一种选择是预先添加虚拟输入。警告
ZeroRedundancyOptimizer 是实验性的,可能会发生变化。
参数:
关键字参数:
相关用法
- Python torch.distributed.optim.DistributedOptimizer用法及代码示例
- Python torch.distributed.rpc.rpc_async用法及代码示例
- Python torch.distributed.TCPStore用法及代码示例
- Python torch.distributed.pipeline.sync.skip.skippable.stash用法及代码示例
- Python torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook用法及代码示例
- Python torch.distributed.all_reduce用法及代码示例
- Python torch.distributed.gather_object用法及代码示例
- Python torch.distributed.Store.set_timeout用法及代码示例
- Python torch.distributed.rpc.functions.async_execution用法及代码示例
- Python torch.distributed.all_gather_object用法及代码示例
- Python torch.distributed.elastic.timer.expires用法及代码示例
- Python torch.distributed.elastic.agent.server.ElasticAgent用法及代码示例
- Python torch.distributed.rpc.rpc_sync用法及代码示例
- Python torch.distributed.elastic.metrics.prof用法及代码示例
- Python torch.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent用法及代码示例
- Python torch.distributed.FileStore用法及代码示例
- Python torch.distributed.Store.num_keys用法及代码示例
- Python torch.distributed.pipeline.sync.skip.skippable.skippable用法及代码示例
- Python torch.distributed.all_to_all用法及代码示例
- Python torch.distributed.Store.get用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributed.optim.ZeroRedundancyOptimizer。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。