本文简要介绍python语言中 torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook
的用法。
用法:
torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook(state, bucket)
state(PowerSGDState) -用于配置压缩率和支持错误反馈、热启动等的状态信息。要调整压缩配置,主要需要调整
matrix_approximation_rank
和start_powerSGD_iter
。bucket(dist.GradBucket) -存储一维扁平梯度张量的桶,该张量批量处理多个每个变量的张量。请注意,由于 DDP 通信钩子仅支持单进程单设备模式,因此该存储桶中仅存储了一个张量。
通信的未来处理程序,它更新到位的梯度。
这个 DDP 通信钩子实现了一个简化的 PowerSGD 梯度压缩算法,在纸.此变体不会逐层压缩梯度,而是压缩对所有梯度进行批处理的扁平化输入张量。因此它是快点比
powerSGD_hook()
,但通常会导致精度低得多, 除非matrix_approximation_rank
是 1。警告
此处增加
matrix_approximation_rank
不一定会提高准确性,因为在没有列/行对齐的情况下对每个参数张量进行批处理可能会破坏低秩结构。因此,用户应始终首先考虑powerSGD_hook()
,并且仅当matrix_approximation_rank
为1时能够达到满意的精度时才考虑此变体。一旦梯度张量在所有工作人员中聚合,此挂钩将按如下方式应用压缩:
将输入扁平化的 1D 梯度张量视为具有 0 个填充的 square-shaped 张量 M;
创建两个低秩张量 P 和 Q 用于分解 M,使得 M = PQ^T,其中 Q 从标准正态分布初始化并正交化;
计算 P,它等于 MQ;
全部减少 P;
正交化 P;
计算 Q,它大约等于 M^TP;
全部减少 Q;
计算 M,它大约等于 PQ^T。
将输入张量截断为原始长度。
请注意,此通信挂钩对第一个
state.start_powerSGD_iter
迭代强制执行 vanilla allreduce。这不仅让用户可以更好地控制加速和准确性之间的权衡,还有助于为未来的通信钩子开发人员抽象出 DDP 内部优化的一些复杂性。>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1) >>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
例子:
参数:
返回:
相关用法
- Python PyTorch backward用法及代码示例
- Python PyTorch baddbmm用法及代码示例
- Python PyTorch bincount用法及代码示例
- Python PyTorch bitwise_right_shift用法及代码示例
- Python PyTorch bernoulli用法及代码示例
- Python PyTorch bitwise_and用法及代码示例
- Python PyTorch bitwise_not用法及代码示例
- Python PyTorch binary_cross_entropy用法及代码示例
- Python PyTorch bitwise_xor用法及代码示例
- Python PyTorch binary_cross_entropy_with_logits用法及代码示例
- Python PyTorch bleu_score用法及代码示例
- Python PyTorch broadcast_tensors用法及代码示例
- Python PyTorch build_vocab_from_iterator用法及代码示例
- Python PyTorch broadcast_object_list用法及代码示例
- Python PyTorch broadcast_shapes用法及代码示例
- Python PyTorch bitwise_or用法及代码示例
- Python PyTorch bitwise_left_shift用法及代码示例
- Python PyTorch bf16_compress_wrapper用法及代码示例
- Python PyTorch bmm用法及代码示例
- Python PyTorch broadcast_to用法及代码示例
- Python PyTorch bf16_compress_hook用法及代码示例
- Python PyTorch bucketize用法及代码示例
- Python PyTorch block_diag用法及代码示例
- Python PyTorch frexp用法及代码示例
- Python PyTorch jvp用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。