当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch broadcast_object_list用法及代码示例


本文简要介绍python语言中 torch.distributed.broadcast_object_list 的用法。

用法:

torch.distributed.broadcast_object_list(object_list, src=0, group=None, device=None)

参数

  • object_list(List[任何]) -要广播的输入对象列表。每个对象都必须是 picklable 的。只有src 等级上的对象才会被广播,但每个等级必须提供大小相等的列表。

  • src(int) -广播的源排名 object_list

  • group-(ProcessGroup,可选):要处理的进程组。如果没有,将使用默认进程组。默认为 None

  • device(torch.device, 可选的) -如果不是 None,则将对象序列化并转换为张量,然后在广播之前将其移动到device。默认为 None

返回

None 。如果 rank 是组的一部分,object_list 将包含来自 src rank 的广播对象。

object_list 中的可挑选对象广播给整个组。与 broadcast() 类似,但可以传入 Python 对象。请注意,object_list 中的所有对象都必须是可picklable 的才能进行广播。

注意

对于基于 NCCL 的处理组,对象的内部张量表示必须在通信发生之前移动到 GPU 设备。在这种情况下,使用的设备由 torch.cuda.current_device() 给出,用户有责任通过 torch.cuda.set_device() 确保将其设置为每个等级都有单独的 GPU。

注意

请注意,此 API 与 all_gather() 集合略有不同,因为它不提供 async_op 句柄,因此将是阻塞调用。

警告

broadcast_object_list() 隐式使用 pickle 模块,已知这是不安全的。可以构造恶意的 pickle 数据,该数据将在 unpickling 期间执行任意代码。仅使用您信任的数据调用此函数。

例子:

>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> if dist.get_rank() == 0:
>>>     # Assumes world_size of 3.
>>>     objects = ["foo", 12, {1: 2}] # any picklable object
>>> else:
>>>     objects = [None, None, None]
>>> # Assumes backend is not NCCL
>>> device = torch.device("cpu")
>>> dist.broadcast_object_list(objects, src=0, device=device)
>>> broadcast_objects
['foo', 12, {1: 2}]

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributed.broadcast_object_list。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。