本文简要介绍python语言中 torch.fx.symbolic_trace
的用法。
用法:
torch.fx.symbolic_trace(root, concrete_args=None, enable_cpatching=False)
root(联盟[torch.nn.Module,可调用]) -要跟踪并转换为图形表示的模块或函数。
concrete_args(可选的[字典[str,任何]]) - 部分特化的输入
enable_cpatching-启用 C-level 修补函数(捕获类似
torch.randn
的内容)
从
root
记录的操作创建的模块。符号跟踪 API
给定一个
nn.Module
或函数实例root
,此函数将返回一个GraphModule
,该GraphModule
是通过记录在跟踪root
时看到的操作构造的。concrete_args
允许您对函数进行部分专门化,无论是删除控制流还是数据结构。例如:
def f(a, b): if b == True: return a else: return a*2
由于存在控制流,FX 通常无法通过此跟踪。但是,我们可以使用
concrete_args
来专门研究b
的值来跟踪它。f = fx.symbolic_trace(f, concrete_args={‘b’: False}) 断言 f(3, False) == 6
请注意,尽管您仍然可以传入不同的
b
值,但它们将被忽略。我们还可以使用
concrete_args
来消除函数中的data-structure 处理。这将使用 pytrees 来展平您的输入。为避免过度特化,请为不应特化的值传入fx.PH
。例如:def f(x): out = 0 for v in x.values(): out += v return out f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}}) assert f({'a': 1, 'b': 2, 'c': 4}) == 7
注意
保证此 API 的向后兼容性。
参数:
返回:
返回类型:
相关用法
- Python PyTorch symeig用法及代码示例
- Python PyTorch saved_tensors_hooks用法及代码示例
- Python PyTorch sqrt用法及代码示例
- Python PyTorch skippable用法及代码示例
- Python PyTorch squeeze用法及代码示例
- Python PyTorch square用法及代码示例
- Python PyTorch save_on_cpu用法及代码示例
- Python PyTorch scatter_object_list用法及代码示例
- Python PyTorch skip_init用法及代码示例
- Python PyTorch simple_space_split用法及代码示例
- Python PyTorch sum用法及代码示例
- Python PyTorch sub用法及代码示例
- Python PyTorch sparse_csr_tensor用法及代码示例
- Python PyTorch sentencepiece_numericalizer用法及代码示例
- Python PyTorch sinh用法及代码示例
- Python PyTorch sinc用法及代码示例
- Python PyTorch std_mean用法及代码示例
- Python PyTorch spectral_norm用法及代码示例
- Python PyTorch slogdet用法及代码示例
- Python PyTorch shutdown用法及代码示例
- Python PyTorch sgn用法及代码示例
- Python PyTorch set_flush_denormal用法及代码示例
- Python PyTorch set_default_dtype用法及代码示例
- Python PyTorch signbit用法及代码示例
- Python PyTorch sort用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.fx.symbolic_trace。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。