当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch context用法及代码示例


本文简要介绍python语言中 torch.distributed.autograd.context 的用法。

用法:

class torch.distributed.autograd.context

使用分布式 autograd 时包装前向和后向传递的上下文对象。 with 语句中生成的context_id 需要唯一标识所有工作人员的分布式反向传递。每个工作人员都存储与此 context_id 关联的元数据,这是正确执行分布式 autograd pass 所必需的。

例子:

>>> import torch.distributed.autograd as dist_autograd
>>> with dist_autograd.context() as context_id:
>>>   t1 = torch.rand((3, 3), requires_grad=True)
>>>   t2 = torch.rand((3, 3), requires_grad=True)
>>>   loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum()
>>>   dist_autograd.backward(context_id, [loss])

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributed.autograd.context。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。