当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch Graph用法及代码示例


本文简要介绍python语言中 torch.fx.Graph 的用法。

用法:

class torch.fx.Graph(owning_module=None, tracer_cls=None)

Graph 是 FX 中间表示中使用的主要数据结构。它由一系列 Node 组成,每个代表调用点(或其他句法结构)。 Node 的列表一起构成了一个有效的 Python 函数。

例如下面的代码

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

将产生以下图表:

print(gm.graph)
graph(x):
    %linear_weight : [#users=1] = self.linear.weight
    %add_1 : [#users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
    %linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
    %relu_1 : [#users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
    %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
    %topk_1 : [#users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
    return topk_1

有关 Graph 中表示的操作的语义,请参阅 Node

注意

保证此 API 的向后兼容性。

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.fx.Graph。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。