当前位置: 首页>>代码示例>>Python>>正文


Python load_data.load_gt_roidb方法代码示例

本文整理汇总了Python中utils.load_data.load_gt_roidb方法的典型用法代码示例。如果您正苦于以下问题:Python load_data.load_gt_roidb方法的具体用法?Python load_data.load_gt_roidb怎么用?Python load_data.load_gt_roidb使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在utils.load_data的用法示例。


在下文中一共展示了load_data.load_gt_roidb方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_rcnn

# 需要导入模块: from utils import load_data [as 别名]
# 或者: from utils.load_data import load_gt_roidb [as 别名]
def test_rcnn(imageset, year, root_path, devkit_path, prefix, epoch, ctx, vis=False, has_rpn=True, proposal='rpn',
              end2end=False):
    # load symbol and testing data
    if has_rpn:
        sym = get_vgg_test()
        config.TEST.HAS_RPN = True
        config.TEST.RPN_PRE_NMS_TOP_N = 6000
        config.TEST.RPN_POST_NMS_TOP_N = 300
        voc, roidb = load_gt_roidb(imageset, year, root_path, devkit_path)
    else:
        sym = get_vgg_rcnn_test()
        voc, roidb = eval('load_test_' + proposal + '_roidb')(imageset, year, root_path, devkit_path)

    # get test data iter
    test_data = ROIIter(roidb, batch_size=1, shuffle=False, mode='test')

    # load model
    args, auxs, _ = load_param(prefix, epoch, convert=True, ctx=ctx)

    # detect
    detector = Detector(sym, ctx, args, auxs)
    pred_eval(detector, test_data, voc, vis=vis) 
开发者ID:giorking,项目名称:mx-rfcn,代码行数:24,代码来源:test_rcnn.py

示例2: test_rpn

# 需要导入模块: from utils import load_data [as 别名]
# 或者: from utils.load_data import load_gt_roidb [as 别名]
def test_rpn(image_set, year, root_path, devkit_path, prefix, epoch, ctx, vis=False):
    # load symbol
    sym = get_vgg_rpn_test()

    # load testing data
    voc, roidb = load_gt_roidb(image_set, year, root_path, devkit_path)
    test_data = ROIIter(roidb, batch_size=1, shuffle=False, mode='test')

    # load model
    args, auxs = load_param(prefix, epoch, convert=True, ctx=ctx)

    # start testing
    detector = Detector(sym, ctx, args, auxs)
    imdb_boxes = generate_detections(detector, test_data, voc, vis=vis)
    voc.evaluate_recall(roidb, candidate_boxes=imdb_boxes) 
开发者ID:giorking,项目名称:mx-rfcn,代码行数:17,代码来源:test_rpn.py

示例3: main

# 需要导入模块: from utils import load_data [as 别名]
# 或者: from utils.load_data import load_gt_roidb [as 别名]
def main():
    logging.info('########## TRAIN FASTER-RCNN WITH APPROXIMATE JOINT END2END #############')
    init_config()
    if "resnet" in args.pretrained:
        sym = resnet_50(num_class=args.num_classes, bn_mom=args.bn_mom, bn_global=True, is_train=True)  # consider background
    else:
        sym = get_faster_rcnn(num_classes=args.num_classes)  # consider background

    feat_sym = sym.get_internals()['rpn_cls_score_output']
    # setup for multi-gpu
    ctx = [mx.gpu(int(i)) for i in args.gpu_ids.split(',')]
    config.TRAIN.IMS_PER_BATCH *= len(ctx)
    max_data_shape, max_label_shape = get_max_shape(feat_sym)

    # data
    # voc, roidb = load_gt_roidb_from_list(args.dataset_name, args.lst, args.dataset_root,
    #                                      args.outdata_path, flip=not args.no_flip)
    voc, roidb = load_gt_roidb(args.image_set, args.year, args.root_path, args.devkit_path, flip=not args.no_flip)
    train_data = AnchorLoader(feat_sym, roidb, batch_size=config.TRAIN.IMS_PER_BATCH, anchor_scales=(4, 8, 16, 32),
                              shuffle=not args.no_shuffle, mode='train', ctx=ctx, need_mean=args.need_mean)
    # model
    args_params, auxs_params, _ = load_param(args.pretrained, args.load_epoch, convert=True)
    if not args.resume:
        args_params, auxs_params= init_model(args_params, auxs_params, train_data, sym, args.pretrained)
    data_names = [k[0] for k in train_data.provide_data]
    label_names = [k[0] for k in train_data.provide_label]
    batch_end_callback = Speedometer(train_data.batch_size, frequent=args.frequent)
    epoch_end_callback = do_checkpoint(args.prefix)

    optimizer_params = {'momentum':         args.mom,
                        'wd':               args.wd,
                        'learning_rate':    args.lr,
                        # 'lr_scheduler':     WarmupScheduler(args.factor_step, 0.1, warmup_lr=0.1*args.lr, warmup_step=200) \
                        #                     if not args.resume else mx.lr_scheduler.FactorScheduler(args.factor_step, 0.1),
                        'lr_scheduler':     mx.lr_scheduler.FactorScheduler(args.factor_step, 0.1), # seems no need warm up
                        'clip_gradient':    1.0,
                        'rescale_grad':     1.0}

    if "resnet" in args.pretrained:
        # only consider resnet-50 here
        fixed_param_prefix = ['conv0', 'stage1', 'stage2', 'bn_data', 'bn0']
    else:
        fixed_param_prefix = ['conv1', 'conv2', 'conv3']
    # train
    mod = MutableModule(sym, data_names=data_names, label_names=label_names, logger=logger, context=ctx,
                        max_data_shapes=max_data_shape, max_label_shapes=max_label_shape,
                        fixed_param_prefix=fixed_param_prefix)
    mod.fit(train_data, eval_metric=metric(), epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback, kvstore=args.kv_store,
            optimizer='sgd', optimizer_params=optimizer_params, arg_params=args_params, aux_params=auxs_params,
            begin_epoch=args.load_epoch, num_epoch=args.num_epoch) 
开发者ID:giorking,项目名称:mx-rfcn,代码行数:53,代码来源:train_end2end_resnet.py


注:本文中的utils.load_data.load_gt_roidb方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。