本文简要介绍python语言中 torch.nn.parallel.DistributedDataParallel
的用法。
用法:
class torch.nn.parallel.DistributedDataParallel(module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=False)
module(torch.nn.Module) -要并行化的模块
device_ids(Python列表:int或者torch.device) -
CUDA 设备。 1)对于单设备模块,
device_ids
可以只包含一个设备id,代表该进程对应的输入模块所在的唯一CUDA设备。或者,device_ids
也可以是None
。 2) 对于多设备模块和 CPU 模块,device_ids
必须为None
。对于这两种情况,当
device_ids
为None
时,前向传递的输入数据和实际模块都必须放置在正确的设备上。 (默认:None
)output_device(int或者torch.device) -单设备 CUDA 模块的输出的设备位置。对于多设备模块和CPU模块,它必须是
None
,并且模块本身决定输出位置。 (对于单设备模块,默认值:device_ids[0]
)broadcast_buffers(bool) -在
forward
函数开头启用模块同步(广播)缓冲区的标志。 (默认值:True
)process_group-用于分布式数据的进程组all-reduction。如果是
None
,则将使用由torch.distributed.init_process_group()
创建的默认进程组。 (默认:None
)bucket_cap_mb-
DistributedDataParallel
将参数分桶到多个桶中,以便每个桶的梯度减少可能与后向计算重叠。bucket_cap_mb
控制存储桶大小,单位为MegaBytes (MB)。 (默认值:25)find_unused_parameters(bool) -从包装模块的
forward
函数的返回值中包含的所有张量遍历 autograd 图。不接收梯度作为该图的一部分的参数被预先标记为准备减少。此外,可能已在包装模块的forward
函数中使用但不属于损失计算的一部分,因此也不会接收梯度的参数被抢先标记为准备减少。 (默认:False
)check_reduction-此论点已弃用。
gradient_as_bucket_view(bool) -当设置为
True
时,梯度将是指向allreduce
通信桶的不同偏移量的视图。这可以减少峰值内存使用量,其中节省的内存大小将等于总梯度大小。此外,它避免了在梯度和allreduce
通信桶之间复制的开销。当渐变是视图时,不能在渐变上调用detach_()
。如果遇到此类错误,请参考torch/optim/optimizer.py
中的zero_grad()
函数作为解决方案来修复它。
~DistributedDataParallel.module(torch.nn.Module) -要并行化的模块。
在模块级别实现基于
torch.distributed
包的分布式数据并行性。此容器通过在批处理维度中分块将输入拆分到指定的设备,从而并行化给定模块的应用程序。该模块在每台机器和每台设备上复制,每个这样的副本处理输入的一部分。在反向传播期间,来自每个节点的梯度被平均。
批量大小应大于本地使用的 GPU 数量。
另请参阅:基础知识和使用 nn.parallel.DistributedDataParallel 而不是多处理或 nn.DataParallel。对输入的限制与
torch.nn.DataParallel
中的相同。创建此类需要通过调用
torch.distributed.init_process_group()
来初始化torch.distributed
。对于 single-node multi-GPU 数据并行训练,
DistributedDataParallel
被证明比torch.nn.DataParallel
快得多。要在具有 N 个 GPU 的主机上使用
DistributedDataParallel
,您应该生成N
进程,确保每个进程仅在从 0 到 N-1 的单个 GPU 上工作。这可以通过为每个进程设置CUDA_VISIBLE_DEVICES
或调用:>>> torch.cuda.set_device(i)
其中 i 是从 0 到 N-1。在每个过程中,您应该参考以下内容来构建此模块:
>>> torch.distributed.init_process_group( >>> backend='nccl', world_size=N, init_method='...' >>> ) >>> model = DistributedDataParallel(model, device_ids=[i], output_device=i)
为了在每个节点生成多个进程,您可以使用
torch.distributed.launch
或torch.multiprocessing.spawn
。注意
分布式训练相关的所有特性请参考PyTorch Distributed Overview简单介绍。
注意
DistributedDataParallel
可以与torch.distributed.optim.ZeroRedundancyOptimizer
结合使用,以减少 per-rank 优化器状态内存占用。更多详情请参阅ZeroRedundancyOptimizer recipe。注意
nccl
后端是目前使用 GPU 时速度最快且强烈推荐的后端。这适用于 single-node 和 multi-node 分布式训练。注意
该模块还支持mixed-precision分布式训练。这意味着您的模型可以具有不同类型的参数,例如
fp16
和fp32
的混合类型,这些混合类型的参数的梯度减少将正常工作。注意
如果您在一个进程上使用
torch.save
来检查模块,并在其他一些进程上使用torch.load
来恢复它,请确保为每个进程正确配置map_location
。如果没有map_location
,torch.load
会将模块恢复到保存模块的设备。注意
当使用
batch=N
在M
节点上训练模型时,如果对损失求和(不像平常那样平均),则与使用batch=M*N
在单个节点上训练的相同模型相比,梯度将小M
倍批次中的跨实例(因为不同节点之间的梯度是平均的)。当您想要获得与本地训练对应的数学上等效的训练过程时,您应该考虑到这一点。但在大多数情况下,您可以将 DistributedDataParallel 包装模型、DataParallel 包装模型和单个 GPU 上的普通模型视为相同(例如,对等效批量大小使用相同的学习率)。注意
参数永远不会在进程之间广播。该模块对梯度执行 all-reduce 步骤,并假设优化器将在所有进程中以相同的方式修改它们。在每次迭代中,缓冲区(例如 BatchNorm stats)从处于 0 级进程的模块广播到系统中的所有其他副本。
注意
如果您将 DistributedDataParallel 与分布式 RPC 框架结合使用,则应始终使用
torch.distributed.autograd.backward()
来计算梯度,并使用torch.distributed.optim.DistributedOptimizer
来优化参数。例子:
>>> import torch.distributed.autograd as dist_autograd >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> from torch import optim >>> from torch.distributed.optim import DistributedOptimizer >>> from torch.distributed.rpc import RRef >>> >>> t1 = torch.rand((3, 3), requires_grad=True) >>> t2 = torch.rand((3, 3), requires_grad=True) >>> rref = rpc.remote("worker1", torch.add, args=(t1, t2)) >>> ddp_model = DDP(my_model) >>> >>> # Setup optimizer >>> optimizer_params = [rref] >>> for param in ddp_model.parameters(): >>> optimizer_params.append(RRef(param)) >>> >>> dist_optim = DistributedOptimizer( >>> optim.SGD, >>> optimizer_params, >>> lr=0.05, >>> ) >>> >>> with dist_autograd.context() as context_id: >>> pred = ddp_model(rref.to_here()) >>> loss = loss_func(pred, loss) >>> dist_autograd.backward(context_id, loss) >>> dist_optim.step()
注意
要让非 DDP 模型从 DDP 模型加载状态字典,需要应用
consume_prefix_in_state_dict_if_present()
以在加载前去除 DDP 状态字典中的前缀 “module.”。警告
构造函数、前向方法和输出的微分(或本模块输出的函数)是分布式同步点。考虑到这一点,以防不同的进程可能正在执行不同的代码。
警告
该模块假定所有参数在模型创建时都已注册在模型中。以后不应添加或删除任何参数。同样适用于缓冲区。
警告
该模块假设所有参数都注册在模型中,每个分布式进程的顺序相同。模块本身将按照模型注册参数的相反顺序进行梯度
allreduce
。换句话说,用户有责任确保每个分布式进程具有完全相同的模型,从而具有完全相同的参数注册顺序。警告
此模块允许具有非行主要连续步幅的参数。例如,您的模型可能包含一些
torch.memory_format
为torch.contiguous_format
的参数和其他格式为torch.channels_last
的参数。但是,不同进程中对应的参数必须具有相同的步幅。警告
此模块不适用于
torch.autograd.grad()
(即,它仅在渐变将在参数的.grad
属性中累积时才有效)。警告
如果您计划将此模块与
nccl
后端或gloo
后端(使用 Infiniband)以及使用多个工作线程的 DataLoader 一起使用,请将多处理启动方法更改为forkserver
(仅限 Python 3) )或spawn
。不幸的是,Gloo(使用 Infiniband)和 NCCL2 不是分叉安全的,如果不更改此设置,您可能会遇到死锁。警告
module
及其子模块上定义的前向和后向挂钩将不再被调用,除非挂钩在forward()
方法中被初始化。警告
在使用
DistributedDataParallel
包装模型后,您永远不应该尝试更改模型的参数。因为,当用DistributedDataParallel
包你的模型时,DistributedDataParallel
的构造函数将在构建时在模型本身的所有参数上注册额外的梯度减少函数。如果之后更改模型的参数,梯度缩减函数将不再匹配正确的参数集。警告
将
DistributedDataParallel
与分布式 RPC 框架结合使用是实验性的,可能会发生变化。例子:
>>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...') >>> net = torch.nn.parallel.DistributedDataParallel(model, pg)
参数:
变量:
相关用法
- Python PyTorch DistributedDataParallel.register_comm_hook用法及代码示例
- Python PyTorch DistributedDataParallel.join用法及代码示例
- Python PyTorch DistributedDataParallel.no_sync用法及代码示例
- Python PyTorch DistributedModelParallel用法及代码示例
- Python PyTorch DistributedSampler用法及代码示例
- Python PyTorch DistributedModelParallel.named_parameters用法及代码示例
- Python PyTorch DistributedModelParallel.state_dict用法及代码示例
- Python PyTorch DistributedModelParallel.named_buffers用法及代码示例
- Python PyTorch DistributedOptimizer用法及代码示例
- Python PyTorch Dirichlet用法及代码示例
- Python PyTorch DeQuantize用法及代码示例
- Python PyTorch DenseArch用法及代码示例
- Python PyTorch DeepFM用法及代码示例
- Python PyTorch DataFrameMaker用法及代码示例
- Python PyTorch DLRM用法及代码示例
- Python PyTorch Dropout用法及代码示例
- Python PyTorch Dropout3d用法及代码示例
- Python PyTorch DataParallel用法及代码示例
- Python PyTorch Decompressor用法及代码示例
- Python PyTorch Dropout2d用法及代码示例
- Python PyTorch DeepFM.forward用法及代码示例
- Python PyTorch Demultiplexer用法及代码示例
- Python PyTorch DatasetFolder.find_classes用法及代码示例
- Python PyTorch frexp用法及代码示例
- Python PyTorch jvp用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.parallel.DistributedDataParallel。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。