當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch bf16_compress_wrapper用法及代碼示例


本文簡要介紹python語言中 torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_wrapper 的用法。

用法:

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_wrapper(hook)

警告:此 API 是實驗性的,需要 NCCL 版本高於 2.9.6。

此包裝器將給定 DDP 通信掛鉤的輸入梯度張量轉換為 half-precisionBrain floating point format <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format> `_ (``torch.bfloat16`),並將給定鉤子的結果張量轉換回輸入數據類型,例如float32.

因此,bf16_compress_hook 等價於 bf16_compress_wrapper(allreduce_hook)

例子:

>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook))

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_wrapper。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。