當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch symbolic_trace用法及代碼示例


本文簡要介紹python語言中 torch.fx.symbolic_trace 的用法。

用法:

torch.fx.symbolic_trace(root, concrete_args=None, enable_cpatching=False)

參數

  • root(聯盟[torch.nn.Module,可調用]) -要跟蹤並轉換為圖形表示的模塊或函數。

  • concrete_args(可選的[字典[str,任何]]) - 部分特化的輸入

  • enable_cpatching-啟用 C-level 修補函數(捕獲類似 torch.randn 的內容)

返回

root 記錄的操作創建的模塊。

返回類型

GraphModule

符號跟蹤 API

給定一個 nn.Module 或函數實例 root ,此函數將返回一個 GraphModule ,該 GraphModule 是通過記錄在跟蹤 root 時看到的操作構造的。

concrete_args 允許您對函數進行部分專門化,無論是刪除控製流還是數據結構。

例如:

def f(a, b):
    if b == True:
        return a
    else:
        return a*2

由於存在控製流,FX 通常無法通過此跟蹤。但是,我們可以使用concrete_args 來專門研究b 的值來跟蹤它。

f = fx.symbolic_trace(f, concrete_args={‘b’: False}) 斷言 f(3, False) == 6

請注意,盡管您仍然可以傳入不同的 b 值,但它們將被忽略。

我們還可以使用concrete_args 來消除函數中的data-structure 處理。這將使用 pytrees 來展平您的輸入。為避免過度特化,請為不應特化的值傳入fx.PH。例如:

def f(x):
    out = 0
    for v in x.values():
        out += v
    return out
f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
assert f({'a': 1, 'b': 2, 'c': 4}) == 7

注意

保證此 API 的向後兼容性。

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.fx.symbolic_trace。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。