本文簡要介紹python語言中 torch.nn.parallel.DistributedDataParallel.register_comm_hook 的用法。
用法:
register_comm_hook(state, hook)state(object) -
在訓練過程中傳遞給鉤子以維護任何狀態信息。示例包括梯度壓縮中的錯誤反饋、與 GossipGrad 中的 next 通信的對等點等。
它由每個工作人員本地存儲,並由工作人員上的所有梯度張量共享。
hook(可調用的) -
可使用以下簽名調用:
hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:一旦存儲桶準備好,就會調用此函數。該鉤子可以執行所需的任何處理,並返回一個 Future 指示任何異步工作的完成(例如:allreduce)。如果鉤子不執行任何通信,它仍然必須返回一個完整的 Future。未來應該保存梯度桶張量的新值。一旦存儲桶準備好,c10d reducer 將調用此鉤子並使用 Future 返回的張量並將梯度複製到各個參數。請注意,Future 的返回類型必須是單個張量。
我們還提供了一個名為
get_future的 API 來檢索與c10d.ProcessGroup.Work的完成相關的 Future。get_future當前支持 NCCL,也支持 GLOO 和 MPI 上的大多數操作,除了點對點操作(發送/接收)。
注冊一個通信鉤子,它是一種增強函數,為用戶提供了一個靈活的鉤子,用戶可以在其中指定 DDP 如何在多個工作人員之間聚合梯度。
這個鉤子對於研究人員嘗試新想法非常有用。例如,此鉤子可用於實現多種算法,例如GossipGrad和梯度壓縮,這些算法涉及在運行分布式DataParallel訓練時參數同步的不同通信策略。
警告
Grad bucket的張量不會被world_size預除。在 allreduce 等操作的情況下,用戶負責除以world_size。
警告
DDP 通信鉤子隻能注冊一次,並且應該在反向調用之前注冊。
警告
hook 返回的 Future 對象應該包含一個與 grad 桶中的張量具有相同形狀的單個張量。
警告
get_futureAPI 支持 NCCL,部分支持 GLOO 和 MPI 後端(不支持 peer-to-peer 操作,如發送/接收),並將返回torch.futures.Future。下麵是一個返回相同張量的 noop 鉤子的示例。
>>> def noop(state: object, bucket: dist.GradBucket): -> torch.futures.Future[torch.Tensor] >>> fut = torch.futures.Future() >>> fut.set_result(bucket.buffer()) >>> return fut>>> ddp.register_comm_hook(state=None, hook=noop)下麵是一個並行 SGD 算法的示例,其中梯度在 allreduce 之前編碼,然後在 allreduce 之後解碼。
>>> def encode_and_decode(state: object, bucket: dist.GradBucket): -> torch.futures.Future[torch.Tensor] >>> encoded_tensor = encode(bucket.buffer()) # encode gradients >>> fut = torch.distributed.all_reduce(encoded_tensor).get_future() >>> # Define the then callback to decode. >>> def decode(fut): >>> decoded_tensor = decode(fut.value()[0]) # decode gradients >>> return decoded_tensor >>> return fut.then(decode)>>> ddp.register_comm_hook(state=None, hook=encode_and_decode)
例子:
例子:
參數:
相關用法
- Python PyTorch DistributedDataParallel.join用法及代碼示例
- Python PyTorch DistributedDataParallel.no_sync用法及代碼示例
- Python PyTorch DistributedDataParallel用法及代碼示例
- Python PyTorch DistributedModelParallel用法及代碼示例
- Python PyTorch DistributedSampler用法及代碼示例
- Python PyTorch DistributedModelParallel.named_parameters用法及代碼示例
- Python PyTorch DistributedModelParallel.state_dict用法及代碼示例
- Python PyTorch DistributedModelParallel.named_buffers用法及代碼示例
- Python PyTorch DistributedOptimizer用法及代碼示例
- Python PyTorch Dirichlet用法及代碼示例
- Python PyTorch DeQuantize用法及代碼示例
- Python PyTorch DenseArch用法及代碼示例
- Python PyTorch DeepFM用法及代碼示例
- Python PyTorch DataFrameMaker用法及代碼示例
- Python PyTorch DLRM用法及代碼示例
- Python PyTorch Dropout用法及代碼示例
- Python PyTorch Dropout3d用法及代碼示例
- Python PyTorch DataParallel用法及代碼示例
- Python PyTorch Decompressor用法及代碼示例
- Python PyTorch Dropout2d用法及代碼示例
- Python PyTorch DeepFM.forward用法及代碼示例
- Python PyTorch Demultiplexer用法及代碼示例
- Python PyTorch DatasetFolder.find_classes用法及代碼示例
- Python PyTorch frexp用法及代碼示例
- Python PyTorch jvp用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.parallel.DistributedDataParallel.register_comm_hook。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。
