當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch Interpreter用法及代碼示例


本文簡要介紹python語言中 torch.fx.Interpreter 的用法。

用法:

class torch.fx.Interpreter(module, garbage_collect_values=True)

參數

  • module(GraphModule) -要執行的模塊

  • garbage_collect_values(bool) -是否在模塊執行中最後一次使用後刪除值。這確保了執行期間的最佳內存使用。這可以被禁用,例如,通過查看Interpreter.env 屬性來檢查執行中的所有中間值。

解釋器執行 FX 圖Node-by-Node。這種模式可以用於許多事情,包括編寫代碼轉換以及分析傳遞。

可以重寫解釋器類中的方法以自定義執行行為。在調用層次結構方麵的可覆蓋方法的映射:

run()
    +-- run_node
        +-- placeholder()
        +-- get_attr()
        +-- call_function()
        +-- call_method()
        +-- call_module()
        +-- output()

示例

假設我們想用torch.sigmoid 交換所有torch.neg 實例,反之亦然(包括它們的Tensor 方法等價物)。我們可以像這樣子類解釋器:

class NegSigmSwapInterpreter(Interpreter):
    def call_function(self, target : Target,
                      args : Tuple, kwargs : Dict) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(n)

    def call_method(self, target : Target,
                    args : Tuple, kwargs : Dict) -> Any:
        if target == 'neg':
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(n)

def fn(x):
    return torch.sigmoid(x).neg()

gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
torch.testing.assert_allclose(result, torch.neg(input).sigmoid())

注意

保證此 API 的向後兼容性。

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.fx.Interpreter。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。