當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch FunctionCtx.mark_non_differentiable用法及代碼示例


本文簡要介紹python語言中 torch.autograd.function.FunctionCtx.mark_non_differentiable 的用法。

用法:

FunctionCtx.mark_non_differentiable(*args)

將輸出標記為不可微分。

這應該最多調用一次,隻能從內部調用 forward() 方法,所有參數都應該是張量輸出。

這會將輸出標記為不需要梯度,從而提高反向計算的效率。您仍然需要為 backward() 中的每個輸出接受一個梯度,但它始終是一個零張量,其形狀與相應輸出的形狀相同。

這用於例如對於從排序返回的索引。見例子::
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         sorted, idx = x.sort()
>>>         ctx.mark_non_differentiable(idx)
>>>         ctx.save_for_backward(x, idx)
>>>         return sorted, idx
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):  # still need to accept g2
>>>         x, idx = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         grad_input.index_add_(0, idx, g1)
>>>         return grad_input

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.autograd.function.FunctionCtx.mark_non_differentiable。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。