當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch RRef.backward用法及代碼示例


本文簡要介紹python語言中 torch.distributed.rpc.RRef.backward 的用法。

用法:

backward(self: torch._C._distributed_rpc.PyRRef, dist_autograd_ctx_id: int = - 1, retain_graph: bool = False) → None

參數

  • dist_autograd_ctx_id(int,可選的) -我們應該檢索梯度的分布式 autograd 上下文 id(默認值:-1)。

  • retain_graph(bool,可選的) -如果 False ,用於計算 grad 的圖將被釋放。請注意,幾乎在所有情況下都不需要將此選項設置為True,並且通常可以以更有效的方式解決。通常,您需要將其設置為 True 以多次向後運行(默認值:False)。

使用 RRef 作為向後傳遞的根來運行向後傳遞。如果提供了 dist_autograd_ctx_id,我們將使用提供的 ctx_id 從 RRef 的所有者開始執行分布式向後傳遞。在這種情況下,應使用get_gradients() 來檢索梯度。如果 dist_autograd_ctx_idNone ,則假設這是一個局部自動求導圖,並且我們僅執行局部向後傳遞。在本地情況下,調用此 API 的節點必須是 RRef 的所有者。 RRef 的值預計是一個標量張量。

例子:

>>> import torch.distributed.autograd as dist_autograd
>>> with dist_autograd.context() as context_id:
>>>     rref.backward(context_id)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.distributed.rpc.RRef.backward。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。