当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python torch.distributed.optim.ZeroRedundancyOptimizer用法及代码示例


用法:

class torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)

参数

params(Iterable) -torch.TensorIterable 给出所有参数,这些参数将跨等级分片。

关键字参数

  • optimizer_class(torch.nn.Optimizer) -局部优化器的类。

  • process_group(ProcessGroup, 可选的) -torch.distributed ProcessGroup(默认值:dist.group.WORLDtorch.distributed.init_process_group() 初始化)。

  • parameters_as_bucket_view(bool,可选的) -如果 True ,参数被打包到桶中以加速通信,并且 param.data 字段指向不同偏移量的桶视图;如果 False ,每个单独的参数单独通信,每个 params.data 保持不变(默认值:False)。

  • overlap_with_ddp(bool,可选的) -如果Truestep()DistributedDataParallel的梯度同步重叠;这需要 (1) optimizer_class 参数的函数优化器或具有等效函数的优化器,以及 (2) 注册从 ddp_zero_hook.py 中的一个函数构造的 DDP 通信钩子;参数被打包到与 DistributedDataParallel 中的匹配的桶中,这意味着 parameters_as_bucket_view 参数被忽略。如果 Falsestep() 在向后传递后不相交地运行(按正常情况)。 (默认:False)

  • **defaults-任何尾随参数,它们被转发到本地优化器。

此类包装任意 optim.Optimizer 并将其状态分片到组中的等级,如 ZeRO 所述。每个等级的局部优化器实例只负责更新大约1 / world_size参数,因此只需要保持1 / world_size优化器状态。在本地更新参数后,每个 rank 将其参数广播给所有其他对等节点,以保持所有模型副本处于相同状态。 ZeroRedundancyOptimizer 可以与 torch.nn.parallel.DistributedDataParallel 结合使用,以减少 per-rank 峰值内存消耗。

ZeroRedundancyOptimizer 使用sorted-greedy 算法在每个等级打包多个参数。每个参数都属于一个等级,并且不分等级。分区是任意的,可能与参数注册或使用顺序不匹配。

例子:

>>> import torch.nn as nn
>>> from torch.distributed.optim import ZeroRedundancyOptimizer
>>> from torch.nn.parallel import DistributedDataParallel as DDP

>>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
>>> ddp = DDP(model, device_ids=[rank])
>>> opt = ZeroRedundancyOptimizer(
>>>     ddp.parameters(),
>>>     optimizer_class=torch.optim.Adam,
>>>     lr=0.01
>>> )
>>> ddp(inputs).sum().backward()
>>> opt.step()

警告

目前,ZeroRedundancyOptimizer 要求所有 passed-in 参数都是相同的密集类型。

警告

如果您通过 overlap_with_ddp=True ,请注意以下事项:鉴于当前实现 DistributedDataParallelZeroRedundancyOptimizer 重叠的方式,前两个或三个训练迭代不会在优化器步骤中执行参数更新,具体取决于是否 static_graph=Falsestatic_graph=True ,分别。这是因为它需要有关 DistributedDataParallel 使用的梯度分桶策略的信息,如果 static_graph=False 直到第二次前向传递或如果 static_graph=True 直到第三次前向传递才最终确定。要对此进行调整,一种选择是预先添加虚拟输入。

警告

ZeroRedundancyOptimizer 是实验性的,可能会发生变化。

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributed.optim.ZeroRedundancyOptimizer。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。