本文簡要介紹python語言中 torch.nn.parallel.DistributedDataParallel.join
的用法。
用法:
join(divide_by_initial_world_size=True, enable=True, throw_on_early_termination=False)
divide_by_initial_world_size(bool) -如果
True
,將梯度除以初始world_size
DDP 訓練啟動。如果False
,將計算有效世界大小(尚未耗盡其輸入的等級數)並在 allreduce 期間除以梯度。設置divide_by_initial_world_size=True
以確保每個輸入樣本(包括不均勻輸入)在它們對全局梯度的貢獻方掩碼有相同的權重。這是通過始終將梯度除以初始world_size
來實現的,即使我們遇到不均勻的輸入。如果將其設置為False
,我們將梯度除以剩餘節點數。這確保了在較小的world_size
上進行訓練的均等性,盡管這也意味著不均勻的輸入將對全局梯度做出更多貢獻。通常,對於訓練作業的最後幾個輸入不均勻的情況,您可能希望將其設置為True
。在輸入數量存在很大差異的極端情況下,將其設置為False
可能會提供更好的結果。enable(bool) -是否啟用不均勻輸入檢測。在您知道輸入甚至跨參與進程的情況下,傳入
enable=False
以禁用。默認為True
。throw_on_early_termination(bool) -當至少一個等級用盡輸入時是否拋出錯誤或繼續訓練。如果
True
,將拋出第一個到達數據末尾的排名。如果False
,將繼續以較小的有效世界大小進行訓練,直到加入所有行列。請注意,如果指定了此標誌,則將忽略標誌divide_by_initial_world_size
。默認為False
。
上下文管理器與
torch.nn.parallel.DistributedDataParallel
實例結合使用,能夠在參與進程中使用不均勻的輸入進行訓練。此上下文管理器將跟蹤already-joined DDP 進程和“shadow” 通過插入集體通信操作以匹配未加入的 DDP 進程創建的前向和後向傳遞。這將確保每個集體調用都有already-joined DDP 進程的相應調用,從而防止在跨進程輸入不均勻的情況下進行訓練時可能發生的掛起或錯誤。或者,如果將標誌
throw_on_early_termination
指定為True
,則一旦一個等級的輸入用完,所有訓練器都會拋出錯誤,從而允許根據應用程序邏輯捕獲和處理這些錯誤。一旦所有 DDP 進程都加入了,上下文管理器會將最後加入的進程對應的模型廣播給所有進程,以確保模型在所有進程中都是相同的(這是由 DDP 保證的)。
要使用它來啟用跨流程輸入不均勻的訓練,隻需將此上下文管理器包裝在您的訓練循環周圍。無需對模型或數據加載進行進一步修改。
警告
如果此上下文管理器所環繞的模型或訓練循環具有額外的分布式集體操作,例如模型前向傳遞中的
SyncBatchNorm
,則必須啟用標誌throw_on_early_termination
。這是因為此上下文管理器不知道非 DDP 集體通信。當任何一個 rank 耗盡輸入時,此標誌將導致所有 rank 拋出,從而允許從所有 rank 中捕獲和恢複這些錯誤。例子:
>>> import torch >>> import torch.distributed as dist >>> import os >>> import torch.multiprocessing as mp >>> import torch.nn as nn >>> # On each spawned worker >>> def worker(rank): >>> dist.init_process_group("nccl", rank=rank, world_size=2) >>> torch.cuda.set_device(rank) >>> model = nn.Linear(1, 1, bias=False).to(rank) >>> model = torch.nn.parallel.DistributedDataParallel( >>> model, device_ids=[rank], output_device=rank >>> ) >>> # Rank 1 gets one more input than rank 0. >>> inputs = [torch.tensor([1]).float() for _ in range(10 + rank)] >>> with model.join(): >>> for _ in range(5): >>> for inp in inputs: >>> loss = model(inp).sum() >>> loss.backward() >>> # Without the join() API, the below synchronization will hang >>> # blocking for rank 1's allreduce to complete. >>> torch.cuda.synchronize(device=rank)
參數:
相關用法
- Python PyTorch DistributedDataParallel.register_comm_hook用法及代碼示例
- Python PyTorch DistributedDataParallel.no_sync用法及代碼示例
- Python PyTorch DistributedDataParallel用法及代碼示例
- Python PyTorch DistributedModelParallel用法及代碼示例
- Python PyTorch DistributedSampler用法及代碼示例
- Python PyTorch DistributedModelParallel.named_parameters用法及代碼示例
- Python PyTorch DistributedModelParallel.state_dict用法及代碼示例
- Python PyTorch DistributedModelParallel.named_buffers用法及代碼示例
- Python PyTorch DistributedOptimizer用法及代碼示例
- Python PyTorch Dirichlet用法及代碼示例
- Python PyTorch DeQuantize用法及代碼示例
- Python PyTorch DenseArch用法及代碼示例
- Python PyTorch DeepFM用法及代碼示例
- Python PyTorch DataFrameMaker用法及代碼示例
- Python PyTorch DLRM用法及代碼示例
- Python PyTorch Dropout用法及代碼示例
- Python PyTorch Dropout3d用法及代碼示例
- Python PyTorch DataParallel用法及代碼示例
- Python PyTorch Decompressor用法及代碼示例
- Python PyTorch Dropout2d用法及代碼示例
- Python PyTorch DeepFM.forward用法及代碼示例
- Python PyTorch Demultiplexer用法及代碼示例
- Python PyTorch DatasetFolder.find_classes用法及代碼示例
- Python PyTorch frexp用法及代碼示例
- Python PyTorch jvp用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.parallel.DistributedDataParallel.join。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。