当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch get_gradients用法及代码示例


本文简要介绍python语言中 torch.distributed.autograd.get_gradients 的用法。

用法:

torch.distributed.autograd.get_gradients(context_id: int) → Dict[Tensor, Tensor]

参数

context_id(int) -我们应该为其检索梯度的 autograd 上下文 id。

返回

一个映射,其中键是张量,值是该张量的相关梯度。

检索从张量到适当梯度的映射,该张量在与给定 context_id 对应的提供的上下文中累积,作为分布式 autograd 反向传递的一部分。

例子:

>>> import torch.distributed.autograd as dist_autograd
>>> with dist_autograd.context() as context_id:
>>>     t1 = torch.rand((3, 3), requires_grad=True)
>>>     t2 = torch.rand((3, 3), requires_grad=True)
>>>     loss = t1 + t2
>>>     dist_autograd.backward(context_id, [loss.sum()])
>>>     grads = dist_autograd.get_gradients(context_id)
>>>     print(grads[t1])
>>>     print(grads[t2])

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributed.autograd.get_gradients。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。