当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch DistributedDataParallel.register_comm_hook用法及代码示例


本文简要介绍python语言中 torch.nn.parallel.DistributedDataParallel.register_comm_hook 的用法。

用法:

register_comm_hook(state, hook)

参数

  • state(object) -

    在训练过程中传递给钩子以维护任何状态信息。示例包括梯度压缩中的错误反馈、与 GossipGrad 中的 next 通信的对等点等。

    它由每个工作人员本地存储,并由工作人员上的所有梯度张量共享。

  • hook(可调用的) -

    可使用以下签名调用:hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]

    一旦存储桶准备好,就会调用此函数。该钩子可以执行所需的任何处理,并返回一个 Future 指示任何异步工作的完成(例如:allreduce)。如果钩子不执行任何通信,它仍然必须返回一个完整的 Future。未来应该保存梯度桶张量的新值。一旦存储桶准备好,c10d reducer 将调用此钩子并使用 Future 返回的张量并将梯度复制到各个参数。请注意,Future 的返回类型必须是单个张量。

    我们还提供了一个名为 get_future 的 API 来检索与 c10d.ProcessGroup.Work 的完成相关的 Future。 get_future 当前支持 NCCL,也支持 GLOO 和 MPI 上的大多数操作,除了点对点操作(发送/接收)。

注册一个通信钩子,它是一种增强函数,为用户提供了一个灵活的钩子,用户可以在其中指定 DDP 如何在多个工作人员之间聚合梯度。

这个钩子对于研究人员尝试新想法非常有用。例如,此钩子可用于实现多种算法,例如GossipGrad和梯度压缩,这些算法涉及在运行分布式DataParallel训练时参数同步的不同通信策略。

警告

Grad bucket的张量不会被world_size预除。在 allreduce 等操作的情况下,用户负责除以world_size。

警告

DDP 通信钩子只能注册一次,并且应该在反向调用之前注册。

警告

hook 返回的 Future 对象应该包含一个与 grad 桶中的张量具有相同形状的单个张量。

警告

get_future API 支持 NCCL,部分支持 GLOO 和 MPI 后端(不支持 peer-to-peer 操作,如发送/接收),并将返回 torch.futures.Future

例子:

下面是一个返回相同张量的 noop 钩子的示例。

>>> def noop(state: object, bucket: dist.GradBucket): -> torch.futures.Future[torch.Tensor]
>>>     fut = torch.futures.Future()
>>>     fut.set_result(bucket.buffer())
>>>     return fut
>>> ddp.register_comm_hook(state=None, hook=noop)

例子:

下面是一个并行 SGD 算法的示例,其中梯度在 allreduce 之前编码,然后在 allreduce 之后解码。

>>> def encode_and_decode(state: object, bucket: dist.GradBucket): -> torch.futures.Future[torch.Tensor]
>>>     encoded_tensor = encode(bucket.buffer()) # encode gradients
>>>     fut = torch.distributed.all_reduce(encoded_tensor).get_future()
>>>     # Define the then callback to decode.
>>>     def decode(fut):
>>>         decoded_tensor = decode(fut.value()[0]) # decode gradients
>>>         return decoded_tensor
>>>     return fut.then(decode)
>>> ddp.register_comm_hook(state=None, hook=encode_and_decode)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.parallel.DistributedDataParallel.register_comm_hook。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。