當前位置: 首頁>>代碼示例>>Python>>正文


Python runner.load_checkpoint方法代碼示例

本文整理匯總了Python中mmcv.runner.load_checkpoint方法的典型用法代碼示例。如果您正苦於以下問題:Python runner.load_checkpoint方法的具體用法?Python runner.load_checkpoint怎麽用?Python runner.load_checkpoint使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在mmcv.runner的用法示例。


在下文中一共展示了runner.load_checkpoint方法的8個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: init_model

# 需要導入模塊: from mmcv import runner [as 別名]
# 或者: from mmcv.runner import load_checkpoint [as 別名]
def init_model(config, checkpoint=None, device='cuda:0'):
    """
    Initialize a stereo model from config file.
    Args:
        config (str or :obj:`mmcv.Config`): Config file path or the config
            object.
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.
    Returns:
        nn.Module: The constructed stereo model.
    """
    if isinstance(config, str):
        config = mmcv.Config.fromfile(config)
    elif not isinstance(config, mmcv.Config):
        raise TypeError('config must be a filename or Config object, '
                        'but got {}'.format(type(config)))

    model = build_model(config)
    if checkpoint is not None:
        checkpoint = load_checkpoint(model, checkpoint)
    model.cfg = config  # save the config in the model for convenience
    model.to(device)
    model.eval()
    return model 
開發者ID:DeepMotionAIResearch,項目名稱:DenseMatchingBenchmark,代碼行數:26,代碼來源:inference.py

示例2: init_weights

# 需要導入模塊: from mmcv import runner [as 別名]
# 或者: from mmcv.runner import load_checkpoint [as 別名]
def init_weights(self, pretrained=None):
        print("init hrnet weights")
#         if isinstance(pretrained, str):
#             logger = logging.getLogger()
#             load_checkpoint(self, pretrained, strict=False, logger=logger)
#         elif pretrained is None:
#             for m in self.modules():
#                 if isinstance(m, nn.Conv2d):
#                     kaiming_init(m)
#                 elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
#                     constant_init(m, 1)

#             if self.zero_init_residual:
#                 for m in self.modules():
#                     if isinstance(m, Bottleneck):
#                         constant_init(m.norm3, 0)
#                     elif isinstance(m, BasicBlock):
#                         constant_init(m.norm2, 0)
#         else:
#             raise TypeError('pretrained must be a str or None') 
開發者ID:lizhe960118,項目名稱:CenterNet,代碼行數:22,代碼來源:hrnet4.py

示例3: init_detector

# 需要導入模塊: from mmcv import runner [as 別名]
# 或者: from mmcv.runner import load_checkpoint [as 別名]
def init_detector(config, checkpoint=None, device='cuda:0'):
    """Initialize a detector from config file.

    Args:
        config (str or :obj:`mmcv.Config`): Config file path or the config
            object.
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.

    Returns:
        nn.Module: The constructed detector.
    """
    if isinstance(config, str):
        config = mmcv.Config.fromfile(config)
    elif not isinstance(config, mmcv.Config):
        raise TypeError('config must be a filename or Config object, '
                        f'but got {type(config)}')
    config.model.pretrained = None
    model = build_detector(config.model, test_cfg=config.test_cfg)
    if checkpoint is not None:
        checkpoint = load_checkpoint(model, checkpoint)
        if 'CLASSES' in checkpoint['meta']:
            model.CLASSES = checkpoint['meta']['CLASSES']
        else:
            warnings.simplefilter('once')
            warnings.warn('Class names are not saved in the checkpoint\'s '
                          'meta data, use COCO classes by default.')
            model.CLASSES = get_classes('coco')
    model.cfg = config  # save the config in the model for convenience
    model.to(device)
    model.eval()
    return model 
開發者ID:open-mmlab,項目名稱:mmdetection,代碼行數:34,代碼來源:inference.py

示例4: init_detector

# 需要導入模塊: from mmcv import runner [as 別名]
# 或者: from mmcv.runner import load_checkpoint [as 別名]
def init_detector(config, checkpoint=None, device='cuda:0'):
    """Initialize a detector from config file.

    Args:
        config (str or :obj:`mmcv.Config`): Config file path or the config
            object.
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.

    Returns:
        nn.Module: The constructed detector.
    """
    if isinstance(config, str):
        config = mmcv.Config.fromfile(config)
    elif not isinstance(config, mmcv.Config):
        raise TypeError('config must be a filename or Config object, '
                        'but got {}'.format(type(config)))
    config.model.pretrained = None
    model = build_detector(config.model, test_cfg=config.test_cfg)
    if checkpoint is not None:
        checkpoint = load_checkpoint(model, checkpoint)
        if 'CLASSES' in checkpoint['meta']:
            model.CLASSES = checkpoint['meta']['CLASSES']
        else:
            warnings.warn('Class names are not saved in the checkpoint\'s '
                          'meta data, use COCO classes by default.')
            model.CLASSES = get_classes('coco')
    model.cfg = config  # save the config in the model for convenience
    model.to(device)
    model.eval()
    return model 
開發者ID:dingjiansw101,項目名稱:AerialDetection,代碼行數:33,代碼來源:inference.py

示例5: init_detector

# 需要導入模塊: from mmcv import runner [as 別名]
# 或者: from mmcv.runner import load_checkpoint [as 別名]
def init_detector(config, checkpoint=None, device='cuda:0'):
    """Initialize a detector from config file.

    Args:
        config (str or :obj:`mmcv.Config`): Config file path or the config
            object.
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.

    Returns:
        nn.Module: The constructed detector.
    """
    if isinstance(config, str):
        config = mmcv.Config.fromfile(config)
    elif not isinstance(config, mmcv.Config):
        raise TypeError('config must be a filename or Config object, '
                        'but got {}'.format(type(config)))
    config.model.pretrained = None
    model = build_detector(config.model, test_cfg=config.test_cfg)
    if checkpoint is not None:
        checkpoint = load_checkpoint(model, checkpoint)
        if 'CLASSES' in checkpoint['meta']:
            model.CLASSES = checkpoint['meta']['classes']
        else:
            warnings.warn('Class names are not saved in the checkpoint\'s '
                          'meta data, use COCO classes by default.')
            model.CLASSES = get_classes('coco')
    model.cfg = config  # save the config in the model for convenience
    model.to(device)
    model.eval()
    return model 
開發者ID:STVIR,項目名稱:Grid-R-CNN,代碼行數:33,代碼來源:inference.py

示例6: load_checkpoint

# 需要導入模塊: from mmcv import runner [as 別名]
# 或者: from mmcv.runner import load_checkpoint [as 別名]
def load_checkpoint(model, filename, *args, **kwargs):
    try:
        filename = get_mmskeleton_url(filename)
        return mmcv_load_checkpoint(model, filename, *args, **kwargs)
    except (urllib.error.HTTPError, urllib.error.URLError) as e:
        raise Exception(url_error_message.format(filename)) from e 
開發者ID:open-mmlab,項目名稱:mmskeleton,代碼行數:8,代碼來源:checkpoint.py

示例7: load_model

# 需要導入模塊: from mmcv import runner [as 別名]
# 或者: from mmcv.runner import load_checkpoint [as 別名]
def load_model():
    model = build_detector(cfg.model, test_cfg=cfg.test_cfg)
    _ = load_checkpoint(model, model_cfgs[0][1]) # 7 it/s
    return model 
開發者ID:lxy5513,項目名稱:hrnet,代碼行數:6,代碼來源:high_api.py

示例8: _make_stage

# 需要導入模塊: from mmcv import runner [as 別名]
# 或者: from mmcv.runner import load_checkpoint [as 別名]
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
        num_modules = layer_config['num_modules']
        num_branches = layer_config['num_branches']
        num_blocks = layer_config['num_blocks']
        num_channels = layer_config['num_channels']
        block = self.blocks_dict[layer_config['block']]

        hr_modules = []
        for i in range(num_modules):
            # multi_scale_output is only used for the last module
            if not multiscale_output and i == num_modules - 1:
                reset_multiscale_output = False
            else:
                reset_multiscale_output = True

            hr_modules.append(
                HRModule(
                    num_branches,
                    block,
                    num_blocks,
                    in_channels,
                    num_channels,
                    reset_multiscale_output,
                    with_cp=self.with_cp,
                    norm_cfg=self.norm_cfg,
                    conv_cfg=self.conv_cfg))

        return nn.Sequential(*hr_modules), in_channels

#     def init_weights(self, pretrained=None):
#         if isinstance(pretrained, str):
#             logger = logging.getLogger()
#             load_checkpoint(self, pretrained, strict=False, logger=logger)
#         elif pretrained is None:
#             for m in self.modules():
#                 if isinstance(m, nn.Conv2d):
#                     kaiming_init(m)
#                 elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
#                     constant_init(m, 1)

#             if self.zero_init_residual:
#                 for m in self.modules():
#                     if isinstance(m, Bottleneck):
#                         constant_init(m.norm3, 0)
#                     elif isinstance(m, BasicBlock):
#                         constant_init(m.norm2, 0)
#         else:
#             raise TypeError('pretrained must be a str or None') 
開發者ID:lizhe960118,項目名稱:CenterNet,代碼行數:50,代碼來源:hrnet2.py


注:本文中的mmcv.runner.load_checkpoint方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。