当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch FunctionCtx.set_materialize_grads用法及代码示例


本文简要介绍python语言中 torch.autograd.function.FunctionCtx.set_materialize_grads 的用法。

用法:

FunctionCtx.set_materialize_grads(value)

设置是否物化输出梯度张量。默认为 True

这应该只从内部调用 forward() 方法

如果 True ,未定义的输出梯度张量将在调用 backward() 方法之前扩展为全零张量。

例子:

>>> class SimpleFunc(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         return g1 + g2  # No check for None necessary
>>>
>>> # We modify SimpleFunc to handle non-materialized grad outputs
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         ctx.set_materialize_grads(False)
>>>         ctx.save_for_backward(x)
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         x, = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         if g1 is not None:  # We must check for None now
>>>             grad_input += g1
>>>         if g2 is not None:
>>>             grad_input += g2
>>>         return grad_input
>>>
>>> a = torch.tensor(1., requires_grad=True)
>>> b, _ = Func.apply(a)  # induces g2 to be undefined

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.autograd.function.FunctionCtx.set_materialize_grads。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。