当前位置: 首页>>编程示例 >>用法及示例精选 >>正文


Python PyTorch Join用法及代码示例

本文简要介绍python语言中 torch.distributed.algorithms.Join 的用法。

用法:

class torch.distributed.algorithms.Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)

参数

  • joinables(List[Joinable]) -参与Joinable 的列表;它们的钩子按照给定的顺序迭代。

  • enable(bool) -启用不均匀输入检测的标志;设置为 False 会禁用上下文管理器的函数,并且仅应在用户知道输入不会不均匀时设置(默认值:True)。

  • throw_on_early_termination(bool) -控制是否在检测到不均匀输入时抛出异常的标志(默认值:False)。

此类定义通用连接上下文管理器,它允许在进程连接后调用自定义钩子。这些钩子应该遮蔽非连接进程的集体通信,以防止挂起和错误并确保算法的正确性。有关钩子定义的详细信息,请参阅JoinHook

警告

上下文管理器要求每个参与的 Joinable 在其自己的每次迭代集体通信之前调用方法 notify_join_context() 以确保正确性。

警告

上下文管理器要求 JoinHook 对象中的所有 process_group 属性都相同。如果有多个 JoinHook 对象,则使用第一个对象的 device。进程组和设备信息用于检查未加入的进程,并用于通知进程在启用 throw_on_early_termination 时抛出异常,这两者都使用 all-reduce。

例子:

>>> import os
>>> import torch
>>> import torch.distributed as dist
>>> import torch.multiprocessing as mp
>>> import torch.nn.parallel.DistributedDataParallel as DDP
>>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO
>>> from torch.distributed.algorithms.join import Join
>>>
>>> # On each spawned worker
>>> def worker(rank):
>>>     dist.init_process_group("nccl", rank=rank, world_size=2)
>>>     model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
>>>     optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01)
>>>     # Rank 1 gets one more input than rank 0
>>>     inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)]
>>>     with Join([model, optim]):
>>>         for input in inputs:
>>>             loss = model(input).sum()
>>>             loss.backward()
>>>             optim.step()
>>>     # All ranks reach here without hanging/erroring

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributed.algorithms.Join。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。