當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch Tensor.register_hook用法及代碼示例

本文簡要介紹python語言中 torch.Tensor.register_hook 的用法。

用法:

Tensor.register_hook(hook)

注冊一個後向鉤子。

每次計算相對於張量的梯度時,都會調用該鉤子。掛鉤應具有以下簽名:

hook(grad) -> Tensor or None

鉤子不應該修改它的參數,但它可以選擇返回一個新的漸變,它將用來代替 grad

此函數返回一個帶有方法handle.remove() 的句柄,該方法從模塊中刪除鉤子。

例子:

>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2)  # double the gradient
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v.grad

 2
 4
 6
[torch.FloatTensor of size (3,)]

>>> h.remove()  # removes the hook

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.Tensor.register_hook。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。