当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch handle_torch_function用法及代码示例


本文简要介绍python语言中 torch.overrides.handle_torch_function 的用法。

用法:

torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)

参数

  • public_api(函数) -由公共 torch API 公开的函数最初称为 public_api(*args, **kwargs),现在正在检查其参数。

  • relevant_args(可迭代的) -检查 __torch_function__ 方法的可迭代参数。

  • args(tuple) -最初传递给 public_api 的任意位置参数。

  • kwargs(tuple) -最初传递给 public_api 的任意关键字参数。

返回

调用 implementation__torch_function__ 方法的结果,视情况而定。

返回类型

对象

实现一个检查__torch_function__ 覆盖的函数。

请参阅 torch::autograd::handle_torch_function 以了解 C++ 实现中此函数的等效项。

:引发类型错误:如果未找到实现。:

示例

>>> def func(a):
...     if type(a) is not torch.Tensor:  # This will make func dispatchable by __torch_function__
...         return handle_torch_function(func, (a,), a)
...     return a + 0

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.overrides.handle_torch_function。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。