当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch FunctionCtx.mark_non_differentiable用法及代码示例


本文简要介绍python语言中 torch.autograd.function.FunctionCtx.mark_non_differentiable 的用法。

用法:

FunctionCtx.mark_non_differentiable(*args)

将输出标记为不可微分。

这应该最多调用一次,只能从内部调用 forward() 方法,所有参数都应该是张量输出。

这会将输出标记为不需要梯度,从而提高反向计算的效率。您仍然需要为 backward() 中的每个输出接受一个梯度,但它始终是一个零张量,其形状与相应输出的形状相同。

这用于例如对于从排序返回的索引。见例子::
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         sorted, idx = x.sort()
>>>         ctx.mark_non_differentiable(idx)
>>>         ctx.save_for_backward(x, idx)
>>>         return sorted, idx
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):  # still need to accept g2
>>>         x, idx = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         grad_input.index_add_(0, idx, g1)
>>>         return grad_input

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.autograd.function.FunctionCtx.mark_non_differentiable。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。