當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch Function用法及代碼示例


本文簡要介紹python語言中 torch.autograd.Function 的用法。

用法:

class torch.autograd.Function(*args, **kwargs)

創建自定義autograd.Function的基類

要創建自定義 autograd.Function ,請子類化此類並實現 forward() 和 :meth`backward` 靜態方法。然後,要在前向傳遞中使用自定義操作,請調用類方法 apply 。不要直接調用 forward()

為確保正確性和最佳性能,請確保您在 ctx 上調用正確的方法並使用 torch.autograd.gradcheck() 驗證您的後向函數。

有關如何使用此類的更多詳細信息,請參閱擴展torch.autograd。

例子:

>>> class Exp(Function):
>>>     @staticmethod
>>>     def forward(ctx, i):
>>>         result = i.exp()
>>>         ctx.save_for_backward(result)
>>>         return result
>>>
>>>     @staticmethod
>>>     def backward(ctx, grad_output):
>>>         result, = ctx.saved_tensors
>>>         return grad_output * result
>>>
>>> # Use it by calling the apply method:
>>> output = Exp.apply(input)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.autograd.Function。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。