當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch saved_tensors_hooks用法及代碼示例


本文簡要介紹python語言中 torch.autograd.graph.saved_tensors_hooks 的用法。

用法:

class torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook)

Context-manager 為保存的張量設置一對打包/解包鉤子。

使用此context-manager 定義操作的中間結果應如何在保存之前打包,並在檢索時解包。

在這種情況下,每次操作保存向後張量時都會調用 pack_hook 函數(這包括使用 save_for_backward() 保存的中間結果,也包括由 PyTorch-defined 操作記錄的結果)。然後將pack_hook 的輸出而不是原始張量存儲在計算圖中。

unpack_hook當需要訪問保存的張量時,即執行時調用torch.Tensor.backward()或者torch.autograd.grad()。它以包裝的返回的對象pack_hook並且應該返回一個與原始張量具有相同內容的張量(作為輸入傳遞給相應的張量)pack_hook)。

掛鉤應具有以下簽名:

pack_hook(張量:張量) -> 任意

unpack_hook(任意) -> 張量

其中 pack_hook 的返回值是 unpack_hook 的有效輸入。

通常,您希望 unpack_hook(pack_hook(t)) 在值、大小、dtype 和設備方麵等於 t

例子:

>>> def pack_hook(x):
...     print("Packing", x)
...     return x
>>>
>>> def unpack_hook(x):
...     print("Unpacking", x)
...     return x
>>>
>>> a = torch.ones(5, requires_grad=True)
>>> b = torch.ones(5, requires_grad=True) * 2
>>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
...     y = a * b
Packing tensor([1., 1., 1., 1., 1.])
Packing tensor([2., 2., 2., 2., 2.])
>>> y.sum().backward()
Unpacking tensor([1., 1., 1., 1., 1.])
Unpacking tensor([2., 2., 2., 2., 2.])

警告

對任一鉤子的輸入執行就地操作可能會導致未定義的行為。

警告

一次隻允許一對鉤子。尚不支持遞歸嵌套此 context-manager。

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.autograd.graph.saved_tensors_hooks。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。