當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch handle_torch_function用法及代碼示例


本文簡要介紹python語言中 torch.overrides.handle_torch_function 的用法。

用法:

torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)

參數

  • public_api(函數) -由公共 torch API 公開的函數最初稱為 public_api(*args, **kwargs),現在正在檢查其參數。

  • relevant_args(可迭代的) -檢查 __torch_function__ 方法的可迭代參數。

  • args(tuple) -最初傳遞給 public_api 的任意位置參數。

  • kwargs(tuple) -最初傳遞給 public_api 的任意關鍵字參數。

返回

調用 implementation__torch_function__ 方法的結果,視情況而定。

返回類型

對象

實現一個檢查__torch_function__ 覆蓋的函數。

請參閱 torch::autograd::handle_torch_function 以了解 C++ 實現中此函數的等效項。

:引發類型錯誤:如果未找到實現。:

示例

>>> def func(a):
...     if type(a) is not torch.Tensor:  # This will make func dispatchable by __torch_function__
...         return handle_torch_function(func, (a,), a)
...     return a + 0

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.overrides.handle_torch_function。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。