当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch Function用法及代码示例


本文简要介绍python语言中 torch.autograd.Function 的用法。

用法:

class torch.autograd.Function(*args, **kwargs)

创建自定义autograd.Function的基类

要创建自定义 autograd.Function ,请子类化此类并实现 forward() 和 :meth`backward` 静态方法。然后,要在前向传递中使用自定义操作,请调用类方法 apply 。不要直接调用 forward()

为确保正确性和最佳性能,请确保您在 ctx 上调用正确的方法并使用 torch.autograd.gradcheck() 验证您的后向函数。

有关如何使用此类的更多详细信息,请参阅扩展torch.autograd。

例子:

>>> class Exp(Function):
>>>     @staticmethod
>>>     def forward(ctx, i):
>>>         result = i.exp()
>>>         ctx.save_for_backward(result)
>>>         return result
>>>
>>>     @staticmethod
>>>     def backward(ctx, grad_output):
>>>         result, = ctx.saved_tensors
>>>         return grad_output * result
>>>
>>> # Use it by calling the apply method:
>>> output = Exp.apply(input)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.autograd.Function。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。