當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch get_gradients用法及代碼示例


本文簡要介紹python語言中 torch.distributed.autograd.get_gradients 的用法。

用法:

torch.distributed.autograd.get_gradients(context_id: int) → Dict[Tensor, Tensor]

參數

context_id(int) -我們應該為其檢索梯度的 autograd 上下文 id。

返回

一個映射,其中鍵是張量,值是該張量的相關梯度。

檢索從張量到適當梯度的映射,該張量在與給定 context_id 對應的提供的上下文中累積,作為分布式 autograd 反向傳遞的一部分。

例子:

>>> import torch.distributed.autograd as dist_autograd
>>> with dist_autograd.context() as context_id:
>>>     t1 = torch.rand((3, 3), requires_grad=True)
>>>     t2 = torch.rand((3, 3), requires_grad=True)
>>>     loss = t1 + t2
>>>     dist_autograd.backward(context_id, [loss.sum()])
>>>     grads = dist_autograd.get_gradients(context_id)
>>>     print(grads[t1])
>>>     print(grads[t2])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.distributed.autograd.get_gradients。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。