当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch DistributedOptimizer用法及代码示例


本文简要介绍python语言中 torch.distributed.optim.DistributedOptimizer 的用法。

用法:

class torch.distributed.optim.DistributedOptimizer(optimizer_class, params_rref, *args, **kwargs)

参数

  • optimizer_class(optim.Optimizer) -在每个工作人员上实例化的优化器类。

  • params_rref(list[RRef]) -要优化的本地或远程参数的 RRef 列表。

  • args-传递给每个工作人员的优化器构造函数的参数。

  • kwargs-传递给每个工作人员的优化器构造函数的参数。

DistributedOptimizer 对分散在工作人员中的参数进行远程引用,并在本地为每个参数应用给定的优化器。

此类使用get_gradients() 来检索特定参数的梯度。

来自相同或不同客户端的对 step() 的并发调用将在每个工作线程上进行序列化 - 因为每个工作线程的优化器一次只能处理一组梯度。但是,无法保证一次为一个客户端执行完整的 forward-backward-optimizer 序列。这意味着所应用的梯度可能与在给定工作线程上执行的最新前向传递不对应。此外,不保证工人之间的排序。

DistributedOptimizer 创建本地优化器时默认启用 TorchScript,因此在多线程训练(例如分布式模型并行)的情况下,优化器更新不会被 Python 全局解释器锁 (GIL) 阻止。目前大多数优化器都启用了此函数。您还可以按照PyTorch 教程中的the recipe 为您自己的自定义优化器启用TorchScript 支持。

例子:

>>> import torch.distributed.autograd as dist_autograd
>>> import torch.distributed.rpc as rpc
>>> from torch import optim
>>> from torch.distributed.optim import DistributedOptimizer
>>>
>>> with dist_autograd.context() as context_id:
>>>   # Forward pass.
>>>   rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
>>>   rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
>>>   loss = rref1.to_here() + rref2.to_here()
>>>
>>>   # Backward pass.
>>>   dist_autograd.backward(context_id, [loss.sum()])
>>>
>>>   # Optimizer.
>>>   dist_optim = DistributedOptimizer(
>>>      optim.SGD,
>>>      [rref1, rref2],
>>>      lr=0.05,
>>>   )
>>>   dist_optim.step(context_id)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributed.optim.DistributedOptimizer。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。