本文简要介绍python语言中 torch.autograd.graph.saved_tensors_hooks
的用法。
用法:
class torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook)
Context-manager 为保存的张量设置一对打包/解包钩子。
使用此context-manager 定义操作的中间结果应如何在保存之前打包,并在检索时解包。
在这种情况下,每次操作保存向后张量时都会调用
pack_hook
函数(这包括使用save_for_backward()
保存的中间结果,也包括由 PyTorch-defined 操作记录的结果)。然后将pack_hook
的输出而不是原始张量存储在计算图中。unpack_hook
当需要访问保存的张量时,即执行时调用torch.Tensor.backward()
或者torch.autograd.grad()
。它以包装的返回的对象pack_hook
并且应该返回一个与原始张量具有相同内容的张量(作为输入传递给相应的张量)pack_hook
)。挂钩应具有以下签名:
pack_hook(张量:张量) -> 任意
unpack_hook(任意) -> 张量
其中
pack_hook
的返回值是unpack_hook
的有效输入。通常,您希望
unpack_hook(pack_hook(t))
在值、大小、dtype 和设备方面等于t
。例子:
>>> def pack_hook(x): ... print("Packing", x) ... return x >>> >>> def unpack_hook(x): ... print("Unpacking", x) ... return x >>> >>> a = torch.ones(5, requires_grad=True) >>> b = torch.ones(5, requires_grad=True) * 2 >>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): ... y = a * b Packing tensor([1., 1., 1., 1., 1.]) Packing tensor([2., 2., 2., 2., 2.]) >>> y.sum().backward() Unpacking tensor([1., 1., 1., 1., 1.]) Unpacking tensor([2., 2., 2., 2., 2.])
警告
对任一钩子的输入执行就地操作可能会导致未定义的行为。
警告
一次只允许一对钩子。尚不支持递归嵌套此 context-manager。
相关用法
- Python PyTorch save_on_cpu用法及代码示例
- Python PyTorch save用法及代码示例
- Python PyTorch sqrt用法及代码示例
- Python PyTorch skippable用法及代码示例
- Python PyTorch squeeze用法及代码示例
- Python PyTorch square用法及代码示例
- Python PyTorch scatter_object_list用法及代码示例
- Python PyTorch skip_init用法及代码示例
- Python PyTorch simple_space_split用法及代码示例
- Python PyTorch sum用法及代码示例
- Python PyTorch sub用法及代码示例
- Python PyTorch sparse_csr_tensor用法及代码示例
- Python PyTorch sentencepiece_numericalizer用法及代码示例
- Python PyTorch symeig用法及代码示例
- Python PyTorch sinh用法及代码示例
- Python PyTorch sinc用法及代码示例
- Python PyTorch std_mean用法及代码示例
- Python PyTorch spectral_norm用法及代码示例
- Python PyTorch slogdet用法及代码示例
- Python PyTorch symbolic_trace用法及代码示例
- Python PyTorch shutdown用法及代码示例
- Python PyTorch sgn用法及代码示例
- Python PyTorch set_flush_denormal用法及代码示例
- Python PyTorch set_default_dtype用法及代码示例
- Python PyTorch signbit用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.autograd.graph.saved_tensors_hooks。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。