當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch fuse_modules用法及代碼示例


本文簡要介紹python語言中 torch.quantization.fuse_modules 的用法。

用法:

torch.quantization.fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=<function fuse_known_modules>, fuse_custom_config_dict=None)

參數

  • model-包含要融合的模塊的模型

  • modules_to_fuse-要融合的模塊名稱列表。如果隻有一個模塊列表要融合,也可以是字符串列表。

  • inplace-bool 指定模型上是否發生融合,默認情況下返回一個新模型

  • fuser_func-接收模塊列表並輸出相同長度的融合模塊列表的函數。例如,fuser_func([convModule, BNModule]) 返回列表 [ConvBNModule, nn.Identity()] 默認為 torch.quantization.fuse_known_modules

  • fuse_custom_config_dict-用於融合的自定義配置

返回

帶有融合模塊的模型。如果inplace=True,則創建一個新副本。

將模塊列表融合為單個模塊

僅融合以下模塊序列:conv、bn conv、bn、relu conv、relu linear、relu bn、relu 所有其他序列保持不變。對於這些序列,將列表中的第一項替換為融合模塊,將其餘模塊替換為標識。

# Example of fuse_custom_config_dict
fuse_custom_config_dict = {
    # Additional fuser_method mapping
    "additional_fuser_method_mapping": {
        (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
    },
}

例子:

>>> m = myModel()
>>> # m is a module containing  the sub-modules below
>>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
>>> output = fused_m(input)

>>> m = myModel()
>>> # Alternately provide a single list of modules to fuse
>>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
>>> output = fused_m(input)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.quantization.fuse_modules。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。