当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch batched_powerSGD_hook用法及代码示例


本文简要介绍python语言中 torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook 的用法。

用法:

torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook(state, bucket)

参数

  • state(PowerSGDState) -用于配置压缩率和支持错误反馈、热启动等的状态信息。要调整压缩配置,主要需要调整 matrix_approximation_rankstart_powerSGD_iter

  • bucket(dist.GradBucket) -存储一维扁平梯度张量的桶,该张量批量处理多个每个变量的张量。请注意,由于 DDP 通信钩子仅支持单进程单设备模式,因此该存储桶中仅存储了一个张量。

返回

通信的未来处理程序,它更新到位的梯度。

这个 DDP 通信钩子实现了一个简化的 PowerSGD 梯度压缩算法,在.此变体不会逐层压缩梯度,而是压缩对所有梯度进行批处理的扁平化输入张量。因此它是快点powerSGD_hook(),但通常会导致精度低得多, 除非matrix_approximation_rank是 1。

警告

此处增加 matrix_approximation_rank 不一定会提高准确性,因为在没有列/行对齐的情况下对每个参数张量进行批处理可能会破坏低秩结构。因此,用户应始终首先考虑powerSGD_hook(),并且仅当matrix_approximation_rank为1时能够达到满意的精度时才考虑此变体。

一旦梯度张量在所有工作人员中聚合,此挂钩将按如下方式应用压缩:

  1. 将输入扁平化的 1D 梯度张量视为具有 0 个填充的 square-shaped 张量 M;

  2. 创建两个低秩张量 P 和 Q 用于分解 M,使得 M = PQ^T,其中 Q 从标准正态分布初始化并正交化;

  3. 计算 P,它等于 MQ;

  4. 全部减少 P;

  5. 正交化 P;

  6. 计算 Q,它大约等于 M^TP;

  7. 全部减少 Q;

  8. 计算 M,它大约等于 PQ^T。

  9. 将输入张量截断为原始长度。

请注意,此通信挂钩对第一个 state.start_powerSGD_iter 迭代强制执行 vanilla allreduce。这不仅让用户可以更好地控制加速和准确性之间的权衡,还有助于为未来的通信钩子开发人员抽象出 DDP 内部优化的一些复杂性。

例子:

>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。