当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch Tensor.register_hook用法及代码示例


本文简要介绍python语言中 torch.Tensor.register_hook 的用法。

用法:

Tensor.register_hook(hook)

注册一个后向钩子。

每次计算相对于张量的梯度时,都会调用该钩子。挂钩应具有以下签名:

hook(grad) -> Tensor or None

钩子不应该修改它的参数,但它可以选择返回一个新的渐变,它将用来代替 grad

此函数返回一个带有方法handle.remove() 的句柄,该方法从模块中删除钩子。

例子:

>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2)  # double the gradient
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v.grad

 2
 4
 6
[torch.FloatTensor of size (3,)]

>>> h.remove()  # removes the hook

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.Tensor.register_hook。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。