当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch RRef.backward用法及代码示例


本文简要介绍python语言中 torch.distributed.rpc.RRef.backward 的用法。

用法:

backward(self: torch._C._distributed_rpc.PyRRef, dist_autograd_ctx_id: int = - 1, retain_graph: bool = False) → None

参数

  • dist_autograd_ctx_id(int,可选的) -我们应该检索梯度的分布式 autograd 上下文 id(默认值:-1)。

  • retain_graph(bool,可选的) -如果 False ,用于计算 grad 的图将被释放。请注意,几乎在所有情况下都不需要将此选项设置为True,并且通常可以以更有效的方式解决。通常,您需要将其设置为 True 以多次向后运行(默认值:False)。

使用 RRef 作为向后传递的根来运行向后传递。如果提供了 dist_autograd_ctx_id,我们将使用提供的 ctx_id 从 RRef 的所有者开始执行分布式向后传递。在这种情况下,应使用get_gradients() 来检索梯度。如果 dist_autograd_ctx_idNone ,则假设这是一个局部自动求导图,并且我们仅执行局部向后传递。在本地情况下,调用此 API 的节点必须是 RRef 的所有者。 RRef 的值预计是一个标量张量。

例子:

>>> import torch.distributed.autograd as dist_autograd
>>> with dist_autograd.context() as context_id:
>>>     rref.backward(context_id)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributed.rpc.RRef.backward。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。