当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch all_reduce用法及代码示例


本文简要介绍python语言中 torch.distributed.all_reduce 的用法。

用法:

torch.distributed.all_reduce(tensor, op=<ReduceOp.SUM: 0>, group=None, async_op=False)

参数

  • tensor(Tensor) -集体的输入和输出。该函数就地运行。

  • op(可选的) -torch.distributed.ReduceOp 枚举中的值之一。指定用于按元素减少的操作。

  • group(ProcessGroup,可选的) -要处理的流程组。如果没有,将使用默认进程组。

  • async_op(bool,可选的) -此操作是否应该是异步操作

返回

异步工作句柄,如果 async_op 设置为 True。无,如果不是 async_op 或者如果不是该组的一部分

以这样的方式减少所有机器上的张量数据,从而得到最终结果。

在调用 tensor 之后,所有进程将按位相同。

支持复杂的张量。

例子

>>> # All tensors below are of torch.int64 type.
>>> # We have 2 process groups, 2 ranks.
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> dist.all_reduce(tensor, op=ReduceOp.SUM)
>>> tensor
tensor([4, 6]) # Rank 0
tensor([4, 6]) # Rank 1
>>> # All tensors below are of torch.cfloat type.
>>> # We have 2 process groups, 2 ranks.
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
>>> tensor
tensor([1.+1.j, 2.+2.j]) # Rank 0
tensor([3.+3.j, 4.+4.j]) # Rank 1
>>> dist.all_reduce(tensor, op=ReduceOp.SUM)
>>> tensor
tensor([4.+4.j, 6.+6.j]) # Rank 0
tensor([4.+4.j, 6.+6.j]) # Rank 1

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributed.all_reduce。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。