當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch batched_powerSGD_hook用法及代碼示例

本文簡要介紹python語言中 torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook 的用法。

用法:

torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook(state, bucket)

參數

  • state(PowerSGDState) -用於配置壓縮率和支持錯誤反饋、熱啟動等的狀態信息。要調整壓縮配置,主要需要調整 matrix_approximation_rankstart_powerSGD_iter

  • bucket(dist.GradBucket) -存儲一維扁平梯度張量的桶,該張量批量處理多個每個變量的張量。請注意,由於 DDP 通信鉤子僅支持單進程單設備模式,因此該存儲桶中僅存儲了一個張量。

返回

通信的未來處理程序,它更新到位的梯度。

這個 DDP 通信鉤子實現了一個簡化的 PowerSGD 梯度壓縮算法,在.此變體不會逐層壓縮梯度,而是壓縮對所有梯度進行批處理的扁平化輸入張量。因此它是快點powerSGD_hook(),但通常會導致精度低得多, 除非matrix_approximation_rank是 1。

警告

此處增加 matrix_approximation_rank 不一定會提高準確性,因為在沒有列/行對齊的情況下對每個參數張量進行批處理可能會破壞低秩結構。因此,用戶應始終首先考慮powerSGD_hook(),並且僅當matrix_approximation_rank為1時能夠達到滿意的精度時才考慮此變體。

一旦梯度張量在所有工作人員中聚合,此掛鉤將按如下方式應用壓縮:

  1. 將輸入扁平化的 1D 梯度張量視為具有 0 個填充的 square-shaped 張量 M;

  2. 創建兩個低秩張量 P 和 Q 用於分解 M,使得 M = PQ^T,其中 Q 從標準正態分布初始化並正交化;

  3. 計算 P,它等於 MQ;

  4. 全部減少 P;

  5. 正交化 P;

  6. 計算 Q,它大約等於 M^TP;

  7. 全部減少 Q;

  8. 計算 M,它大約等於 PQ^T。

  9. 將輸入張量截斷為原始長度。

請注意,此通信掛鉤對第一個 state.start_powerSGD_iter 迭代強製執行 vanilla allreduce。這不僅讓用戶可以更好地控製加速和準確性之間的權衡,還有助於為未來的通信鉤子開發人員抽象出 DDP 內部優化的一些複雜性。

例子:

>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。