当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch gather_object用法及代码示例


本文简要介绍python语言中 torch.distributed.gather_object 的用法。

用法:

torch.distributed.gather_object(obj, object_gather_list=None, dst=0, group=None)

参数

  • obj(任何) -输入对象。必须是 picklable 的。

  • object_gather_list(list[任何]) -输出列表。在dst 等级上,它的大小应该正确地调整为该集合的组大小,并将包含输出。在非 dst 等级上必须是 None。 (默认为 None )

  • dst(int,可选的) -目的地排名。 (默认为 0)

  • group-(ProcessGroup,可选):要处理的进程组。如果没有,将使用默认进程组。默认为 None

返回

没有。在dst 等级上,object_gather_list 将包含集体的输出。

在单个进程中从整个组中收集 picklable 的对象。与gather()类似,但可以传入Python对象。请注意,该对象必须是可picklable的才能被收集。

注意

请注意,此 API 与聚集集合略有不同,因为它不提供 async_op 句柄,因此将是一个阻塞调用。

注意

请注意,使用 NCCL 后端时不支持此 API。

警告

gather_object() 隐式使用 pickle 模块,已知这是不安全的。可以构造恶意的 pickle 数据,该数据将在 unpickling 期间执行任意代码。仅使用您信任的数据调用此函数。

例子:

>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes world_size of 3.
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>> output = [None for _ in gather_objects]
>>> dist.gather_object(
        gather_objects[dist.get_rank()],
        output if dist.get_rank() == 0 else None,
        dst=0
    )
>>> # On rank 0
>>> output
['foo', 12, {1: 2}]

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributed.gather_object。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。