本文簡要介紹python語言中 torch.fx.replace_pattern
的用法。
用法:
torch.fx.replace_pattern(gm, pattern, replacement)
gm-包裝要操作的 Graph 的 GraphModule
pattern-
gm
中匹配的子圖進行替換replacement-替換
pattern
的子圖
Match
對象列表,表示原始圖中與pattern
匹配的位置。如果沒有匹配項,則列表為空。Match
定義為:class Match(NamedTuple): # Node from which the match was found anchor: Node # Maps nodes in the pattern subgraph to nodes in the larger graph nodes_map: Dict[Node, Node]
列表[匹配]
匹配 GraphModule (
gm
) 圖中所有可能的非重疊運算符集及其數據依賴性 (pattern
),然後將每個匹配的子圖替換為另一個子圖 (replacement
)。例子:
import torch from torch.fx import symbolic_trace, subgraph_rewriter class M(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, w1, w2): m1 = torch.cat([w1, w2]).sum() m2 = torch.cat([w1, w2]).sum() return x + torch.max(m1) + torch.max(m2) def pattern(w1, w2): return torch.cat([w1, w2]).sum() def replacement(w1, w2): return torch.stack([w1, w2]) traced_module = symbolic_trace(M()) subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
上麵的代碼將首先匹配
traced_module
的forward
方法中的pattern
。 Pattern-matching 是基於 use-def 關係完成的,而不是節點名稱。例如,如果您在pattern
中有p = torch.cat([a, b])
,則可以匹配原始forward
函數中的m = torch.cat([a, b])
,盡管變量名稱不同(p
與m
)。pattern
中的return
語句僅根據其值進行匹配;它可能與較大圖中的return
語句匹配,也可能不匹配。換句話說,模式不必延伸到更大圖的末尾。當模式匹配時,它將從較大的函數中刪除並替換為
replacement
。如果較大的函數中有多個pattern
匹配,則每個不重疊的匹配都將被替換。在匹配重疊的情況下,重疊匹配集中第一個找到的匹配將被替換。 (“First” 在這裏被定義為節點的 use-def 關係的拓撲排序中的第一個。在大多數情況下,第一個 Node 是直接出現在self
之後的參數,而最後一個 Node 是函數返回的任何內容.)需要注意的重要一點是
pattern
Callable 的參數必須在Callable 本身中使用,並且replacement
Callable 的參數必須與模式匹配。第一條規則是為什麽在上麵的代碼塊中,forward
函數有參數x, w1, w2
,但pattern
函數隻有參數w1, w2
。pattern
不使用x
,因此不應將x
指定為參數。作為第二條規則的示例,請考慮替換def pattern(x, y): return torch.neg(x) + torch.relu(y)
和
def replacement(x, y): return torch.relu(x)
在這種情況下,
replacement
需要與pattern
相同數量的參數(x
和y
),即使在replacement
中沒有使用參數y
。調用
subgraph_rewriter.replace_pattern
後,生成的 Python 代碼如下所示:def forward(self, x, w1, w2): stack_1 = torch.stack([w1, w2]) sum_1 = stack_1.sum() stack_2 = torch.stack([w1, w2]) sum_2 = stack_2.sum() max_1 = torch.max(sum_1) add_1 = x + max_1 max_2 = torch.max(sum_2) add_2 = add_1 + max_2 return add_2
注意
保證此 API 的向後兼容性。
參數:
返回:
返回類型:
相關用法
- Python PyTorch repeat_interleave用法及代碼示例
- Python PyTorch renorm用法及代碼示例
- Python PyTorch reshape用法及代碼示例
- Python PyTorch real用法及代碼示例
- Python PyTorch remove用法及代碼示例
- Python PyTorch read_vec_flt_ark用法及代碼示例
- Python PyTorch register_kl用法及代碼示例
- Python PyTorch read_vec_int_ark用法及代碼示例
- Python PyTorch resolve_neg用法及代碼示例
- Python PyTorch remainder用法及代碼示例
- Python PyTorch register_module_forward_pre_hook用法及代碼示例
- Python PyTorch remote用法及代碼示例
- Python PyTorch register_module_full_backward_hook用法及代碼示例
- Python PyTorch remove_spectral_norm用法及代碼示例
- Python PyTorch record用法及代碼示例
- Python PyTorch remove_weight_norm用法及代碼示例
- Python PyTorch retinanet_resnet50_fpn用法及代碼示例
- Python PyTorch read_vec_flt_scp用法及代碼示例
- Python PyTorch resolve_conj用法及代碼示例
- Python PyTorch register_parametrization用法及代碼示例
- Python PyTorch reciprocal用法及代碼示例
- Python PyTorch result_type用法及代碼示例
- Python PyTorch register_module_forward_hook用法及代碼示例
- Python PyTorch read_mat_scp用法及代碼示例
- Python PyTorch read_mat_ark用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.fx.replace_pattern。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。