当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch symbolic_trace用法及代码示例


本文简要介绍python语言中 torch.fx.symbolic_trace 的用法。

用法:

torch.fx.symbolic_trace(root, concrete_args=None, enable_cpatching=False)

参数

  • root(联盟[torch.nn.Module,可调用]) -要跟踪并转换为图形表示的模块或函数。

  • concrete_args(可选的[字典[str,任何]]) - 部分特化的输入

  • enable_cpatching-启用 C-level 修补函数(捕获类似 torch.randn 的内容)

返回

root 记录的操作创建的模块。

返回类型

GraphModule

符号跟踪 API

给定一个 nn.Module 或函数实例 root ,此函数将返回一个 GraphModule ,该 GraphModule 是通过记录在跟踪 root 时看到的操作构造的。

concrete_args 允许您对函数进行部分专门化,无论是删除控制流还是数据结构。

例如:

def f(a, b):
    if b == True:
        return a
    else:
        return a*2

由于存在控制流,FX 通常无法通过此跟踪。但是,我们可以使用concrete_args 来专门研究b 的值来跟踪它。

f = fx.symbolic_trace(f, concrete_args={‘b’: False}) 断言 f(3, False) == 6

请注意,尽管您仍然可以传入不同的 b 值,但它们将被忽略。

我们还可以使用concrete_args 来消除函数中的data-structure 处理。这将使用 pytrees 来展平您的输入。为避免过度特化,请为不应特化的值传入fx.PH。例如:

def f(x):
    out = 0
    for v in x.values():
        out += v
    return out
f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
assert f({'a': 1, 'b': 2, 'c': 4}) == 7

注意

保证此 API 的向后兼容性。

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.fx.symbolic_trace。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。