当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch all_gather_object用法及代码示例


本文简要介绍python语言中 torch.distributed.all_gather_object 的用法。

用法:

torch.distributed.all_gather_object(object_list, obj, group=None)

参数

  • object_list(list[任何]) -输出列表。它的大小应该正确地调整为该集合的组大小,并将包含输出。

  • object(任何) -要从当前进程广播的可选 Python 对象。

  • group(ProcessGroup,可选的) -要处理的流程组。如果没有,将使用默认进程组。默认为 None

返回

没有。如果调用等级是该组的一部分,则集合的输出将填充到输入 object_list 中。如果调用等级不是组的一部分,则传入的object_list 将不被修改。

将整个组中的 picklable 对象收集到一个列表中。与all_gather()类似,但可以传入Python对象。请注意,该对象必须是可picklable的才能被收集。

注意

请注意,此 API 与 all_gather() 集合略有不同,因为它不提供 async_op 句柄,因此将是阻塞调用。

注意

对于基于 NCCL 的处理组,对象的内部张量表示必须在通信发生之前移动到 GPU 设备。在这种情况下,使用的设备由 torch.cuda.current_device() 给出,用户有责任通过 torch.cuda.set_device() 确保将其设置为每个等级都有单独的 GPU。

警告

all_gather_object() 隐式使用 pickle 模块,已知这是不安全的。可以构造恶意的 pickle 数据,该数据将在 unpickling 期间执行任意代码。仅使用您信任的数据调用此函数。

例子:

>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes world_size of 3.
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>> output = [None for _ in gather_objects]
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()])
>>> output
['foo', 12, {1: 2}]

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributed.all_gather_object。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。