當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch all_gather用法及代碼示例


本文簡要介紹python語言中 torch.distributed.all_gather 的用法。

用法:

torch.distributed.all_gather(tensor_list, tensor, group=None, async_op=False)

參數

  • tensor_list(list[Tensor]) -輸出列表。它應該包含 correctly-sized 張量,用於集體的輸出。

  • tensor(Tensor) -從當前進程廣播的張量。

  • group(ProcessGroup,可選的) -要處理的流程組。如果沒有,將使用默認進程組。

  • async_op(bool,可選的) -此操作是否應該是異步操作

返回

異步工作句柄,如果 async_op 設置為 True。無,如果不是 async_op 或者如果不是該組的一部分

將整個組的張量收集到一個列表中。

支持複雜的張量。

例子

>>> # All tensors below are of torch.int64 dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zeros(2, dtype=torch.int64) for _ in range(2)]
>>> tensor_list
[tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1, 2]), tensor([3, 4])] # Rank 0
[tensor([1, 2]), tensor([3, 4])] # Rank 1
>>> # All tensors below are of torch.cfloat dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zeros(2, dtype=torch.cfloat) for _ in range(2)]
>>> tensor_list
[tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
>>> tensor
tensor([1.+1.j, 2.+2.j]) # Rank 0
tensor([3.+3.j, 4.+4.j]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 0
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.distributed.all_gather。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。