本文简要介绍python语言中 torch.fx.replace_pattern
的用法。
用法:
torch.fx.replace_pattern(gm, pattern, replacement)
gm-包装要操作的 Graph 的 GraphModule
pattern-
gm
中匹配的子图进行替换replacement-替换
pattern
的子图
Match
对象列表,表示原始图中与pattern
匹配的位置。如果没有匹配项,则列表为空。Match
定义为:class Match(NamedTuple): # Node from which the match was found anchor: Node # Maps nodes in the pattern subgraph to nodes in the larger graph nodes_map: Dict[Node, Node]
列表[匹配]
匹配 GraphModule (
gm
) 图中所有可能的非重叠运算符集及其数据依赖性 (pattern
),然后将每个匹配的子图替换为另一个子图 (replacement
)。例子:
import torch from torch.fx import symbolic_trace, subgraph_rewriter class M(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, w1, w2): m1 = torch.cat([w1, w2]).sum() m2 = torch.cat([w1, w2]).sum() return x + torch.max(m1) + torch.max(m2) def pattern(w1, w2): return torch.cat([w1, w2]).sum() def replacement(w1, w2): return torch.stack([w1, w2]) traced_module = symbolic_trace(M()) subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
上面的代码将首先匹配
traced_module
的forward
方法中的pattern
。 Pattern-matching 是基于 use-def 关系完成的,而不是节点名称。例如,如果您在pattern
中有p = torch.cat([a, b])
,则可以匹配原始forward
函数中的m = torch.cat([a, b])
,尽管变量名称不同(p
与m
)。pattern
中的return
语句仅根据其值进行匹配;它可能与较大图中的return
语句匹配,也可能不匹配。换句话说,模式不必延伸到更大图的末尾。当模式匹配时,它将从较大的函数中删除并替换为
replacement
。如果较大的函数中有多个pattern
匹配,则每个不重叠的匹配都将被替换。在匹配重叠的情况下,重叠匹配集中第一个找到的匹配将被替换。 (“First” 在这里被定义为节点的 use-def 关系的拓扑排序中的第一个。在大多数情况下,第一个 Node 是直接出现在self
之后的参数,而最后一个 Node 是函数返回的任何内容.)需要注意的重要一点是
pattern
Callable 的参数必须在Callable 本身中使用,并且replacement
Callable 的参数必须与模式匹配。第一条规则是为什么在上面的代码块中,forward
函数有参数x, w1, w2
,但pattern
函数只有参数w1, w2
。pattern
不使用x
,因此不应将x
指定为参数。作为第二条规则的示例,请考虑替换def pattern(x, y): return torch.neg(x) + torch.relu(y)
和
def replacement(x, y): return torch.relu(x)
在这种情况下,
replacement
需要与pattern
相同数量的参数(x
和y
),即使在replacement
中没有使用参数y
。调用
subgraph_rewriter.replace_pattern
后,生成的 Python 代码如下所示:def forward(self, x, w1, w2): stack_1 = torch.stack([w1, w2]) sum_1 = stack_1.sum() stack_2 = torch.stack([w1, w2]) sum_2 = stack_2.sum() max_1 = torch.max(sum_1) add_1 = x + max_1 max_2 = torch.max(sum_2) add_2 = add_1 + max_2 return add_2
注意
保证此 API 的向后兼容性。
参数:
返回:
返回类型:
相关用法
- Python PyTorch repeat_interleave用法及代码示例
- Python PyTorch renorm用法及代码示例
- Python PyTorch reshape用法及代码示例
- Python PyTorch real用法及代码示例
- Python PyTorch remove用法及代码示例
- Python PyTorch read_vec_flt_ark用法及代码示例
- Python PyTorch register_kl用法及代码示例
- Python PyTorch read_vec_int_ark用法及代码示例
- Python PyTorch resolve_neg用法及代码示例
- Python PyTorch remainder用法及代码示例
- Python PyTorch register_module_forward_pre_hook用法及代码示例
- Python PyTorch remote用法及代码示例
- Python PyTorch register_module_full_backward_hook用法及代码示例
- Python PyTorch remove_spectral_norm用法及代码示例
- Python PyTorch record用法及代码示例
- Python PyTorch remove_weight_norm用法及代码示例
- Python PyTorch retinanet_resnet50_fpn用法及代码示例
- Python PyTorch read_vec_flt_scp用法及代码示例
- Python PyTorch resolve_conj用法及代码示例
- Python PyTorch register_parametrization用法及代码示例
- Python PyTorch reciprocal用法及代码示例
- Python PyTorch result_type用法及代码示例
- Python PyTorch register_module_forward_hook用法及代码示例
- Python PyTorch read_mat_scp用法及代码示例
- Python PyTorch read_mat_ark用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.fx.replace_pattern。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。