当前位置: 首页>>编程示例 >>用法及示例精选 >>正文


Python PyTorch wrap用法及代码示例

本文简要介绍python语言中 torch.fx.wrap 的用法。

用法:

torch.fx.wrap(fn_or_name)

参数

fn_or_name(联盟[str,可调用]) -调用时插入到图中的全局函数的函数或名称

可以在模块级范围调用此函数,将fn_or_name注册为“leaf function”。 “leaf function” 将在 FX 跟踪中保留为 CallFunction 节点,而不是通过以下方式进行跟踪:

# foo/bar/baz.py
def my_custom_function(x, y):
    return x * x + y * y

torch.fx.wrap('my_custom_function')

def fn_to_be_traced(x, y):
    # When symbolic tracing, the below call to my_custom_function will be inserted into
    # the graph rather than tracing it.
    return my_custom_function(x, y)

此函数也可以等效地用作装饰器:

# foo/bar/baz.py
@torch.fx.wrap
def my_custom_function(x, y):
    return x * x + y * y

包装函数可以被认为是“leaf function”,类似于“leaf modules” 的概念,也就是说,它们是作为调用留在FX 跟踪中而不是通过跟踪的函数。

注意

保证此 API 的向后兼容性。

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.fx.wrap。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。