當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch context用法及代碼示例


本文簡要介紹python語言中 torch.distributed.autograd.context 的用法。

用法:

class torch.distributed.autograd.context

使用分布式 autograd 時包裝前向和後向傳遞的上下文對象。 with 語句中生成的context_id 需要唯一標識所有工作人員的分布式反向傳遞。每個工作人員都存儲與此 context_id 關聯的元數據,這是正確執行分布式 autograd pass 所必需的。

例子:

>>> import torch.distributed.autograd as dist_autograd
>>> with dist_autograd.context() as context_id:
>>>   t1 = torch.rand((3, 3), requires_grad=True)
>>>   t2 = torch.rand((3, 3), requires_grad=True)
>>>   loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum()
>>>   dist_autograd.backward(context_id, [loss])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.distributed.autograd.context。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。