当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch all_to_all用法及代码示例


本文简要介绍python语言中 torch.distributed.all_to_all 的用法。

用法:

torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False)

参数

  • output_tensor_list(list[Tensor]) -每个等级要收集一个的张量列表。

  • input_tensor_list(list[Tensor]) -每个等级分散一个的张量列表。

  • group(ProcessGroup,可选的) -要处理的流程组。如果没有,将使用默认进程组。

  • async_op(bool,可选的) -此操作是否应该是异步操作。

返回

异步工作句柄,如果 async_op 设置为 True。无,如果不是 async_op 或不是该组的一部分。

每个进程将输入张量列表分散到组中的所有进程,并在输出列表中返回收集的张量列表。

支持复杂的张量。

警告

all_to_all 是实验性的,可能会发生变化。

例子

>>> input = torch.arange(4) + rank * 4
>>> input = list(input.chunk(4))
>>> input
[tensor([0]), tensor([1]), tensor([2]), tensor([3])]     # Rank 0
[tensor([4]), tensor([5]), tensor([6]), tensor([7])]     # Rank 1
[tensor([8]), tensor([9]), tensor([10]), tensor([11])]   # Rank 2
[tensor([12]), tensor([13]), tensor([14]), tensor([15])] # Rank 3
>>> output = list(torch.empty([4], dtype=torch.int64).chunk(4))
>>> dist.all_to_all(output, input)
>>> output
[tensor([0]), tensor([4]), tensor([8]), tensor([12])]    # Rank 0
[tensor([1]), tensor([5]), tensor([9]), tensor([13])]    # Rank 1
[tensor([2]), tensor([6]), tensor([10]), tensor([14])]   # Rank 2
[tensor([3]), tensor([7]), tensor([11]), tensor([15])]   # Rank 3
>>> # Essentially, it is similar to following operation:
>>> scatter_list = input
>>> gather_list  = output
>>> for i in range(world_size):
>>>   dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i)
>>> input
tensor([0, 1, 2, 3, 4, 5])                                       # Rank 0
tensor([10, 11, 12, 13, 14, 15, 16, 17, 18])                     # Rank 1
tensor([20, 21, 22, 23, 24])                                     # Rank 2
tensor([30, 31, 32, 33, 34, 35, 36])                             # Rank 3
>>> input_splits
[2, 2, 1, 1]                                                     # Rank 0
[3, 2, 2, 2]                                                     # Rank 1
[2, 1, 1, 1]                                                     # Rank 2
[2, 2, 2, 1]                                                     # Rank 3
>>> output_splits
[2, 3, 2, 2]                                                     # Rank 0
[2, 2, 1, 2]                                                     # Rank 1
[1, 2, 1, 2]                                                     # Rank 2
[1, 2, 1, 1]                                                     # Rank 3
>>> input = list(input.split(input_splits))
>>> input
[tensor([0, 1]), tensor([2, 3]), tensor([4]), tensor([5])]                   # Rank 0
[tensor([10, 11, 12]), tensor([13, 14]), tensor([15, 16]), tensor([17, 18])] # Rank 1
[tensor([20, 21]), tensor([22]), tensor([23]), tensor([24])]                 # Rank 2
[tensor([30, 31]), tensor([32, 33]), tensor([34, 35]), tensor([36])]         # Rank 3
>>> output = ...
>>> dist.all_to_all(output, input)
>>> output
[tensor([0, 1]), tensor([10, 11, 12]), tensor([20, 21]), tensor([30, 31])]   # Rank 0
[tensor([2, 3]), tensor([13, 14]), tensor([22]), tensor([32, 33])]           # Rank 1
[tensor([4]), tensor([15, 16]), tensor([23]), tensor([34, 35])]              # Rank 2
[tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])]                  # Rank 3
>>> # Another example with tensors of torch.cfloat type.
>>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j)
>>> input = list(input.chunk(4))
>>> input
[tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])]            # Rank 0
[tensor([5+5j]), tensor([6+6j]), tensor([7+7j]), tensor([8+8j])]            # Rank 1
[tensor([9+9j]), tensor([10+10j]), tensor([11+11j]), tensor([12+12j])]      # Rank 2
[tensor([13+13j]), tensor([14+14j]), tensor([15+15j]), tensor([16+16j])]    # Rank 3
>>> output = list(torch.empty([4], dtype=torch.int64).chunk(4))
>>> dist.all_to_all(output, input)
>>> output
[tensor([1+1j]), tensor([5+5j]), tensor([9+9j]), tensor([13+13j])]          # Rank 0
[tensor([2+2j]), tensor([6+6j]), tensor([10+10j]), tensor([14+14j])]        # Rank 1
[tensor([3+3j]), tensor([7+7j]), tensor([11+11j]), tensor([15+15j])]        # Rank 2
[tensor([4+4j]), tensor([8+8j]), tensor([12+12j]), tensor([16+16j])]        # Rank 3

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributed.all_to_all。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。