當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch powerSGD_hook用法及代碼示例


本文簡要介紹python語言中 torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.powerSGD_hook 的用法。

用法:

torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.powerSGD_hook(state, bucket)

參數

  • state(PowerSGDState) -用於配置壓縮率和支持錯誤反饋、熱啟動等的狀態信息。要調整壓縮配置,主要需要調整 matrix_approximation_rankstart_powerSGD_itermin_compression_rate

  • bucket(dist.GradBucket) -存儲一維扁平梯度張量的桶,該張量批量處理多個每個變量的張量。請注意,由於 DDP 通信鉤子僅支持單進程單設備模式,因此該存儲桶中僅存儲了一個張量。

返回

通信的未來處理程序,它更新到位的梯度。

此 DDP 通信鉤子實現 paper 中說明的 PowerSGD 梯度壓縮算法。一旦梯度張量在所有工作人員中聚合,此鉤子將按如下方式應用壓縮:

  1. 將輸入的扁平化一維梯度張量視為每個參數張量的列表,並將所有張量分為兩組:

    1.1 allreduce之前應該壓縮的張量,因為壓縮可以節省足夠的帶寬。

    1.2 其餘的張量將直接全部歸約而不壓縮,包括所有的向量張量(用於偏差)。

  2. 處理未壓縮的張量:

    2.1。為那些未壓縮的張量分配連續的內存,並將所有未壓縮的張量作為一個批次全部減少,不進行壓縮;

    2.2.將單個未壓縮張量從連續內存複製回輸入張量。

  3. 處理應該通過PowerSGD 壓縮來壓縮的張量:

    3.1.對於每個張量 M,創建兩個低秩張量 P 和 Q 來分解 M,使得 M = PQ^T,其中 Q 從標準正態分布初始化並正交化;

    3.2.計算 Ps 中的每個 P,等於 MQ;

    3.3. Allreduces Ps 作為一個批次;

    3.4.正交化 Ps 中的每個 P;

    3.5.計算 Qs 中的每個 Q,大約等於 M^TP;

    3.6. Allreduce Qs 作為一個批次;

    3.7.計算所有壓縮張量中的每個 M,大約等於 PQ^T。

請注意,此通信掛鉤對第一個 state.start_powerSGD_iter 迭代強製執行 vanilla allreduce。這不僅讓用戶可以更好地控製加速和準確性之間的權衡,還有助於為未來的通信鉤子開發人員抽象出 DDP 內部優化的一些複雜性。

例子:

>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1,
                          start_powerSGD_iter=10, min_compression_rate=0.5)
>>> ddp_model.register_comm_hook(state, powerSGD_hook)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.powerSGD_hook。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。