当前位置: 首页>>代码示例>>Python>>正文


Python net.load_ckpt方法代码示例

本文整理汇总了Python中utils.net.load_ckpt方法的典型用法代码示例。如果您正苦于以下问题:Python net.load_ckpt方法的具体用法?Python net.load_ckpt怎么用?Python net.load_ckpt使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在utils.net的用法示例。


在下文中一共展示了net.load_ckpt方法的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: initialize_model_from_cfg

# 需要导入模块: from utils import net [as 别名]
# 或者: from utils.net import load_ckpt [as 别名]
def initialize_model_from_cfg(args, gpu_id=0):
    """Initialize a model from the global cfg. Loads test-time weights and
    set to evaluation mode.
    """
    model = model_builder.Generalized_RCNN()
    model.eval()

    if args.cuda:
        model.cuda()

    if args.load_ckpt:
        load_name = args.load_ckpt
        logger.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(model, checkpoint['model'])

    if args.load_detectron:
        logger.info("loading detectron weights %s", args.load_detectron)
        load_detectron_weight(model, args.load_detectron)

    model = mynn.DataParallel(model, cpu_keywords=['im_info', 'roidb'], minibatch=True)

    return model 
开发者ID:roytseng-tw,项目名称:Detectron.pytorch,代码行数:25,代码来源:test_engine.py

示例2: initialize_model_from_cfg

# 需要导入模块: from utils import net [as 别名]
# 或者: from utils.net import load_ckpt [as 别名]
def initialize_model_from_cfg(args, gpu_id=0):
    """Initialize a model from the global cfg. Loads test-time weights and
    set to evaluation mode.
    """
    model = Generalized_RCNN()
    model.eval()

    if args.cuda:
        model.cuda()

    if args.load_ckpt:
        load_name = args.load_ckpt
        logger.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(model, checkpoint['model'])
        # model.load_state_dict(checkpoint['model'])

    if args.load_detectron:
        logger.info("loading detectron weights %s", args.load_detectron)
        load_detectron_weight(model, args.load_detectron)

    model = mynn.DataParallel(model, cpu_keywords=['im_info', 'roidb'], minibatch=True)

    return model 
开发者ID:bobwan1995,项目名称:PMFNet,代码行数:26,代码来源:test_engine.py

示例3: initialize_model_from_cfg

# 需要导入模块: from utils import net [as 别名]
# 或者: from utils.net import load_ckpt [as 别名]
def initialize_model_from_cfg(args, gpu_id=0):
    """Initialize a model from the global cfg. Loads test-time weights and
    set to evaluation mode.
    """
    model = model_builder_rel.Generalized_RCNN()
    model.eval()

    if args.cuda:
        model.cuda()

    if args.load_ckpt:
        load_name = args.load_ckpt
        logger.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(model, checkpoint['model'])

    if args.load_detectron:
        logger.info("loading detectron weights %s", args.load_detectron)
        load_detectron_weight(model, args.load_detectron)

    model = mynn.DataParallel(model, cpu_keywords=['im_info', 'roidb'], minibatch=True)

    return model 
开发者ID:jz462,项目名称:Large-Scale-VRD.pytorch,代码行数:25,代码来源:test_engine_rel.py

示例4: parse_args

# 需要导入模块: from utils import net [as 别名]
# 或者: from utils.net import load_ckpt [as 别名]
def parse_args():
    """Parse in command line arguments"""
    parser = argparse.ArgumentParser(description='Demonstrate mask-rcnn results')
    parser.add_argument(
        '--dataset', required=True,
        help='training dataset')

    parser.add_argument(
        '--cfg', dest='cfg_file', required=True,
        help='optional config file')
    parser.add_argument(
        '--set', dest='set_cfgs',
        help='set config keys, will overwrite config in the cfg_file',
        default=[], nargs='+')

    parser.add_argument(
        '--no_cuda', dest='cuda', help='whether use CUDA', action='store_false')

    parser.add_argument('--load_ckpt', help='path of checkpoint to load')
    parser.add_argument(
        '--load_detectron', help='path to the detectron weight pickle file')

    parser.add_argument(
        '--image_dir',
        help='directory to load images for demo')
    parser.add_argument(
        '--images', nargs='+',
        help='images to infer. Must not use with --image_dir')
    parser.add_argument(
        '--output_dir',
        help='directory to save demo results',
        default="infer_outputs")
    parser.add_argument(
        '--merge_pdfs', type=distutils.util.strtobool, default=True)

    args = parser.parse_args()

    return args 
开发者ID:roytseng-tw,项目名称:Detectron.pytorch,代码行数:40,代码来源:infer_simple.py

示例5: multi_gpu_test_net_on_dataset

# 需要导入模块: from utils import net [as 别名]
# 或者: from utils.net import load_ckpt [as 别名]
def multi_gpu_test_net_on_dataset(
        args, dataset_name, proposal_file, num_images, output_dir):
    """Multi-gpu inference on a dataset."""
    binary_dir = envu.get_runtime_dir()
    binary_ext = envu.get_py_bin_ext()
    binary = os.path.join(binary_dir, args.test_net_file + binary_ext)
    assert os.path.exists(binary), 'Binary \'{}\' not found'.format(binary)

    # Pass the target dataset and proposal file (if any) via the command line
    opts = ['TEST.DATASETS', '("{}",)'.format(dataset_name)]
    if proposal_file:
        opts += ['TEST.PROPOSAL_FILES', '("{}",)'.format(proposal_file)]

    # Run inference in parallel in subprocesses
    # Outputs will be a list of outputs from each subprocess, where the output
    # of each subprocess is the dictionary saved by test_net().
    tag = 'discovery' if 'train' in dataset_name else 'detection'
    outputs = subprocess_utils.process_in_parallel(
        tag, num_images, binary, output_dir,
        args.load_ckpt, args.load_detectron, opts
    )

    # Collate the results from each subprocess
    all_boxes = {}
    for det_data in outputs:
        all_boxes_batch = det_data['all_boxes']
        all_boxes.update(all_boxes_batch)
    if 'train' in dataset_name:
        det_file = os.path.join(output_dir, 'discovery.pkl')
    else:
        det_file = os.path.join(output_dir, 'detections.pkl')
    cfg_yaml = yaml.dump(cfg)
    save_object(
        dict(
            all_boxes=all_boxes,
            cfg=cfg_yaml
        ), det_file
    )
    logger.info('Wrote detections to: {}'.format(os.path.abspath(det_file)))

    return all_boxes 
开发者ID:ppengtang,项目名称:pcl.pytorch,代码行数:43,代码来源:test_engine.py

示例6: _init_modules

# 需要导入模块: from utils import net [as 别名]
# 或者: from utils.net import load_ckpt [as 别名]
def _init_modules(self):
        if cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS != '':  # or cfg.MODEL.USE_SE_LOSS:
            logger.info("Loading pretrained weights from %s", cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS)
            resnet_utils.load_pretrained_imagenet_weights(self)
        # Check if shared weights are equaled
        if cfg.MODEL.MASK_ON and getattr(self.Mask_Head, 'SHARE_RES5', False):
            assert self.Mask_Head.res5.state_dict() == self.Box_Head.res5.state_dict()
        if cfg.MODEL.KEYPOINTS_ON and getattr(self.Keypoint_Head, 'SHARE_RES5', False):
            assert self.Keypoint_Head.res5.state_dict() == self.Box_Head.res5.state_dict()
        
        # load detectron pretrained weights for resnet
        if cfg.RESNETS.COCO_PRETRAINED_WEIGHTS != '':
            logger.info("loading detectron pretrained weights from %s", cfg.RESNETS.COCO_PRETRAINED_WEIGHTS)
            load_detectron_weight(self, cfg.RESNETS.COCO_PRETRAINED_WEIGHTS, ('cls_score', 'bbox_pred'))

        if cfg.VGG16.COCO_PRETRAINED_WEIGHTS != '':
            logger.info("loading pretrained weights from %s", cfg.VGG16.COCO_PRETRAINED_WEIGHTS)
            checkpoint = torch.load(cfg.VGG16.COCO_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage)
            # not using the last softmax layers
            del checkpoint['model']['Box_Outs.cls_score.weight']
            del checkpoint['model']['Box_Outs.cls_score.bias']
            del checkpoint['model']['Box_Outs.bbox_pred.weight']
            del checkpoint['model']['Box_Outs.bbox_pred.bias']
            net_utils.load_ckpt(self, checkpoint['model'])
            
        if cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS != '':
            logger.info("loading trained and to be finetuned weights from %s", cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS)
            checkpoint = torch.load(cfg.RESNETS.TO_BE_FINETUNED_WEIGHTS, map_location=lambda storage, loc: storage)
            net_utils.load_ckpt(self, checkpoint['model'])

        if cfg.TRAIN.FREEZE_CONV_BODY:
            for p in self.Conv_Body.parameters():
                p.requires_grad = False 
开发者ID:jz462,项目名称:Large-Scale-VRD.pytorch,代码行数:35,代码来源:model_builder.py

示例7: load_detector_weights

# 需要导入模块: from utils import net [as 别名]
# 或者: from utils.net import load_ckpt [as 别名]
def load_detector_weights(self, weight_name):
        logger.info("loading pretrained weights from %s", weight_name)
        checkpoint = torch.load(weight_name, map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(self, checkpoint['model'])
        # freeze everything above the rel module
        for p in self.Conv_Body.parameters():
            p.requires_grad = False
        for p in self.RPN.parameters():
            p.requires_grad = False
        if not cfg.MODEL.UNFREEZE_DET:
            for p in self.Box_Head.parameters():
                p.requires_grad = False
            for p in self.Box_Outs.parameters():
                p.requires_grad = False 
开发者ID:jz462,项目名称:Large-Scale-VRD.pytorch,代码行数:16,代码来源:model_builder_rel.py

示例8: parse_args

# 需要导入模块: from utils import net [as 别名]
# 或者: from utils.net import load_ckpt [as 别名]
def parse_args():
  """Parse input arguments."""
  parser = argparse.ArgumentParser(description='Face Detection using Faster R-CNN')

  parser.add_argument(
    '--exp_name', required=True, dest='det_dir', 
    help='detector name'
  )
  parser.add_argument(
    '--no_cuda', dest='cuda', help='whether use CUDA', 
    action='store_false'
  )  
  parser.add_argument(
    '--cfg', dest='cfg_file', required=True, 
    help='config file'
  )
  parser.add_argument(
    '--set', dest='set_cfgs',
    help='set config keys, will overwrite config in the cfg_file',
    default=[], nargs='+'
  )
  parser.add_argument(
    '--load_ckpt', help='path of checkpoint to load'
  )
  parser.add_argument(
    '--load_detectron', help='path to the detectron weight pickle file'
  )
  parser.add_argument(
    '--split', dest='split', default='val', help='train or val'
  ) 
  
  args = parser.parse_args()

  return args 
开发者ID:AruniRC,项目名称:detectron-self-train,代码行数:36,代码来源:run_face_detection_on_wider.py

示例9: parse_args

# 需要导入模块: from utils import net [as 别名]
# 或者: from utils.net import load_ckpt [as 别名]
def parse_args():
    """Parse in command line arguments"""
    parser = argparse.ArgumentParser(description='Demonstrate mask-rcnn results')
    parser.add_argument(
        '--dataset', required=True,
        help='training dataset')

    parser.add_argument(
        '--cfg', dest='cfg_file', required=True,
        help='optional config file')
    parser.add_argument(
        '--set', dest='set_cfgs',
        help='set config keys, will overwrite config in the cfg_file',
        default=[], nargs='+')

    parser.add_argument(
        '--no_cuda', dest='cuda', help='whether use CUDA', action='store_false')

    parser.add_argument('--load_ckpt', help='path of checkpoint to load')
    parser.add_argument(
        '--load_detectron', help='path to the detectron weight pickle file')

    parser.add_argument(
        '--image_dir',
        help='directory to load images for demo')
    parser.add_argument(
        '--images', nargs='+',
        help='images to infer. Must not use with --image_dir')
    parser.add_argument(
        '--output_dir',
        help='directory to save demo results',
        default="infer_outputs")
    parser.add_argument(
        '--merge_pdfs', type=distutils.util.strtobool, default=False)

    args = parser.parse_args()

    return args 
开发者ID:AruniRC,项目名称:detectron-self-train,代码行数:40,代码来源:infer_simple.py

示例10: multi_gpu_test_net_on_dataset

# 需要导入模块: from utils import net [as 别名]
# 或者: from utils.net import load_ckpt [as 别名]
def multi_gpu_test_net_on_dataset(
        args, dataset_name, proposal_file, num_images, output_dir):
    """Multi-gpu inference on a dataset."""
    binary_dir = envu.get_runtime_dir()
    binary_ext = envu.get_py_bin_ext()
    binary = os.path.join(binary_dir, args.test_net_file + binary_ext)
    assert os.path.exists(binary), 'Binary \'{}\' not found'.format(binary)

    # Pass the target dataset and proposal file (if any) via the command line
    opts = ['TEST.DATASETS', '("{}",)'.format(dataset_name)]
    if proposal_file:
        opts += ['TEST.PROPOSAL_FILES', '("{}",)'.format(proposal_file)]

    # Run inference in parallel in subprocesses
    # Outputs will be a list of outputs from each subprocess, where the output
    # of each subprocess is the dictionary saved by test_net().
    outputs = subprocess_utils.process_in_parallel(
        'detection', num_images, binary, output_dir,
        args.load_ckpt, args.load_detectron, opts
    )

    # Collate the results from each subprocess
    all_boxes = [[] for _ in range(cfg.MODEL.NUM_CLASSES)]
    all_segms = [[] for _ in range(cfg.MODEL.NUM_CLASSES)]
    all_keyps = [[] for _ in range(cfg.MODEL.NUM_CLASSES)]
    for det_data in outputs:
        all_boxes_batch = det_data['all_boxes']
        all_segms_batch = det_data['all_segms']
        all_keyps_batch = det_data['all_keyps']
        for cls_idx in range(1, cfg.MODEL.NUM_CLASSES):
            all_boxes[cls_idx] += all_boxes_batch[cls_idx]
            all_segms[cls_idx] += all_segms_batch[cls_idx]
            all_keyps[cls_idx] += all_keyps_batch[cls_idx]
    det_file = os.path.join(output_dir, 'detections.pkl')
    cfg_yaml = yaml.dump(cfg)
    save_object(
        dict(
            all_boxes=all_boxes,
            all_segms=all_segms,
            all_keyps=all_keyps,
            cfg=cfg_yaml
        ), det_file
    )
    logger.info('Wrote detections to: {}'.format(os.path.abspath(det_file)))

    return all_boxes, all_segms, all_keyps 
开发者ID:roytseng-tw,项目名称:Detectron.pytorch,代码行数:48,代码来源:test_engine.py

示例11: initialize_model_from_cfg

# 需要导入模块: from utils import net [as 别名]
# 或者: from utils.net import load_ckpt [as 别名]
def initialize_model_from_cfg(args, roidb=None, gpu_id=0):
    """Initialize a model from the global cfg. Loads test-time weights and
    set to evaluation mode.
    """
    model = model_builder.Generalized_RCNN()
    model.eval()

    cfg.immutable(False)
    cfg.TEST.CLASS_SPLIT = {'source': roidb[0]['source'], 'target': roidb[0]['target']}
    cfg.immutable(True)

    if 'word_embeddings' in roidb[0]:
        model.Box_Outs.set_word_embedding(torch.tensor(roidb[0]['word_embeddings']))
    if cfg.MODEL.IGNORE_CLASSES:
        if cfg.MODEL.IGNORE_CLASSES == 'all':
            roidb[0]['all'] = roidb[0]['source'] + roidb[0]['target']
        model._ignore_classes = roidb[0][cfg.MODEL.IGNORE_CLASSES]
        model.Box_Outs._ignore_classes = roidb[0][cfg.MODEL.IGNORE_CLASSES]
    if True:
        tmp = {}
        for rel in roidb[0]['relationships']:
            tmp[(rel['subject_id'], rel['object_id'])] = \
                tmp.get((rel['subject_id'], rel['object_id']), []) + [rel['rel_id']]
        if cfg.MODEL.RELATION_COOCCUR:
            for k in tmp:
                tmp[k] = [1]
        if cfg.MODEL.NUM_RELATIONS > 0:
            model.Rel_Outs.relationship_dict = tmp

    if args.cuda:
        model.cuda()

    if args.load_ckpt:
        load_name = args.load_ckpt
        logger.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(model, checkpoint['model'])

    if args.load_detectron:
        logger.info("loading detectron weights %s", args.load_detectron)
        load_detectron_weight(model, args.load_detectron)

    model = mynn.DataParallel(model, cpu_keywords=['im_info', 'roidb'], minibatch=True)

    return model 
开发者ID:ruotianluo,项目名称:Context-aware-ZSR,代码行数:47,代码来源:test_engine.py

示例12: multi_gpu_test_net_on_dataset

# 需要导入模块: from utils import net [as 别名]
# 或者: from utils.net import load_ckpt [as 别名]
def multi_gpu_test_net_on_dataset(
        args, dataset_name, proposal_file, num_images, output_dir):
    """Multi-gpu inference on a dataset."""
    binary_dir = envu.get_runtime_dir()
    binary_ext = envu.get_py_bin_ext()
    binary = os.path.join(binary_dir, args.test_net_file + binary_ext)
    assert os.path.exists(binary), 'Binary \'{}\' not found'.format(binary)

    # Pass the target dataset and proposal file (if any) via the command line
    opts = ['TEST.DATASETS', '("{}",)'.format(dataset_name)]
    if proposal_file:
        opts += ['TEST.PROPOSAL_FILES', '("{}",)'.format(proposal_file)]

    # Run inference in parallel in subprocesses
    # Outputs will be a list of outputs from each subprocess, where the output
    # of each subprocess is the dictionary saved by test_net().
    outputs = subprocess_utils.process_in_parallel(
        'detection', num_images, binary, output_dir,
        args.load_ckpt, args.load_detectron, args.net_name, args.mlp_head_dim, 
        args.heatmap_kernel_size, args.part_crop_size, args.use_kps17,
        opts)

    # Collate the results from each subprocess
    all_boxes = [[] for _ in range(cfg.MODEL.NUM_CLASSES)]
    all_segms = [[] for _ in range(cfg.MODEL.NUM_CLASSES)]
    all_keyps = [[] for _ in range(cfg.MODEL.NUM_CLASSES)]
    all_hois = {}
    all_losses = defaultdict(list)
    all_keyps_vcoco = [[] for _ in range(cfg.MODEL.NUM_CLASSES)]
    for det_data in outputs:
        all_boxes_batch = det_data['all_boxes']
        all_segms_batch = det_data['all_segms']
        all_keyps_batch = det_data['all_keyps']
        all_hois = {**all_hois, **det_data['all_hois']}
        for k, v in det_data['all_losses'].items():
            all_losses[k].extend(v)

        all_keyps_vcoco_batch = det_data['all_keyps_vcoco']
        for cls_idx in range(1, cfg.MODEL.NUM_CLASSES):
            all_boxes[cls_idx] += all_boxes_batch[cls_idx]
            all_segms[cls_idx] += all_segms_batch[cls_idx]
            all_keyps[cls_idx] += all_keyps_batch[cls_idx]
            all_keyps_vcoco[cls_idx] += all_keyps_vcoco_batch[cls_idx]
    det_file = os.path.join(output_dir, 'detections.pkl')
    cfg_yaml = yaml.dump(cfg)
    save_object(
        dict(
            all_boxes=all_boxes,
            all_segms=all_segms,
            all_keyps=all_keyps,
            all_hois=all_hois,
            all_keyps_vcoco=all_keyps_vcoco,
            all_losses=all_losses,
            cfg=cfg_yaml
        ), det_file
    )
    logger.info('Wrote detections to: {}'.format(os.path.abspath(det_file)))

    return all_boxes, all_segms, all_keyps, all_hois, all_keyps_vcoco, all_losses 
开发者ID:bobwan1995,项目名称:PMFNet,代码行数:61,代码来源:test_engine.py

示例13: _init_modules

# 需要导入模块: from utils import net [as 别名]
# 或者: from utils.net import load_ckpt [as 别名]
def _init_modules(self):
        # VGG16 imagenet pretrained model is initialized in VGG16.py
        if cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS != '':
            logger.info("Loading pretrained weights from %s", cfg.RESNETS.IMAGENET_PRETRAINED_WEIGHTS)
            resnet_utils.load_pretrained_imagenet_weights(self)
                
        if cfg.RESNETS.VRD_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.RESNETS.VRD_PRETRAINED_WEIGHTS)
        if cfg.VGG16.VRD_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.VGG16.VRD_PRETRAINED_WEIGHTS)
            
        if cfg.RESNETS.VG_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.RESNETS.VG_PRETRAINED_WEIGHTS)
        if cfg.VGG16.VG_PRETRAINED_WEIGHTS != '':
            self.load_detector_weights(cfg.VGG16.VG_PRETRAINED_WEIGHTS)

        if cfg.TRAIN.FREEZE_CONV_BODY:
            for p in self.Conv_Body.parameters():
                p.requires_grad = False

        if cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS != '' or \
            cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS != '' or cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS != '':
            if cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s", cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(cfg.RESNETS.VRD_PRD_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage)
            if cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s", cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(cfg.VGG16.VRD_PRD_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage)
            if cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s", cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(cfg.RESNETS.VG_PRD_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage)
            if cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS != '':
                logger.info("loading prd pretrained weights from %s", cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS)
                checkpoint = torch.load(cfg.VGG16.VG_PRD_PRETRAINED_WEIGHTS, map_location=lambda storage, loc: storage)
            # not using the last softmax layers
            del checkpoint['model']['Box_Outs.cls_score.weight']
            del checkpoint['model']['Box_Outs.cls_score.bias']
            del checkpoint['model']['Box_Outs.bbox_pred.weight']
            del checkpoint['model']['Box_Outs.bbox_pred.bias']
            net_utils.load_ckpt(self.Prd_RCNN, checkpoint['model'])
            if cfg.TRAIN.FREEZE_PRD_CONV_BODY:
                for p in self.Prd_RCNN.Conv_Body.parameters():
                    p.requires_grad = False
            if cfg.TRAIN.FREEZE_PRD_BOX_HEAD:
                for p in self.Prd_RCNN.Box_Head.parameters():
                    p.requires_grad = False 
开发者ID:jz462,项目名称:Large-Scale-VRD.pytorch,代码行数:48,代码来源:model_builder_rel.py


注:本文中的utils.net.load_ckpt方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。