本文簡要介紹python語言中 torchvision.models.feature_extraction.create_feature_extractor
的用法。
用法:
torchvision.models.feature_extraction.create_feature_extractor(model: torch.nn.modules.module.Module, return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, tracer_kwargs: Dict = {}, suppress_diff_warning: bool = False) → torch.fx.graph_module.GraphModule
model(nn.Module) -我們將在其上提取特征的模型
return_nodes(list或者dict,可選的) -
List
或Dict
包含將為其返回激活的節點的名稱(或部分名稱 - 參見上麵的注釋)。如果是Dict
,鍵是節點名稱,值是用戶為圖形模塊返回的字典指定的鍵。如果它是List
,則將其視為Dict
將節點規範字符串直接映射到輸出名稱。在指定train_return_nodes
和eval_return_nodes
的情況下,不應指定。train_return_nodes(list或者dict,可選的) -類似於
return_nodes
。如果 train 模式的返回節點與 eval 模式的返回節點不同,則可以使用此選項。如果指定了此項,則還必須指定eval_return_nodes
,並且不應指定return_nodes
。eval_return_nodes(list或者dict,可選的) -類似於
return_nodes
。如果 train 模式的返回節點與 eval 模式的返回節點不同,則可以使用此選項。如果指定了此項,則還必須指定train_return_nodes
,並且不應指定return_nodes
。tracer_kwargs(dict,可選的) -
NodePathTracer
的關鍵參數字典(將它們傳遞給它的父類 torch.fx.Tracer )。suppress_diff_warning(bool,可選的) -當圖表的 train 和 eval 版本之間存在差異時是否抑製警告。默認為假。
創建一個新的圖形模塊,該模塊將給定模型中的中間節點作為字典返回,用戶指定的鍵作為字符串,請求的輸出作為值。這是通過 FX 重寫模型的計算圖以返回所需節點作為輸出來實現的。所有未使用的節點及其相應的參數都將被刪除。
所需的輸出節點必須指定為
.
分隔路徑,將模塊層次結構從頂層模塊向下傳遞到葉操作或葉模塊。有關此處使用的節點命名約定的更多詳細信息,請參閱documentation 中的相關子標題。並非所有模型都可以追蹤外匯,盡管通過一些按摩可以使它們相互配合。這是一個(不詳盡的)提示列表:
如果您不需要跟蹤特定的有問題的sub-module,請將
leaf_modules
列表作為tracer_kwargs
之一傳遞(參見下麵的示例),將其轉換為“leaf module”。它不會被跟蹤,而是生成的圖形將包含對該模塊的 forward 方法的引用。同樣,您可以通過將
autowrap_functions
列表作為tracer_kwargs
之一傳遞來將函數轉換為葉函數(參見下麵的示例)。一些內置的 Python 函數可能會出現問題。例如,
int
將在跟蹤期間引發錯誤。您可以將它們包裝在您自己的函數中,然後將其作為tracer_kwargs
之一傳遞給autowrap_functions
。
有關 FX 的更多信息,請參閱torch.fx documentation。
例子:
>>> # Feature extraction with resnet >>> model = torchvision.models.resnet18() >>> # extract layer1 and layer3, giving as names `feat1` and feat2` >>> model = create_feature_extractor( >>> model, {'layer1': 'feat1', 'layer3': 'feat2'}) >>> out = model(torch.rand(1, 3, 224, 224)) >>> print([(k, v.shape) for k, v in out.items()]) >>> [('feat1', torch.Size([1, 64, 56, 56])), >>> ('feat2', torch.Size([1, 256, 14, 14]))] >>> # Specifying leaf modules and leaf functions >>> def leaf_function(x): >>> # This would raise a TypeError if traced through >>> return int(x) >>> >>> class LeafModule(torch.nn.Module): >>> def forward(self, x): >>> # This would raise a TypeError if traced through >>> int(x.shape[0]) >>> return torch.nn.functional.relu(x + 4) >>> >>> class MyModule(torch.nn.Module): >>> def __init__(self): >>> super().__init__() >>> self.conv = torch.nn.Conv2d(3, 1, 3) >>> self.leaf_module = LeafModule() >>> >>> def forward(self, x): >>> leaf_function(x.shape[0]) >>> x = self.conv(x) >>> return self.leaf_module(x) >>> >>> model = create_feature_extractor( >>> MyModule(), return_nodes=['leaf_module'], >>> tracer_kwargs={'leaf_modules': [LeafModule], >>> 'autowrap_functions': [leaf_function]})
參數:
相關用法
- Python PyTorch criteo_terabyte用法及代碼示例
- Python PyTorch cross_entropy用法及代碼示例
- Python PyTorch cross用法及代碼示例
- Python PyTorch criteo_kaggle用法及代碼示例
- Python PyTorch cholesky用法及代碼示例
- Python PyTorch column_stack用法及代碼示例
- Python PyTorch cumprod用法及代碼示例
- Python PyTorch calculate_gain用法及代碼示例
- Python PyTorch cov用法及代碼示例
- Python PyTorch cos用法及代碼示例
- Python PyTorch compute_deltas用法及代碼示例
- Python PyTorch conv_transpose3d用法及代碼示例
- Python PyTorch combinations用法及代碼示例
- Python PyTorch conv2d用法及代碼示例
- Python PyTorch cummax用法及代碼示例
- Python PyTorch custom_from_mask用法及代碼示例
- Python PyTorch collect_all用法及代碼示例
- Python PyTorch chunk用法及代碼示例
- Python PyTorch convert用法及代碼示例
- Python PyTorch conv1d用法及代碼示例
- Python PyTorch chain_matmul用法及代碼示例
- Python PyTorch cat用法及代碼示例
- Python PyTorch constant_用法及代碼示例
- Python PyTorch context用法及代碼示例
- Python PyTorch count_nonzero用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torchvision.models.feature_extraction.create_feature_extractor。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。