當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch broadcast_object_list用法及代碼示例


本文簡要介紹python語言中 torch.distributed.broadcast_object_list 的用法。

用法:

torch.distributed.broadcast_object_list(object_list, src=0, group=None, device=None)

參數

  • object_list(List[任何]) -要廣播的輸入對象列表。每個對象都必須是 picklable 的。隻有src 等級上的對象才會被廣播,但每個等級必須提供大小相等的列表。

  • src(int) -廣播的源排名 object_list

  • group-(ProcessGroup,可選):要處理的進程組。如果沒有,將使用默認進程組。默認為 None

  • device(torch.device, 可選的) -如果不是 None,則將對象序列化並轉換為張量,然後在廣播之前將其移動到device。默認為 None

返回

None 。如果 rank 是組的一部分,object_list 將包含來自 src rank 的廣播對象。

object_list 中的可挑選對象廣播給整個組。與 broadcast() 類似,但可以傳入 Python 對象。請注意,object_list 中的所有對象都必須是可picklable 的才能進行廣播。

注意

對於基於 NCCL 的處理組,對象的內部張量表示必須在通信發生之前移動到 GPU 設備。在這種情況下,使用的設備由 torch.cuda.current_device() 給出,用戶有責任通過 torch.cuda.set_device() 確保將其設置為每個等級都有單獨的 GPU。

注意

請注意,此 API 與 all_gather() 集合略有不同,因為它不提供 async_op 句柄,因此將是阻塞調用。

警告

broadcast_object_list() 隱式使用 pickle 模塊,已知這是不安全的。可以構造惡意的 pickle 數據,該數據將在 unpickling 期間執行任意代碼。僅使用您信任的數據調用此函數。

例子:

>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> if dist.get_rank() == 0:
>>>     # Assumes world_size of 3.
>>>     objects = ["foo", 12, {1: 2}] # any picklable object
>>> else:
>>>     objects = [None, None, None]
>>> # Assumes backend is not NCCL
>>> device = torch.device("cpu")
>>> dist.broadcast_object_list(objects, src=0, device=device)
>>> broadcast_objects
['foo', 12, {1: 2}]

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.distributed.broadcast_object_list。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。