本文整理汇总了Python中mmcv.Config.fromfile方法的典型用法代码示例。如果您正苦于以下问题:Python Config.fromfile方法的具体用法?Python Config.fromfile怎么用?Python Config.fromfile使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类mmcv.Config
的用法示例。
在下文中一共展示了Config.fromfile方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test
# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def test():
from tqdm import trange
import cv2
print('debug mode '*10 )
args = parse_args()
cfg = Config.fromfile(args.config)
cfg.gpus = 1
dataset = build_dataset(cfg.data.train)
embed(header='123123')
# def visual(i):
# img = dataset[i]['img'].data
# img = img.permute(1,2,0) + 100
# img = img.data.cpu().numpy()
# cv2.imwrite('./trash/resize_v1.jpg',img)
# embed(header='check data resizer')
示例2: parse_config
# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def parse_config(config_strings):
temp_file = tempfile.NamedTemporaryFile()
config_path = f'{temp_file.name}.py'
with open(config_path, 'w') as f:
f.write(config_strings)
config = Config.fromfile(config_path)
is_two_stage = True
is_ssd = False
is_retina = False
reg_cls_agnostic = False
if 'rpn_head' not in config.model:
is_two_stage = False
# check whether it is SSD
if config.model.bbox_head.type == 'SSDHead':
is_ssd = True
elif config.model.bbox_head.type == 'RetinaHead':
is_retina = True
elif isinstance(config.model['bbox_head'], list):
reg_cls_agnostic = True
elif 'reg_class_agnostic' in config.model.bbox_head:
reg_cls_agnostic = config.model.bbox_head \
.reg_class_agnostic
temp_file.close()
return is_two_stage, is_ssd, is_retina, reg_cls_agnostic
示例3: main
# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
img_tensor = get_img_tensor(args.input, args.use_cuda)
cfg.model.pretrained = None
model = build_predictor(cfg.model)
load_checkpoint(model, args.checkpoint, map_location='cpu')
if args.use_cuda:
model.cuda()
model.eval()
# predict probabilities for each attribute
attr_prob = model(img_tensor, attr=None, landmark=None, return_loss=False)
attr_predictor = AttrPredictor(cfg.data.test)
attr_predictor.show_prediction(attr_prob)
示例4: main
# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.data_type == 'train':
image_set = build_dataset(cfg.data.train)
elif args.data_type == 'query':
image_set = build_dataset(cfg.data.query)
elif args.data_type == 'gallery':
image_set = build_dataset(cfg.data.gallery)
else:
raise TypeError('So far only support train/query/gallery dataset')
if args.checkpoint is not None:
cfg.load_from = args.checkpoint
extract_features(image_set, cfg, args.save_dir)
示例5: test_merge_from_base
# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def test_merge_from_base():
cfg_file = osp.join(osp.dirname(__file__), 'data/config/d.py')
cfg = Config.fromfile(cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
base_cfg_file = osp.join(osp.dirname(__file__), 'data/config/base.py')
merge_text = osp.abspath(osp.expanduser(base_cfg_file)) + '\n' + \
open(base_cfg_file, 'r').read()
merge_text += '\n' + osp.abspath(osp.expanduser(cfg_file)) + '\n' + \
open(cfg_file, 'r').read()
assert cfg.text == merge_text
assert cfg.item1 == [2, 3]
assert cfg.item2.a == 1
assert cfg.item3 is False
assert cfg.item4 == 'test_base'
with pytest.raises(TypeError):
Config.fromfile(osp.join(osp.dirname(__file__), 'data/config/e.py'))
示例6: test_merge_from_multiple_bases
# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def test_merge_from_multiple_bases():
cfg_file = osp.join(osp.dirname(__file__), 'data/config/l.py')
cfg = Config.fromfile(cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
# cfg.field
assert cfg.item1 == [1, 2]
assert cfg.item2.a == 0
assert cfg.item3 is False
assert cfg.item4 == 'test'
assert cfg.item5 == dict(a=0, b=1)
assert cfg.item6 == [dict(a=0), dict(b=1)]
assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3]))
with pytest.raises(KeyError):
Config.fromfile(osp.join(osp.dirname(__file__), 'data/config/m.py'))
示例7: _get_config_module
# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def _get_config_module(fname):
"""Load a configuration as a python module."""
from mmcv import Config
config_dpath = _get_config_directory()
config_fpath = join(config_dpath, fname)
config_mod = Config.fromfile(config_fpath)
return config_mod
示例8: retrieve_data_cfg
# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def retrieve_data_cfg(config_path, skip_type):
cfg = Config.fromfile(config_path)
train_data_cfg = cfg.data.train
train_data_cfg['pipeline'] = [
x for x in train_data_cfg.pipeline if x['type'] not in skip_type
]
return cfg
示例9: main
# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.options is not None:
cfg.merge_from_dict(args.options)
print(f'Config:\n{cfg.pretty_text}')
示例10: __init__
# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def __init__(self,
config_file,
checkpoint_file):
# init RoITransformer
self.config_file = config_file
self.checkpoint_file = checkpoint_file
self.cfg = Config.fromfile(self.config_file)
self.data_test = self.cfg.data['test']
self.dataset = get_dataset(self.data_test)
self.classnames = self.dataset.CLASSES
self.model = init_detector(config_file, checkpoint_file, device='cuda:0')
示例11: main
# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def main():
args = parse_args()
if len(args.shape) == 1:
input_shape = (3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (3, ) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
cfg = Config.fromfile(args.config)
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg).cuda()
model.eval()
if hasattr(model, 'forward_dummy'):
model.forward = model.forward_dummy
else:
raise NotImplementedError(
'FLOPs counter is currently not currently supported with {}'.
format(model.__class__.__name__))
flops, params = get_model_complexity_info(model, input_shape)
split_line = '=' * 30
print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
split_line, input_shape, flops, params))
示例12: setUpClass
# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def setUpClass(cls):
cls.device = torch.device('cuda:2')
config_path = '/home/zhixiang/youmin/projects/depth/public/' \
'DenseMatchingBenchmark/configs/PSMNet/kitti_2015.py'
cls.cfg = Config.fromfile(config_path)
cls.model = build_model(cls.cfg)
cls.model.to(cls.device)
cls.setUpTimeTestingClass()
cls.avg_time = {}
示例13: setUpClass
# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def setUpClass(cls):
cls.device = torch.device('cuda:1')
config_path = '/home/zhixiang/youmin/projects/depth/public/' \
'DenseMatchingBenchmark/configs/AcfNet/scene_flow_uniform.py'
cls.cfg = Config.fromfile(config_path)
cls.backbone = build_backbone(cls.cfg)
cls.backbone.to(cls.device)
cls.setUpTimeTestingClass()
cls.avg_time = {}
示例14: main
# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def main():
seed = 0
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
args = parse_args()
cfg = Config.fromfile(args.config)
model = build_retriever(cfg.model)
load_checkpoint(model, args.checkpoint)
print('load checkpoint from {}'.format(args.checkpoint))
if args.use_cuda:
model.cuda()
model.eval()
img_tensor = get_img_tensor(args.input, args.use_cuda)
query_feat = model(img_tensor, landmark=None, return_loss=False)
query_feat = query_feat.data.cpu().numpy()
gallery_set = build_dataset(cfg.data.gallery)
gallery_embeds = _process_embeds(gallery_set, model, cfg)
retriever = ClothesRetriever(cfg.data.gallery.img_file, cfg.data_root,
cfg.data.gallery.img_path)
retriever.show_retrieved_images(query_feat, gallery_embeds)
示例15: main
# 需要导入模块: from mmcv import Config [as 别名]
# 或者: from mmcv.Config import fromfile [as 别名]
def main():
seed = 0
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
args = parse_args()
cfg = Config.fromfile(args.config)
img_tensor, w, h = get_img_tensor(args.input, args.use_cuda, get_size=True)
# build model and load checkpoint
model = build_landmark_detector(cfg.model)
print('model built')
load_checkpoint(model, args.checkpoint)
print('load checkpoint from: {}'.format(args.checkpoint))
if args.use_cuda:
model.cuda()
# detect landmark
model.eval()
pred_vis, pred_lm = model(img_tensor, return_loss=False)
pred_lm = pred_lm.data.cpu().numpy()
vis_lms = []
for i, vis in enumerate(pred_vis):
if vis >= 0.5:
print('detected landmark {} {}'.format(
pred_lm[i][0] * (w / 224.), pred_lm[i][1] * (h / 224.)))
vis_lms.append(pred_lm[i])
draw_landmarks(args.input, vis_lms)