本文简要介绍python语言中 torch.nn.parallel.DistributedDataParallel.register_comm_hook
的用法。
用法:
register_comm_hook(state, hook)
state(object) -
在训练过程中传递给钩子以维护任何状态信息。示例包括梯度压缩中的错误反馈、与 GossipGrad 中的 next 通信的对等点等。
它由每个工作人员本地存储,并由工作人员上的所有梯度张量共享。
hook(可调用的) -
可使用以下签名调用:
hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]
:一旦存储桶准备好,就会调用此函数。该钩子可以执行所需的任何处理,并返回一个 Future 指示任何异步工作的完成(例如:allreduce)。如果钩子不执行任何通信,它仍然必须返回一个完整的 Future。未来应该保存梯度桶张量的新值。一旦存储桶准备好,c10d reducer 将调用此钩子并使用 Future 返回的张量并将梯度复制到各个参数。请注意,Future 的返回类型必须是单个张量。
我们还提供了一个名为
get_future
的 API 来检索与c10d.ProcessGroup.Work
的完成相关的 Future。get_future
当前支持 NCCL,也支持 GLOO 和 MPI 上的大多数操作,除了点对点操作(发送/接收)。
注册一个通信钩子,它是一种增强函数,为用户提供了一个灵活的钩子,用户可以在其中指定 DDP 如何在多个工作人员之间聚合梯度。
这个钩子对于研究人员尝试新想法非常有用。例如,此钩子可用于实现多种算法,例如GossipGrad和梯度压缩,这些算法涉及在运行分布式DataParallel训练时参数同步的不同通信策略。
警告
Grad bucket的张量不会被world_size预除。在 allreduce 等操作的情况下,用户负责除以world_size。
警告
DDP 通信钩子只能注册一次,并且应该在反向调用之前注册。
警告
hook 返回的 Future 对象应该包含一个与 grad 桶中的张量具有相同形状的单个张量。
警告
get_future
API 支持 NCCL,部分支持 GLOO 和 MPI 后端(不支持 peer-to-peer 操作,如发送/接收),并将返回torch.futures.Future
。下面是一个返回相同张量的 noop 钩子的示例。
>>> def noop(state: object, bucket: dist.GradBucket): -> torch.futures.Future[torch.Tensor] >>> fut = torch.futures.Future() >>> fut.set_result(bucket.buffer()) >>> return fut
>>> ddp.register_comm_hook(state=None, hook=noop)
下面是一个并行 SGD 算法的示例,其中梯度在 allreduce 之前编码,然后在 allreduce 之后解码。
>>> def encode_and_decode(state: object, bucket: dist.GradBucket): -> torch.futures.Future[torch.Tensor] >>> encoded_tensor = encode(bucket.buffer()) # encode gradients >>> fut = torch.distributed.all_reduce(encoded_tensor).get_future() >>> # Define the then callback to decode. >>> def decode(fut): >>> decoded_tensor = decode(fut.value()[0]) # decode gradients >>> return decoded_tensor >>> return fut.then(decode)
>>> ddp.register_comm_hook(state=None, hook=encode_and_decode)
例子:
例子:
参数:
相关用法
- Python PyTorch DistributedDataParallel.join用法及代码示例
- Python PyTorch DistributedDataParallel.no_sync用法及代码示例
- Python PyTorch DistributedDataParallel用法及代码示例
- Python PyTorch DistributedModelParallel用法及代码示例
- Python PyTorch DistributedSampler用法及代码示例
- Python PyTorch DistributedModelParallel.named_parameters用法及代码示例
- Python PyTorch DistributedModelParallel.state_dict用法及代码示例
- Python PyTorch DistributedModelParallel.named_buffers用法及代码示例
- Python PyTorch DistributedOptimizer用法及代码示例
- Python PyTorch Dirichlet用法及代码示例
- Python PyTorch DeQuantize用法及代码示例
- Python PyTorch DenseArch用法及代码示例
- Python PyTorch DeepFM用法及代码示例
- Python PyTorch DataFrameMaker用法及代码示例
- Python PyTorch DLRM用法及代码示例
- Python PyTorch Dropout用法及代码示例
- Python PyTorch Dropout3d用法及代码示例
- Python PyTorch DataParallel用法及代码示例
- Python PyTorch Decompressor用法及代码示例
- Python PyTorch Dropout2d用法及代码示例
- Python PyTorch DeepFM.forward用法及代码示例
- Python PyTorch Demultiplexer用法及代码示例
- Python PyTorch DatasetFolder.find_classes用法及代码示例
- Python PyTorch frexp用法及代码示例
- Python PyTorch jvp用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.parallel.DistributedDataParallel.register_comm_hook。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。