當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch DistributedDataParallel.register_comm_hook用法及代碼示例

本文簡要介紹python語言中 torch.nn.parallel.DistributedDataParallel.register_comm_hook 的用法。

用法:

register_comm_hook(state, hook)

參數

  • state(object) -

    在訓練過程中傳遞給鉤子以維護任何狀態信息。示例包括梯度壓縮中的錯誤反饋、與 GossipGrad 中的 next 通信的對等點等。

    它由每個工作人員本地存儲,並由工作人員上的所有梯度張量共享。

  • hook(可調用的) -

    可使用以下簽名調用:hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]

    一旦存儲桶準備好,就會調用此函數。該鉤子可以執行所需的任何處理,並返回一個 Future 指示任何異步工作的完成(例如:allreduce)。如果鉤子不執行任何通信,它仍然必須返回一個完整的 Future。未來應該保存梯度桶張量的新值。一旦存儲桶準備好,c10d reducer 將調用此鉤子並使用 Future 返回的張量並將梯度複製到各個參數。請注意,Future 的返回類型必須是單個張量。

    我們還提供了一個名為 get_future 的 API 來檢索與 c10d.ProcessGroup.Work 的完成相關的 Future。 get_future 當前支持 NCCL,也支持 GLOO 和 MPI 上的大多數操作,除了點對點操作(發送/接收)。

注冊一個通信鉤子,它是一種增強函數,為用戶提供了一個靈活的鉤子,用戶可以在其中指定 DDP 如何在多個工作人員之間聚合梯度。

這個鉤子對於研究人員嘗試新想法非常有用。例如,此鉤子可用於實現多種算法,例如GossipGrad和梯度壓縮,這些算法涉及在運行分布式DataParallel訓練時參數同步的不同通信策略。

警告

Grad bucket的張量不會被world_size預除。在 allreduce 等操作的情況下,用戶負責除以world_size。

警告

DDP 通信鉤子隻能注冊一次,並且應該在反向調用之前注冊。

警告

hook 返回的 Future 對象應該包含一個與 grad 桶中的張量具有相同形狀的單個張量。

警告

get_future API 支持 NCCL,部分支持 GLOO 和 MPI 後端(不支持 peer-to-peer 操作,如發送/接收),並將返回 torch.futures.Future

例子:

下麵是一個返回相同張量的 noop 鉤子的示例。

>>> def noop(state: object, bucket: dist.GradBucket): -> torch.futures.Future[torch.Tensor]
>>>     fut = torch.futures.Future()
>>>     fut.set_result(bucket.buffer())
>>>     return fut
>>> ddp.register_comm_hook(state=None, hook=noop)

例子:

下麵是一個並行 SGD 算法的示例,其中梯度在 allreduce 之前編碼,然後在 allreduce 之後解碼。

>>> def encode_and_decode(state: object, bucket: dist.GradBucket): -> torch.futures.Future[torch.Tensor]
>>>     encoded_tensor = encode(bucket.buffer()) # encode gradients
>>>     fut = torch.distributed.all_reduce(encoded_tensor).get_future()
>>>     # Define the then callback to decode.
>>>     def decode(fut):
>>>         decoded_tensor = decode(fut.value()[0]) # decode gradients
>>>         return decoded_tensor
>>>     return fut.then(decode)
>>> ddp.register_comm_hook(state=None, hook=encode_and_decode)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.parallel.DistributedDataParallel.register_comm_hook。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。