当前位置: 首页>>代码示例>>Python>>正文


Python Config.fromfile方法代码示例

本文整理汇总了Python中mmcv.Config.fromfile方法的典型用法代码示例。如果您正苦于以下问题:Python Config.fromfile方法的具体用法?Python Config.fromfile怎么用?Python Config.fromfile使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在mmcv.Config的用法示例。


在下文中一共展示了Config.fromfile方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test

# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def test():
    from tqdm import trange
    import cv2
    print('debug mode '*10 )
    args = parse_args()
    cfg = Config.fromfile(args.config)
    cfg.gpus = 1

    dataset = build_dataset(cfg.data.train)
    embed(header='123123')
    # def visual(i):
    #     img = dataset[i]['img'].data
    #     img = img.permute(1,2,0) + 100
    #     img = img.data.cpu().numpy()
    #     cv2.imwrite('./trash/resize_v1.jpg',img)

    # embed(header='check data resizer') 
开发者ID:xieenze,项目名称:PolarMask,代码行数:19,代码来源:train.py

示例2: parse_config

# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def parse_config(config_strings):
    temp_file = tempfile.NamedTemporaryFile()
    config_path = f'{temp_file.name}.py'
    with open(config_path, 'w') as f:
        f.write(config_strings)

    config = Config.fromfile(config_path)
    is_two_stage = True
    is_ssd = False
    is_retina = False
    reg_cls_agnostic = False
    if 'rpn_head' not in config.model:
        is_two_stage = False
        # check whether it is SSD
        if config.model.bbox_head.type == 'SSDHead':
            is_ssd = True
        elif config.model.bbox_head.type == 'RetinaHead':
            is_retina = True
    elif isinstance(config.model['bbox_head'], list):
        reg_cls_agnostic = True
    elif 'reg_class_agnostic' in config.model.bbox_head:
        reg_cls_agnostic = config.model.bbox_head \
            .reg_class_agnostic
    temp_file.close()
    return is_two_stage, is_ssd, is_retina, reg_cls_agnostic 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:27,代码来源:upgrade_model_version.py

示例3: main

# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    img_tensor = get_img_tensor(args.input, args.use_cuda)

    cfg.model.pretrained = None
    model = build_predictor(cfg.model)
    load_checkpoint(model, args.checkpoint, map_location='cpu')
    if args.use_cuda:
        model.cuda()

    model.eval()

    # predict probabilities for each attribute
    attr_prob = model(img_tensor, attr=None, landmark=None, return_loss=False)
    attr_predictor = AttrPredictor(cfg.data.test)

    attr_predictor.show_prediction(attr_prob) 
开发者ID:open-mmlab,项目名称:mmfashion,代码行数:21,代码来源:test_predictor.py

示例4: main

# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    if args.data_type == 'train':
        image_set = build_dataset(cfg.data.train)
    elif args.data_type == 'query':
        image_set = build_dataset(cfg.data.query)
    elif args.data_type == 'gallery':
        image_set = build_dataset(cfg.data.gallery)
    else:
        raise TypeError('So far only support train/query/gallery dataset')

    if args.checkpoint is not None:
        cfg.load_from = args.checkpoint

    extract_features(image_set, cfg, args.save_dir) 
开发者ID:open-mmlab,项目名称:mmfashion,代码行数:19,代码来源:extract_features.py

示例5: test_merge_from_base

# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def test_merge_from_base():
    cfg_file = osp.join(osp.dirname(__file__), 'data/config/d.py')
    cfg = Config.fromfile(cfg_file)
    assert isinstance(cfg, Config)
    assert cfg.filename == cfg_file
    base_cfg_file = osp.join(osp.dirname(__file__), 'data/config/base.py')
    merge_text = osp.abspath(osp.expanduser(base_cfg_file)) + '\n' + \
        open(base_cfg_file, 'r').read()
    merge_text += '\n' + osp.abspath(osp.expanduser(cfg_file)) + '\n' + \
                  open(cfg_file, 'r').read()
    assert cfg.text == merge_text
    assert cfg.item1 == [2, 3]
    assert cfg.item2.a == 1
    assert cfg.item3 is False
    assert cfg.item4 == 'test_base'

    with pytest.raises(TypeError):
        Config.fromfile(osp.join(osp.dirname(__file__), 'data/config/e.py')) 
开发者ID:open-mmlab,项目名称:mmcv,代码行数:20,代码来源:test_config.py

示例6: test_merge_from_multiple_bases

# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def test_merge_from_multiple_bases():
    cfg_file = osp.join(osp.dirname(__file__), 'data/config/l.py')
    cfg = Config.fromfile(cfg_file)
    assert isinstance(cfg, Config)
    assert cfg.filename == cfg_file
    # cfg.field
    assert cfg.item1 == [1, 2]
    assert cfg.item2.a == 0
    assert cfg.item3 is False
    assert cfg.item4 == 'test'
    assert cfg.item5 == dict(a=0, b=1)
    assert cfg.item6 == [dict(a=0), dict(b=1)]
    assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3]))

    with pytest.raises(KeyError):
        Config.fromfile(osp.join(osp.dirname(__file__), 'data/config/m.py')) 
开发者ID:open-mmlab,项目名称:mmcv,代码行数:18,代码来源:test_config.py

示例7: _get_config_module

# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def _get_config_module(fname):
    """Load a configuration as a python module."""
    from mmcv import Config
    config_dpath = _get_config_directory()
    config_fpath = join(config_dpath, fname)
    config_mod = Config.fromfile(config_fpath)
    return config_mod 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:9,代码来源:test_forward.py

示例8: retrieve_data_cfg

# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def retrieve_data_cfg(config_path, skip_type):
    cfg = Config.fromfile(config_path)
    train_data_cfg = cfg.data.train
    train_data_cfg['pipeline'] = [
        x for x in train_data_cfg.pipeline if x['type'] not in skip_type
    ]

    return cfg 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:10,代码来源:browse_dataset.py

示例9: main

# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    if args.options is not None:
        cfg.merge_from_dict(args.options)
    print(f'Config:\n{cfg.pretty_text}') 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:9,代码来源:print_config.py

示例10: __init__

# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def __init__(self,
                 config_file,
                 checkpoint_file):
        # init RoITransformer
        self.config_file = config_file
        self.checkpoint_file = checkpoint_file
        self.cfg = Config.fromfile(self.config_file)
        self.data_test = self.cfg.data['test']
        self.dataset = get_dataset(self.data_test)
        self.classnames = self.dataset.CLASSES
        self.model = init_detector(config_file, checkpoint_file, device='cuda:0') 
开发者ID:dingjiansw101,项目名称:AerialDetection,代码行数:13,代码来源:demo_large_image.py

示例11: main

# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def main():

    args = parse_args()

    if len(args.shape) == 1:
        input_shape = (3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (3, ) + tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

    cfg = Config.fromfile(args.config)
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg).cuda()
    model.eval()

    if hasattr(model, 'forward_dummy'):
        model.forward = model.forward_dummy
    else:
        raise NotImplementedError(
            'FLOPs counter is currently not currently supported with {}'.
            format(model.__class__.__name__))

    flops, params = get_model_complexity_info(model, input_shape)
    split_line = '=' * 30
    print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
        split_line, input_shape, flops, params)) 
开发者ID:dingjiansw101,项目名称:AerialDetection,代码行数:29,代码来源:get_flops.py

示例12: setUpClass

# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def setUpClass(cls):
        cls.device = torch.device('cuda:2')
        config_path = '/home/zhixiang/youmin/projects/depth/public/' \
                      'DenseMatchingBenchmark/configs/PSMNet/kitti_2015.py'
        cls.cfg = Config.fromfile(config_path)
        cls.model = build_model(cls.cfg)
        cls.model.to(cls.device)

        cls.setUpTimeTestingClass()
        cls.avg_time = {} 
开发者ID:DeepMotionAIResearch,项目名称:DenseMatchingBenchmark,代码行数:12,代码来源:test_model.py

示例13: setUpClass

# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def setUpClass(cls):
        cls.device = torch.device('cuda:1')
        config_path = '/home/zhixiang/youmin/projects/depth/public/' \
                      'DenseMatchingBenchmark/configs/AcfNet/scene_flow_uniform.py'
        cls.cfg = Config.fromfile(config_path)
        cls.backbone = build_backbone(cls.cfg)
        cls.backbone.to(cls.device)

        cls.setUpTimeTestingClass()
        cls.avg_time = {} 
开发者ID:DeepMotionAIResearch,项目名称:DenseMatchingBenchmark,代码行数:12,代码来源:test_backbones.py

示例14: main

# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def main():
    seed = 0
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    args = parse_args()
    cfg = Config.fromfile(args.config)

    model = build_retriever(cfg.model)
    load_checkpoint(model, args.checkpoint)
    print('load checkpoint from {}'.format(args.checkpoint))

    if args.use_cuda:
        model.cuda()
    model.eval()

    img_tensor = get_img_tensor(args.input, args.use_cuda)

    query_feat = model(img_tensor, landmark=None, return_loss=False)
    query_feat = query_feat.data.cpu().numpy()

    gallery_set = build_dataset(cfg.data.gallery)
    gallery_embeds = _process_embeds(gallery_set, model, cfg)

    retriever = ClothesRetriever(cfg.data.gallery.img_file, cfg.data_root,
                                 cfg.data.gallery.img_path)
    retriever.show_retrieved_images(query_feat, gallery_embeds) 
开发者ID:open-mmlab,项目名称:mmfashion,代码行数:30,代码来源:test_retriever.py

示例15: main

# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def main():
    seed = 0
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    args = parse_args()
    cfg = Config.fromfile(args.config)

    img_tensor, w, h = get_img_tensor(args.input, args.use_cuda, get_size=True)

    # build model and load checkpoint
    model = build_landmark_detector(cfg.model)
    print('model built')
    load_checkpoint(model, args.checkpoint)
    print('load checkpoint from: {}'.format(args.checkpoint))

    if args.use_cuda:
        model.cuda()

    # detect landmark
    model.eval()
    pred_vis, pred_lm = model(img_tensor, return_loss=False)
    pred_lm = pred_lm.data.cpu().numpy()
    vis_lms = []

    for i, vis in enumerate(pred_vis):
        if vis >= 0.5:
            print('detected landmark {} {}'.format(
                pred_lm[i][0] * (w / 224.), pred_lm[i][1] * (h / 224.)))
            vis_lms.append(pred_lm[i])

    draw_landmarks(args.input, vis_lms) 
开发者ID:open-mmlab,项目名称:mmfashion,代码行数:35,代码来源:test_landmark_detector.py


注:本文中的mmcv.Config.fromfile方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。