当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch FunctionCtx.save_for_backward用法及代码示例


本文简要介绍python语言中 torch.autograd.function.FunctionCtx.save_for_backward 的用法。

用法:

FunctionCtx.save_for_backward(*tensors)

保存给定的张量以供将来调用 backward()

这应该最多调用一次,并且只能从内部调用 forward() 方法。这只能用输入或输出张量调用

backward() 中,可以通过 saved_tensors 属性访问保存的张量。在将它们返回给用户之前,会进行检查以确保它们没有用于任何修改其内容的就地操作。

参数也可以是 None 。这是no-op。

有关如何使用此方法的更多详细信息,请参阅扩展torch.autograd。

例子:

>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         w = x * y * z
>>>         out = x * y + y * z + w
>>>         ctx.save_for_backward(x, y, out)
>>>         ctx.z = z  # z is not a tensor
>>>         ctx.w = w  # w is neither input nor output
>>>         return out
>>>
>>>     @staticmethod
>>>     def backward(ctx, grad_out):
>>>         x, y, out = ctx.saved_tensors
>>>         z = ctx.z
>>>         gx = grad_out * (y + y * z)
>>>         gy = grad_out * (x + z + x * z)
>>>         gz = None
>>>         return gx, gy, gz
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>> c = 4
>>> d = Func.apply(a, b, c)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.autograd.function.FunctionCtx.save_for_backward。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。