當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch DistributedOptimizer用法及代碼示例


本文簡要介紹python語言中 torch.distributed.optim.DistributedOptimizer 的用法。

用法:

class torch.distributed.optim.DistributedOptimizer(optimizer_class, params_rref, *args, **kwargs)

參數

  • optimizer_class(optim.Optimizer) -在每個工作人員上實例化的優化器類。

  • params_rref(list[RRef]) -要優化的本地或遠程參數的 RRef 列表。

  • args-傳遞給每個工作人員的優化器構造函數的參數。

  • kwargs-傳遞給每個工作人員的優化器構造函數的參數。

DistributedOptimizer 對分散在工作人員中的參數進行遠程引用,並在本地為每個參數應用給定的優化器。

此類使用get_gradients() 來檢索特定參數的梯度。

來自相同或不同客戶端的對 step() 的並發調用將在每個工作線程上進行序列化 - 因為每個工作線程的優化器一次隻能處理一組梯度。但是,無法保證一次為一個客戶端執行完整的 forward-backward-optimizer 序列。這意味著所應用的梯度可能與在給定工作線程上執行的最新前向傳遞不對應。此外,不保證工人之間的排序。

DistributedOptimizer 創建本地優化器時默認啟用 TorchScript,因此在多線程訓練(例如分布式模型並行)的情況下,優化器更新不會被 Python 全局解釋器鎖 (GIL) 阻止。目前大多數優化器都啟用了此函數。您還可以按照PyTorch 教程中的the recipe 為您自己的自定義優化器啟用TorchScript 支持。

例子:

>>> import torch.distributed.autograd as dist_autograd
>>> import torch.distributed.rpc as rpc
>>> from torch import optim
>>> from torch.distributed.optim import DistributedOptimizer
>>>
>>> with dist_autograd.context() as context_id:
>>>   # Forward pass.
>>>   rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
>>>   rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
>>>   loss = rref1.to_here() + rref2.to_here()
>>>
>>>   # Backward pass.
>>>   dist_autograd.backward(context_id, [loss.sum()])
>>>
>>>   # Optimizer.
>>>   dist_optim = DistributedOptimizer(
>>>      optim.SGD,
>>>      [rref1, rref2],
>>>      lr=0.05,
>>>   )
>>>   dist_optim.step(context_id)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.distributed.optim.DistributedOptimizer。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。