本文簡要介紹python語言中 torch.distributed.algorithms.Join
的用法。
用法:
class torch.distributed.algorithms.Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)
此類定義通用連接上下文管理器,它允許在進程連接後調用自定義鉤子。這些鉤子應該遮蔽非連接進程的集體通信,以防止掛起和錯誤並確保算法的正確性。有關鉤子定義的詳細信息,請參閱
JoinHook
。警告
上下文管理器要求每個參與的
Joinable
在其自己的每次迭代集體通信之前調用方法notify_join_context()
以確保正確性。警告
上下文管理器要求
JoinHook
對象中的所有process_group
屬性都相同。如果有多個JoinHook
對象,則使用第一個對象的device
。進程組和設備信息用於檢查未加入的進程,並用於通知進程在啟用throw_on_early_termination
時拋出異常,這兩者都使用 all-reduce。例子:
>>> import os >>> import torch >>> import torch.distributed as dist >>> import torch.multiprocessing as mp >>> import torch.nn.parallel.DistributedDataParallel as DDP >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO >>> from torch.distributed.algorithms.join import Join >>> >>> # On each spawned worker >>> def worker(rank): >>> dist.init_process_group("nccl", rank=rank, world_size=2) >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01) >>> # Rank 1 gets one more input than rank 0 >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)] >>> with Join([model, optim]): >>> for input in inputs: >>> loss = model(input).sum() >>> loss.backward() >>> optim.step() >>> # All ranks reach here without hanging/erroring
參數:
相關用法
- Python PyTorch JsonParser用法及代碼示例
- Python PyTorch frexp用法及代碼示例
- Python PyTorch jvp用法及代碼示例
- Python PyTorch cholesky用法及代碼示例
- Python PyTorch vdot用法及代碼示例
- Python PyTorch ELU用法及代碼示例
- Python PyTorch ScaledDotProduct.__init__用法及代碼示例
- Python PyTorch gumbel_softmax用法及代碼示例
- Python PyTorch get_tokenizer用法及代碼示例
- Python PyTorch saved_tensors_hooks用法及代碼示例
- Python PyTorch positive用法及代碼示例
- Python PyTorch renorm用法及代碼示例
- Python PyTorch AvgPool2d用法及代碼示例
- Python PyTorch MaxUnpool3d用法及代碼示例
- Python PyTorch Bernoulli用法及代碼示例
- Python PyTorch Tensor.unflatten用法及代碼示例
- Python PyTorch Sigmoid用法及代碼示例
- Python PyTorch Tensor.register_hook用法及代碼示例
- Python PyTorch ShardedEmbeddingBagCollection.named_parameters用法及代碼示例
- Python PyTorch sqrt用法及代碼示例
- Python PyTorch PackageImporter.id用法及代碼示例
- Python PyTorch column_stack用法及代碼示例
- Python PyTorch diag用法及代碼示例
- Python PyTorch skippable用法及代碼示例
- Python PyTorch EndOnDiskCacheHolder用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.distributed.algorithms.Join。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。