當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch all_gather_object用法及代碼示例


本文簡要介紹python語言中 torch.distributed.all_gather_object 的用法。

用法:

torch.distributed.all_gather_object(object_list, obj, group=None)

參數

  • object_list(list[任何]) -輸出列表。它的大小應該正確地調整為該集合的組大小,並將包含輸出。

  • object(任何) -要從當前進程廣播的可選 Python 對象。

  • group(ProcessGroup,可選的) -要處理的流程組。如果沒有,將使用默認進程組。默認為 None

返回

沒有。如果調用等級是該組的一部分,則集合的輸出將填充到輸入 object_list 中。如果調用等級不是組的一部分,則傳入的object_list 將不被修改。

將整個組中的 picklable 對象收集到一個列表中。與all_gather()類似,但可以傳入Python對象。請注意,該對象必須是可picklable的才能被收集。

注意

請注意,此 API 與 all_gather() 集合略有不同,因為它不提供 async_op 句柄,因此將是阻塞調用。

注意

對於基於 NCCL 的處理組,對象的內部張量表示必須在通信發生之前移動到 GPU 設備。在這種情況下,使用的設備由 torch.cuda.current_device() 給出,用戶有責任通過 torch.cuda.set_device() 確保將其設置為每個等級都有單獨的 GPU。

警告

all_gather_object() 隱式使用 pickle 模塊,已知這是不安全的。可以構造惡意的 pickle 數據,該數據將在 unpickling 期間執行任意代碼。僅使用您信任的數據調用此函數。

例子:

>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes world_size of 3.
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>> output = [None for _ in gather_objects]
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()])
>>> output
['foo', 12, {1: 2}]

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.distributed.all_gather_object。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。