當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch Join用法及代碼示例

本文簡要介紹python語言中 torch.distributed.algorithms.Join 的用法。

用法:

class torch.distributed.algorithms.Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)

參數

  • joinables(List[Joinable]) -參與Joinable 的列表;它們的鉤子按照給定的順序迭代。

  • enable(bool) -啟用不均勻輸入檢測的標誌;設置為 False 會禁用上下文管理器的函數,並且僅應在用戶知道輸入不會不均勻時設置(默認值:True)。

  • throw_on_early_termination(bool) -控製是否在檢測到不均勻輸入時拋出異常的標誌(默認值:False)。

此類定義通用連接上下文管理器,它允許在進程連接後調用自定義鉤子。這些鉤子應該遮蔽非連接進程的集體通信,以防止掛起和錯誤並確保算法的正確性。有關鉤子定義的詳細信息,請參閱JoinHook

警告

上下文管理器要求每個參與的 Joinable 在其自己的每次迭代集體通信之前調用方法 notify_join_context() 以確保正確性。

警告

上下文管理器要求 JoinHook 對象中的所有 process_group 屬性都相同。如果有多個 JoinHook 對象,則使用第一個對象的 device。進程組和設備信息用於檢查未加入的進程,並用於通知進程在啟用 throw_on_early_termination 時拋出異常,這兩者都使用 all-reduce。

例子:

>>> import os
>>> import torch
>>> import torch.distributed as dist
>>> import torch.multiprocessing as mp
>>> import torch.nn.parallel.DistributedDataParallel as DDP
>>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO
>>> from torch.distributed.algorithms.join import Join
>>>
>>> # On each spawned worker
>>> def worker(rank):
>>>     dist.init_process_group("nccl", rank=rank, world_size=2)
>>>     model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
>>>     optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01)
>>>     # Rank 1 gets one more input than rank 0
>>>     inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)]
>>>     with Join([model, optim]):
>>>         for input in inputs:
>>>             loss = model(input).sum()
>>>             loss.backward()
>>>             optim.step()
>>>     # All ranks reach here without hanging/erroring

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.distributed.algorithms.Join。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。