當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch Graph用法及代碼示例


本文簡要介紹python語言中 torch.fx.Graph 的用法。

用法:

class torch.fx.Graph(owning_module=None, tracer_cls=None)

Graph 是 FX 中間表示中使用的主要數據結構。它由一係列 Node 組成,每個代表調用點(或其他句法結構)。 Node 的列表一起構成了一個有效的 Python 函數。

例如下麵的代碼

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

將產生以下圖表:

print(gm.graph)
graph(x):
    %linear_weight : [#users=1] = self.linear.weight
    %add_1 : [#users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
    %linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
    %relu_1 : [#users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
    %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
    %topk_1 : [#users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
    return topk_1

有關 Graph 中表示的操作的語義,請參閱 Node

注意

保證此 API 的向後兼容性。

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.fx.Graph。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。