本文简要介绍python语言中 torchvision.models.feature_extraction.create_feature_extractor
的用法。
用法:
torchvision.models.feature_extraction.create_feature_extractor(model: torch.nn.modules.module.Module, return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, tracer_kwargs: Dict = {}, suppress_diff_warning: bool = False) → torch.fx.graph_module.GraphModule
model(nn.Module) -我们将在其上提取特征的模型
return_nodes(list或者dict,可选的) -
List
或Dict
包含将为其返回激活的节点的名称(或部分名称 - 参见上面的注释)。如果是Dict
,键是节点名称,值是用户为图形模块返回的字典指定的键。如果它是List
,则将其视为Dict
将节点规范字符串直接映射到输出名称。在指定train_return_nodes
和eval_return_nodes
的情况下,不应指定。train_return_nodes(list或者dict,可选的) -类似于
return_nodes
。如果 train 模式的返回节点与 eval 模式的返回节点不同,则可以使用此选项。如果指定了此项,则还必须指定eval_return_nodes
,并且不应指定return_nodes
。eval_return_nodes(list或者dict,可选的) -类似于
return_nodes
。如果 train 模式的返回节点与 eval 模式的返回节点不同,则可以使用此选项。如果指定了此项,则还必须指定train_return_nodes
,并且不应指定return_nodes
。tracer_kwargs(dict,可选的) -
NodePathTracer
的关键参数字典(将它们传递给它的父类 torch.fx.Tracer )。suppress_diff_warning(bool,可选的) -当图表的 train 和 eval 版本之间存在差异时是否抑制警告。默认为假。
创建一个新的图形模块,该模块将给定模型中的中间节点作为字典返回,用户指定的键作为字符串,请求的输出作为值。这是通过 FX 重写模型的计算图以返回所需节点作为输出来实现的。所有未使用的节点及其相应的参数都将被删除。
所需的输出节点必须指定为
.
分隔路径,将模块层次结构从顶层模块向下传递到叶操作或叶模块。有关此处使用的节点命名约定的更多详细信息,请参阅documentation 中的相关子标题。并非所有模型都可以追踪外汇,尽管通过一些按摩可以使它们相互配合。这是一个(不详尽的)提示列表:
如果您不需要跟踪特定的有问题的sub-module,请将
leaf_modules
列表作为tracer_kwargs
之一传递(参见下面的示例),将其转换为“leaf module”。它不会被跟踪,而是生成的图形将包含对该模块的 forward 方法的引用。同样,您可以通过将
autowrap_functions
列表作为tracer_kwargs
之一传递来将函数转换为叶函数(参见下面的示例)。一些内置的 Python 函数可能会出现问题。例如,
int
将在跟踪期间引发错误。您可以将它们包装在您自己的函数中,然后将其作为tracer_kwargs
之一传递给autowrap_functions
。
有关 FX 的更多信息,请参阅torch.fx documentation。
例子:
>>> # Feature extraction with resnet >>> model = torchvision.models.resnet18() >>> # extract layer1 and layer3, giving as names `feat1` and feat2` >>> model = create_feature_extractor( >>> model, {'layer1': 'feat1', 'layer3': 'feat2'}) >>> out = model(torch.rand(1, 3, 224, 224)) >>> print([(k, v.shape) for k, v in out.items()]) >>> [('feat1', torch.Size([1, 64, 56, 56])), >>> ('feat2', torch.Size([1, 256, 14, 14]))] >>> # Specifying leaf modules and leaf functions >>> def leaf_function(x): >>> # This would raise a TypeError if traced through >>> return int(x) >>> >>> class LeafModule(torch.nn.Module): >>> def forward(self, x): >>> # This would raise a TypeError if traced through >>> int(x.shape[0]) >>> return torch.nn.functional.relu(x + 4) >>> >>> class MyModule(torch.nn.Module): >>> def __init__(self): >>> super().__init__() >>> self.conv = torch.nn.Conv2d(3, 1, 3) >>> self.leaf_module = LeafModule() >>> >>> def forward(self, x): >>> leaf_function(x.shape[0]) >>> x = self.conv(x) >>> return self.leaf_module(x) >>> >>> model = create_feature_extractor( >>> MyModule(), return_nodes=['leaf_module'], >>> tracer_kwargs={'leaf_modules': [LeafModule], >>> 'autowrap_functions': [leaf_function]})
参数:
相关用法
- Python PyTorch criteo_terabyte用法及代码示例
- Python PyTorch cross_entropy用法及代码示例
- Python PyTorch cross用法及代码示例
- Python PyTorch criteo_kaggle用法及代码示例
- Python PyTorch cholesky用法及代码示例
- Python PyTorch column_stack用法及代码示例
- Python PyTorch cumprod用法及代码示例
- Python PyTorch calculate_gain用法及代码示例
- Python PyTorch cov用法及代码示例
- Python PyTorch cos用法及代码示例
- Python PyTorch compute_deltas用法及代码示例
- Python PyTorch conv_transpose3d用法及代码示例
- Python PyTorch combinations用法及代码示例
- Python PyTorch conv2d用法及代码示例
- Python PyTorch cummax用法及代码示例
- Python PyTorch custom_from_mask用法及代码示例
- Python PyTorch collect_all用法及代码示例
- Python PyTorch chunk用法及代码示例
- Python PyTorch convert用法及代码示例
- Python PyTorch conv1d用法及代码示例
- Python PyTorch chain_matmul用法及代码示例
- Python PyTorch cat用法及代码示例
- Python PyTorch constant_用法及代码示例
- Python PyTorch context用法及代码示例
- Python PyTorch count_nonzero用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchvision.models.feature_extraction.create_feature_extractor。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。