当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch saved_tensors_hooks用法及代码示例


本文简要介绍python语言中 torch.autograd.graph.saved_tensors_hooks 的用法。

用法:

class torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook)

Context-manager 为保存的张量设置一对打包/解包钩子。

使用此context-manager 定义操作的中间结果应如何在保存之前打包,并在检索时解包。

在这种情况下,每次操作保存向后张量时都会调用 pack_hook 函数(这包括使用 save_for_backward() 保存的中间结果,也包括由 PyTorch-defined 操作记录的结果)。然后将pack_hook 的输出而不是原始张量存储在计算图中。

unpack_hook当需要访问保存的张量时,即执行时调用torch.Tensor.backward()或者torch.autograd.grad()。它以包装的返回的对象pack_hook并且应该返回一个与原始张量具有相同内容的张量(作为输入传递给相应的张量)pack_hook)。

挂钩应具有以下签名:

pack_hook(张量:张量) -> 任意

unpack_hook(任意) -> 张量

其中 pack_hook 的返回值是 unpack_hook 的有效输入。

通常,您希望 unpack_hook(pack_hook(t)) 在值、大小、dtype 和设备方面等于 t

例子:

>>> def pack_hook(x):
...     print("Packing", x)
...     return x
>>>
>>> def unpack_hook(x):
...     print("Unpacking", x)
...     return x
>>>
>>> a = torch.ones(5, requires_grad=True)
>>> b = torch.ones(5, requires_grad=True) * 2
>>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
...     y = a * b
Packing tensor([1., 1., 1., 1., 1.])
Packing tensor([2., 2., 2., 2., 2.])
>>> y.sum().backward()
Unpacking tensor([1., 1., 1., 1., 1.])
Unpacking tensor([2., 2., 2., 2., 2.])

警告

对任一钩子的输入执行就地操作可能会导致未定义的行为。

警告

一次只允许一对钩子。尚不支持递归嵌套此 context-manager。

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.autograd.graph.saved_tensors_hooks。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。