當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch replace_pattern用法及代碼示例


本文簡要介紹python語言中 torch.fx.replace_pattern 的用法。

用法:

torch.fx.replace_pattern(gm, pattern, replacement)

參數

  • gm-包裝要操作的 Graph 的 GraphModule

  • pattern-gm 中匹配的子圖進行替換

  • replacement-替換pattern 的子圖

返回

Match 對象列表,表示原始圖中與 pattern 匹配的位置。如果沒有匹配項,則列表為空。 Match 定義為:

class Match(NamedTuple):
    # Node from which the match was found
    anchor: Node
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node]

返回類型

列表[匹配]

匹配 GraphModule ( gm ) 圖中所有可能的非重疊運算符集及其數據依賴性 ( pattern ),然後將每個匹配的子圖替換為另一個子圖 ( replacement )。

例子:

import torch
from torch.fx import symbolic_trace, subgraph_rewriter

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, w1, w2):
        m1 = torch.cat([w1, w2]).sum()
        m2 = torch.cat([w1, w2]).sum()
        return x + torch.max(m1) + torch.max(m2)

def pattern(w1, w2):
    return torch.cat([w1, w2]).sum()

def replacement(w1, w2):
    return torch.stack([w1, w2])

traced_module = symbolic_trace(M())

subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)

上麵的代碼將首先匹配 traced_moduleforward 方法中的 pattern 。 Pattern-matching 是基於 use-def 關係完成的,而不是節點名稱。例如,如果您在 pattern 中有 p = torch.cat([a, b]) ,則可以匹配原始 forward 函數中的 m = torch.cat([a, b]),盡管變量名稱不同( pm )。

pattern 中的return 語句僅根據其值進行匹配;它可能與較大圖中的 return 語句匹配,也可能不匹配。換句話說,模式不必延伸到更大圖的末尾。

當模式匹配時,它將從較大的函數中刪除並替換為 replacement 。如果較大的函數中有多個pattern 匹配,則每個不重疊的匹配都將被替換。在匹配重疊的情況下,重疊匹配集中第一個找到的匹配將被替換。 (“First” 在這裏被定義為節點的 use-def 關係的拓撲排序中的第一個。在大多數情況下,第一個 Node 是直接出現在 self 之後的參數,而最後一個 Node 是函數返回的任何內容.)

需要注意的重要一點是pattern Callable 的參數必須在Callable 本身中使用,並且replacement Callable 的參數必須與模式匹配。第一條規則是為什麽在上麵的代碼塊中,forward 函數有參數 x, w1, w2 ,但 pattern 函數隻有參數 w1, w2pattern 不使用 x ,因此不應將 x 指定為參數。作為第二條規則的示例,請考慮替換

def pattern(x, y):
    return torch.neg(x) + torch.relu(y)

def replacement(x, y):
    return torch.relu(x)

在這種情況下,replacement 需要與 pattern 相同數量的參數(xy ),即使在 replacement 中沒有使用參數 y

調用 subgraph_rewriter.replace_pattern 後,生成的 Python 代碼如下所示:

def forward(self, x, w1, w2):
    stack_1 = torch.stack([w1, w2])
    sum_1 = stack_1.sum()
    stack_2 = torch.stack([w1, w2])
    sum_2 = stack_2.sum()
    max_1 = torch.max(sum_1)
    add_1 = x + max_1
    max_2 = torch.max(sum_2)
    add_2 = add_1 + max_2
    return add_2

注意

保證此 API 的向後兼容性。

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.fx.replace_pattern。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。