當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch FunctionCtx.save_for_backward用法及代碼示例


本文簡要介紹python語言中 torch.autograd.function.FunctionCtx.save_for_backward 的用法。

用法:

FunctionCtx.save_for_backward(*tensors)

保存給定的張量以供將來調用 backward()

這應該最多調用一次,並且隻能從內部調用 forward() 方法。這隻能用輸入或輸出張量調用

backward() 中,可以通過 saved_tensors 屬性訪問保存的張量。在將它們返回給用戶之前,會進行檢查以確保它們沒有用於任何修改其內容的就地操作。

參數也可以是 None 。這是no-op。

有關如何使用此方法的更多詳細信息,請參閱擴展torch.autograd。

例子:

>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         w = x * y * z
>>>         out = x * y + y * z + w
>>>         ctx.save_for_backward(x, y, out)
>>>         ctx.z = z  # z is not a tensor
>>>         ctx.w = w  # w is neither input nor output
>>>         return out
>>>
>>>     @staticmethod
>>>     def backward(ctx, grad_out):
>>>         x, y, out = ctx.saved_tensors
>>>         z = ctx.z
>>>         gx = grad_out * (y + y * z)
>>>         gy = grad_out * (x + z + x * z)
>>>         gz = None
>>>         return gx, gy, gz
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>> c = 4
>>> d = Func.apply(a, b, c)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.autograd.function.FunctionCtx.save_for_backward。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。