當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch get_graph_node_names用法及代碼示例


本文簡要介紹python語言中 torchvision.models.feature_extraction.get_graph_node_names 的用法。

用法:

torchvision.models.feature_extraction.get_graph_node_names(model: torch.nn.modules.module.Module, tracer_kwargs: Dict = {}, suppress_diff_warning: bool = False) → Tuple[List[str], List[str]]

參數

  • model(nn.Module) -我們要為其打印節點名稱的模型

  • tracer_kwargs(dict,可選的) -

    NodePathTracer 的關鍵參數字典(它們最終被傳遞到 torch.fx.Tracer )。

  • suppress_diff_warning(bool,可選的) -當圖表的 train 和 eval 版本之間存在差異時是否抑製警告。默認為假。

返回

在訓練模式下跟蹤模型的節點名稱列表,以及在評估模式下跟蹤模型的另一個節點名稱列表。

返回類型

tuple ( list , list )

開發實用程序按執行順序返回節點名稱。請參閱 create_feature_extractor() 下有關節點名稱的注釋。對於查看哪些節點名稱可用於特征提取非常有用。無法輕鬆地直接從模型代碼中讀取節點名稱有兩個原因:

  1. 並非所有子模塊都被跟蹤。 torch.nn 中的模塊都屬於此類別。

  2. 表示重複應用相同操作或葉模塊的節點獲得_{counter} 後綴。

該模型被跟蹤兩次:一次在訓練模式下,一次在評估模式下。返回兩組節點名稱。

有關此處使用的節點命名約定的更多詳細信息,請參閱documentation 中的相關子標題。

例子:

>>> model = torchvision.models.resnet18()
>>> train_nodes, eval_nodes = get_graph_node_names(model)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torchvision.models.feature_extraction.get_graph_node_names。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。