当前位置: 首页>>代码示例>>Python>>正文


Python detectron_weight_helper.resnet_weights_name_pattern方法代码示例

本文整理汇总了Python中utils.detectron_weight_helper.resnet_weights_name_pattern方法的典型用法代码示例。如果您正苦于以下问题:Python detectron_weight_helper.resnet_weights_name_pattern方法的具体用法?Python detectron_weight_helper.resnet_weights_name_pattern怎么用?Python detectron_weight_helper.resnet_weights_name_pattern使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在utils.detectron_weight_helper的用法示例。


在下文中一共展示了detectron_weight_helper.resnet_weights_name_pattern方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: load_pretrained_imagenet_weights

# 需要导入模块: from utils import detectron_weight_helper [as 别名]
# 或者: from utils.detectron_weight_helper import resnet_weights_name_pattern [as 别名]
def load_pretrained_imagenet_weights(model):
    """Load pretrained weights
    Args:
        num_layers: 50 for res50 and so on.
        model: the generalized rcnnn module
    """
    _, ext = os.path.splitext(cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS)
    if ext == '.pkl':
        with open(cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS, 'rb') as fp:
            src_blobs = pickle.load(fp, encoding='latin1')
        if 'blobs' in src_blobs:
            src_blobs = src_blobs['blobs']
        pretrianed_state_dict = src_blobs
    else:
        weights_file = os.path.join(cfg.ROOT_DIR, cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS)
        pretrianed_state_dict = convert_state_dict(torch.load(weights_file))

        # Convert batchnorm weights
        for name, mod in model.named_modules():
            if isinstance(mod, mynn.AffineChannel2d):
                if cfg.FPN.FPN_ON:
                    pretrianed_name = name.split('.', 2)[-1]
                else:
                    pretrianed_name = name.split('.', 1)[-1]
                bn_mean = pretrianed_state_dict[pretrianed_name + '.running_mean']
                bn_var = pretrianed_state_dict[pretrianed_name + '.running_var']
                scale = pretrianed_state_dict[pretrianed_name + '.weight']
                bias = pretrianed_state_dict[pretrianed_name + '.bias']
                std = torch.sqrt(bn_var + 1e-5)
                new_scale = scale / std
                new_bias = bias - bn_mean * scale / std
                pretrianed_state_dict[pretrianed_name + '.weight'] = new_scale
                pretrianed_state_dict[pretrianed_name + '.bias'] = new_bias

    model_state_dict = model.state_dict()

    pattern = dwh.resnet_weights_name_pattern()

    name_mapping, _ = model.detectron_weight_mapping

    for k, v in name_mapping.items():
        if isinstance(v, str):  # maybe a str, None or True
            if pattern.match(v):
                if cfg.FPN.FPN_ON:
                    pretrianed_key = k.split('.', 2)[-1]
                else:
                    pretrianed_key = k.split('.', 1)[-1]
                if ext == '.pkl':
                    model_state_dict[k].copy_(torch.Tensor(pretrianed_state_dict[v]))
                else:
                    model_state_dict[k].copy_(pretrianed_state_dict[pretrianed_key]) 
开发者ID:roytseng-tw,项目名称:Detectron.pytorch,代码行数:53,代码来源:resnet_weights_helper.py


注:本文中的utils.detectron_weight_helper.resnet_weights_name_pattern方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。