本文简要介绍python语言中 torch.distributed.optim.DistributedOptimizer
的用法。
用法:
class torch.distributed.optim.DistributedOptimizer(optimizer_class, params_rref, *args, **kwargs)
optimizer_class(optim.Optimizer) -在每个工作人员上实例化的优化器类。
args-传递给每个工作人员的优化器构造函数的参数。
kwargs-传递给每个工作人员的优化器构造函数的参数。
DistributedOptimizer 对分散在工作人员中的参数进行远程引用,并在本地为每个参数应用给定的优化器。
此类使用
get_gradients()
来检索特定参数的梯度。来自相同或不同客户端的对
step()
的并发调用将在每个工作线程上进行序列化 - 因为每个工作线程的优化器一次只能处理一组梯度。但是,无法保证一次为一个客户端执行完整的 forward-backward-optimizer 序列。这意味着所应用的梯度可能与在给定工作线程上执行的最新前向传递不对应。此外,不保证工人之间的排序。DistributedOptimizer
创建本地优化器时默认启用 TorchScript,因此在多线程训练(例如分布式模型并行)的情况下,优化器更新不会被 Python 全局解释器锁 (GIL) 阻止。目前大多数优化器都启用了此函数。您还可以按照PyTorch 教程中的the recipe 为您自己的自定义优化器启用TorchScript 支持。>>> import torch.distributed.autograd as dist_autograd >>> import torch.distributed.rpc as rpc >>> from torch import optim >>> from torch.distributed.optim import DistributedOptimizer >>> >>> with dist_autograd.context() as context_id: >>> # Forward pass. >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) >>> loss = rref1.to_here() + rref2.to_here() >>> >>> # Backward pass. >>> dist_autograd.backward(context_id, [loss.sum()]) >>> >>> # Optimizer. >>> dist_optim = DistributedOptimizer( >>> optim.SGD, >>> [rref1, rref2], >>> lr=0.05, >>> ) >>> dist_optim.step(context_id)
例子:
参数:
相关用法
- Python PyTorch DistributedModelParallel用法及代码示例
- Python PyTorch DistributedDataParallel用法及代码示例
- Python PyTorch DistributedDataParallel.register_comm_hook用法及代码示例
- Python PyTorch DistributedSampler用法及代码示例
- Python PyTorch DistributedDataParallel.join用法及代码示例
- Python PyTorch DistributedModelParallel.named_parameters用法及代码示例
- Python PyTorch DistributedModelParallel.state_dict用法及代码示例
- Python PyTorch DistributedDataParallel.no_sync用法及代码示例
- Python PyTorch DistributedModelParallel.named_buffers用法及代码示例
- Python PyTorch Dirichlet用法及代码示例
- Python PyTorch DeQuantize用法及代码示例
- Python PyTorch DenseArch用法及代码示例
- Python PyTorch DeepFM用法及代码示例
- Python PyTorch DataFrameMaker用法及代码示例
- Python PyTorch DLRM用法及代码示例
- Python PyTorch Dropout用法及代码示例
- Python PyTorch Dropout3d用法及代码示例
- Python PyTorch DataParallel用法及代码示例
- Python PyTorch Decompressor用法及代码示例
- Python PyTorch Dropout2d用法及代码示例
- Python PyTorch DeepFM.forward用法及代码示例
- Python PyTorch Demultiplexer用法及代码示例
- Python PyTorch DatasetFolder.find_classes用法及代码示例
- Python PyTorch frexp用法及代码示例
- Python PyTorch jvp用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributed.optim.DistributedOptimizer。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。