當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch FunctionCtx.set_materialize_grads用法及代碼示例


本文簡要介紹python語言中 torch.autograd.function.FunctionCtx.set_materialize_grads 的用法。

用法:

FunctionCtx.set_materialize_grads(value)

設置是否物化輸出梯度張量。默認為 True

這應該隻從內部調用 forward() 方法

如果 True ,未定義的輸出梯度張量將在調用 backward() 方法之前擴展為全零張量。

例子:

>>> class SimpleFunc(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         return g1 + g2  # No check for None necessary
>>>
>>> # We modify SimpleFunc to handle non-materialized grad outputs
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         ctx.set_materialize_grads(False)
>>>         ctx.save_for_backward(x)
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         x, = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         if g1 is not None:  # We must check for None now
>>>             grad_input += g1
>>>         if g2 is not None:
>>>             grad_input += g2
>>>         return grad_input
>>>
>>> a = torch.tensor(1., requires_grad=True)
>>> b, _ = Func.apply(a)  # induces g2 to be undefined

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.autograd.function.FunctionCtx.set_materialize_grads。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。