本文简要介绍python语言中 torch.nn.parallel.DistributedDataParallel.join
的用法。
用法:
join(divide_by_initial_world_size=True, enable=True, throw_on_early_termination=False)
divide_by_initial_world_size(bool) -如果
True
,将梯度除以初始world_size
DDP 训练启动。如果False
,将计算有效世界大小(尚未耗尽其输入的等级数)并在 allreduce 期间除以梯度。设置divide_by_initial_world_size=True
以确保每个输入样本(包括不均匀输入)在它们对全局梯度的贡献方掩码有相同的权重。这是通过始终将梯度除以初始world_size
来实现的,即使我们遇到不均匀的输入。如果将其设置为False
,我们将梯度除以剩余节点数。这确保了在较小的world_size
上进行训练的均等性,尽管这也意味着不均匀的输入将对全局梯度做出更多贡献。通常,对于训练作业的最后几个输入不均匀的情况,您可能希望将其设置为True
。在输入数量存在很大差异的极端情况下,将其设置为False
可能会提供更好的结果。enable(bool) -是否启用不均匀输入检测。在您知道输入甚至跨参与进程的情况下,传入
enable=False
以禁用。默认为True
。throw_on_early_termination(bool) -当至少一个等级用尽输入时是否抛出错误或继续训练。如果
True
,将抛出第一个到达数据末尾的排名。如果False
,将继续以较小的有效世界大小进行训练,直到加入所有行列。请注意,如果指定了此标志,则将忽略标志divide_by_initial_world_size
。默认为False
。
上下文管理器与
torch.nn.parallel.DistributedDataParallel
实例结合使用,能够在参与进程中使用不均匀的输入进行训练。此上下文管理器将跟踪already-joined DDP 进程和“shadow” 通过插入集体通信操作以匹配未加入的 DDP 进程创建的前向和后向传递。这将确保每个集体调用都有already-joined DDP 进程的相应调用,从而防止在跨进程输入不均匀的情况下进行训练时可能发生的挂起或错误。或者,如果将标志
throw_on_early_termination
指定为True
,则一旦一个等级的输入用完,所有训练器都会抛出错误,从而允许根据应用程序逻辑捕获和处理这些错误。一旦所有 DDP 进程都加入了,上下文管理器会将最后加入的进程对应的模型广播给所有进程,以确保模型在所有进程中都是相同的(这是由 DDP 保证的)。
要使用它来启用跨流程输入不均匀的训练,只需将此上下文管理器包装在您的训练循环周围。无需对模型或数据加载进行进一步修改。
警告
如果此上下文管理器所环绕的模型或训练循环具有额外的分布式集体操作,例如模型前向传递中的
SyncBatchNorm
,则必须启用标志throw_on_early_termination
。这是因为此上下文管理器不知道非 DDP 集体通信。当任何一个 rank 耗尽输入时,此标志将导致所有 rank 抛出,从而允许从所有 rank 中捕获和恢复这些错误。例子:
>>> import torch >>> import torch.distributed as dist >>> import os >>> import torch.multiprocessing as mp >>> import torch.nn as nn >>> # On each spawned worker >>> def worker(rank): >>> dist.init_process_group("nccl", rank=rank, world_size=2) >>> torch.cuda.set_device(rank) >>> model = nn.Linear(1, 1, bias=False).to(rank) >>> model = torch.nn.parallel.DistributedDataParallel( >>> model, device_ids=[rank], output_device=rank >>> ) >>> # Rank 1 gets one more input than rank 0. >>> inputs = [torch.tensor([1]).float() for _ in range(10 + rank)] >>> with model.join(): >>> for _ in range(5): >>> for inp in inputs: >>> loss = model(inp).sum() >>> loss.backward() >>> # Without the join() API, the below synchronization will hang >>> # blocking for rank 1's allreduce to complete. >>> torch.cuda.synchronize(device=rank)
参数:
相关用法
- Python PyTorch DistributedDataParallel.register_comm_hook用法及代码示例
- Python PyTorch DistributedDataParallel.no_sync用法及代码示例
- Python PyTorch DistributedDataParallel用法及代码示例
- Python PyTorch DistributedModelParallel用法及代码示例
- Python PyTorch DistributedSampler用法及代码示例
- Python PyTorch DistributedModelParallel.named_parameters用法及代码示例
- Python PyTorch DistributedModelParallel.state_dict用法及代码示例
- Python PyTorch DistributedModelParallel.named_buffers用法及代码示例
- Python PyTorch DistributedOptimizer用法及代码示例
- Python PyTorch Dirichlet用法及代码示例
- Python PyTorch DeQuantize用法及代码示例
- Python PyTorch DenseArch用法及代码示例
- Python PyTorch DeepFM用法及代码示例
- Python PyTorch DataFrameMaker用法及代码示例
- Python PyTorch DLRM用法及代码示例
- Python PyTorch Dropout用法及代码示例
- Python PyTorch Dropout3d用法及代码示例
- Python PyTorch DataParallel用法及代码示例
- Python PyTorch Decompressor用法及代码示例
- Python PyTorch Dropout2d用法及代码示例
- Python PyTorch DeepFM.forward用法及代码示例
- Python PyTorch Demultiplexer用法及代码示例
- Python PyTorch DatasetFolder.find_classes用法及代码示例
- Python PyTorch frexp用法及代码示例
- Python PyTorch jvp用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.parallel.DistributedDataParallel.join。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。