当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch Interpreter用法及代码示例


本文简要介绍python语言中 torch.fx.Interpreter 的用法。

用法:

class torch.fx.Interpreter(module, garbage_collect_values=True)

参数

  • module(GraphModule) -要执行的模块

  • garbage_collect_values(bool) -是否在模块执行中最后一次使用后删除值。这确保了执行期间的最佳内存使用。这可以被禁用,例如,通过查看Interpreter.env 属性来检查执行中的所有中间值。

解释器执行 FX 图Node-by-Node。这种模式可以用于许多事情,包括编写代码转换以及分析传递。

可以重写解释器类中的方法以自定义执行行为。在调用层次结构方面的可覆盖方法的映射:

run()
    +-- run_node
        +-- placeholder()
        +-- get_attr()
        +-- call_function()
        +-- call_method()
        +-- call_module()
        +-- output()

示例

假设我们想用torch.sigmoid 交换所有torch.neg 实例,反之亦然(包括它们的Tensor 方法等价物)。我们可以像这样子类解释器:

class NegSigmSwapInterpreter(Interpreter):
    def call_function(self, target : Target,
                      args : Tuple, kwargs : Dict) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(n)

    def call_method(self, target : Target,
                    args : Tuple, kwargs : Dict) -> Any:
        if target == 'neg':
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(n)

def fn(x):
    return torch.sigmoid(x).neg()

gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
torch.testing.assert_allclose(result, torch.neg(input).sigmoid())

注意

保证此 API 的向后兼容性。

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.fx.Interpreter。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。