当前位置: 首页>>代码示例>>Python>>正文


Python models.get_model方法代码示例

本文整理汇总了Python中models.get_model方法的典型用法代码示例。如果您正苦于以下问题:Python models.get_model方法的具体用法?Python models.get_model怎么用?Python models.get_model使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在models的用法示例。


在下文中一共展示了models.get_model方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: import models [as 别名]
# 或者: from models import get_model [as 别名]
def main():
  opt, logger, vis = utils.build(is_train=False)

  dloader = data.get_data_loader(opt)
  print('Val dataset: {}'.format(len(dloader.dataset)))
  model = models.get_model(opt)

  for epoch in opt.which_epochs:
    # Load checkpoint
    if epoch == -1:
      # Find the latest checkpoint
      checkpoints = glob.glob(os.path.join(opt.ckpt_path, 'net*.pth'))
      assert len(checkpoints) > 0
      epochs = [int(filename.split('_')[-1][:-4]) for filename in checkpoints]
      epoch = max(epochs)
    logger.print('Loading checkpoints from {}, epoch {}'.format(opt.ckpt_path, epoch))
    model.load(opt.ckpt_path, epoch)

    results = evaluate(opt, dloader, model)
    for metric in results:
      logger.print('{}: {}'.format(metric, results[metric])) 
开发者ID:jthsieh,项目名称:DDPAE-video-prediction,代码行数:23,代码来源:test.py

示例2: main

# 需要导入模块: import models [as 别名]
# 或者: from models import get_model [as 别名]
def main(args):

    set_cuda(args)
    set_seed(args)

    loader_train, loader_val, loader_test = get_data_loaders(args)
    loss = get_loss(args)
    model = get_model(args)
    optimizer = get_optimizer(args, parameters=model.parameters())
    xp = setup_xp(args, model, optimizer)

    for i in range(args.epochs):
        xp.epoch.update(i)

        train(model, loss, optimizer, loader_train, args, xp)
        test(model, loader_val, args, xp)

        if (i + 1) in args.T:
            decay_optimizer(optimizer, args.decay_factor)

    load_best_model(model, xp)
    test(model, loader_test, args, xp) 
开发者ID:oval-group,项目名称:dfw,代码行数:24,代码来源:main.py

示例3: __init__

# 需要导入模块: import models [as 别名]
# 或者: from models import get_model [as 别名]
def __init__(self):
        """Initializes exp."""
        super().__init__()
        self.name = "NFEL5836GRU"
        self.description = "Sequence of normalized face, eyes and landmarks. Frozen static model, fine-tune fusion " \
                           "layers and train RNN-GRU module from scratch"
        self.recurrent_type = "gru"
        self.num_recurrent_layers = 1
        self.num_recurrent_units = 128
        self.look_back = 4
        self.weights = exp_utils.NFEL5836GRU_VGG16
        self.min_lndmk = exp_utils.NFEL5836GRU_MIN_LNMDK
        self.max_lndmk = exp_utils.NFEL5836GRU_MAX_LNMDK
        self.label_pos = -1
        self.model = get_model("two_stream_rnn")
        print(self.name)
        print(self.description)

        self.feature_arch = NFEL5836_2918()
        self.base_model = self.feature_arch.base_model 
开发者ID:crisie,项目名称:RecurrentGaze,代码行数:22,代码来源:experiment_helper.py

示例4: __init__

# 需要导入模块: import models [as 别名]
# 或者: from models import get_model [as 别名]
def __init__(self, model_path, gpu_id=None):
        '''
        初始化pytorch模型
        :param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件)
        :param gpu_id: 在哪一块gpu上运行
        '''
        self.gpu_id = gpu_id
        
        if self.gpu_id is not None and isinstance(self.gpu_id, int) and torch.cuda.is_available():
            self.device = torch.device("cuda:%s" % self.gpu_id)
            checkpoint = torch.load(model_path)
        else:
            self.device = torch.device("cpu")
            checkpoint = torch.load(model_path, map_location='cpu')
        print('device:', self.device)

        config = checkpoint['config']
        config['arch']['args']['pretrained'] = False
        self.net = get_model(config)

        self.img_channel = config['data_loader']['args']['dataset']['img_channel']
        self.net.load_state_dict(checkpoint['state_dict']) ## load weights
        self.net.to(self.device)
        self.net.eval() 
开发者ID:SURFZJY,项目名称:Real-time-Text-Detection,代码行数:26,代码来源:predict.py

示例5: prepare_reg_feat

# 需要导入模块: import models [as 别名]
# 或者: from models import get_model [as 别名]
def prepare_reg_feat(hseq_utils, reg_model, overwrite):
    in_img_path = []
    out_img_feat_list = []
    for seq_name in hseq_utils.seqs:
        for img_idx in range(1, 7):
            img_feat_path = os.path.join(seq_name, '%d_img_feat.npy' % img_idx)
            if not os.path.exists(img_feat_path) or overwrite:
                in_img_path.append(os.path.join(seq_name, '%d.ppm' % img_idx))
                out_img_feat_list.append(img_feat_path)

    if len(in_img_path) > 0:
        model = get_model('reg_model')(reg_model)
        prog_bar = progressbar.ProgressBar()
        prog_bar.max_value = len(in_img_path)
        for idx, val in enumerate(in_img_path):
            img = cv2.imread(val)
            img = img[..., ::-1]
            reg_feat = model.run_test_data(img)
            np.save(out_img_feat_list[idx], reg_feat)
            prog_bar.update(idx)
        model.close() 
开发者ID:luigifreda,项目名称:pyslam,代码行数:23,代码来源:hseq_eval.py

示例6: extract_reg_feat

# 需要导入模块: import models [as 别名]
# 或者: from models import get_model [as 别名]
def extract_reg_feat(config):
    """Extract regional features."""
    prog_bar = progressbar.ProgressBar()
    config['stage'] = 'reg'
    dataset = get_dataset(config['data_name'])(**config)
    prog_bar.max_value = dataset.data_length
    test_set = dataset.get_test_set()

    model = get_model('reg_model')(config['pretrained']['reg_model'], **(config['reg_feat']))
    idx = 0
    while True:
        try:
            data = next(test_set)
            dump_path = data['dump_path'].decode('utf-8')
            reg_f = h5py.File(dump_path, 'a')
            if 'reg_feat' not in reg_f or config['reg_feat']['overwrite']:
                reg_feat = model.run_test_data(data['image'])
                if 'reg_feat' in reg_f:
                    del reg_f['reg_feat']
                _ = reg_f.create_dataset('reg_feat', data=reg_feat)
            prog_bar.update(idx)
            idx += 1
        except dataset.end_set:
            break
    model.close() 
开发者ID:luigifreda,项目名称:pyslam,代码行数:27,代码来源:evaluations.py

示例7: run

# 需要导入模块: import models [as 别名]
# 或者: from models import get_model [as 别名]
def run(config, num_checkpoint, epoch_end, output_filename):
    dataloader = get_dataloader(config, 'train', get_transform(config, 'val'))

    model = get_model(config).cuda()
    checkpoints = get_checkpoints(config, num_checkpoint, epoch_end)

    utils.checkpoint.load_checkpoint(model, None, checkpoints[0])
    for i, checkpoint in enumerate(checkpoints[1:]):
        model2 = get_model(config).cuda()
        last_epoch, _ = utils.checkpoint.load_checkpoint(model2, None, checkpoint)
        swa.moving_average(model, model2, 1. / (i + 2))

    with torch.no_grad():
        swa.bn_update(dataloader, model)

    output_name = '{}.{}.{:03d}'.format(output_filename, num_checkpoint, last_epoch)
    print('save {}'.format(output_name))
    utils.checkpoint.save_checkpoint(config, model, None, 0, 0,
                                     name=output_name,
                                     weights_dict={'state_dict': model.state_dict()}) 
开发者ID:pudae,项目名称:kaggle-hpa,代码行数:22,代码来源:swa.py

示例8: run

# 需要导入模块: import models [as 别名]
# 或者: from models import get_model [as 别名]
def run(config):
    train_dir = config.train.dir

    model = get_model(config).cuda()
    criterion = get_loss(config)
    optimizer = get_optimizer(config, model.parameters())

    checkpoint = utils.checkpoint.get_initial_checkpoint(config)
    if checkpoint is not None:
        last_epoch, step = utils.checkpoint.load_checkpoint(model, optimizer, checkpoint)
    else:
        last_epoch, step = -1, -1

    print('from checkpoint: {} last epoch:{}'.format(checkpoint, last_epoch))
    scheduler = get_scheduler(config, optimizer, last_epoch)

    dataloaders = {split:get_dataloader(config, split, get_transform(config, split))
                   for split in ['train', 'val']}

    writer = SummaryWriter(config.train.dir)
    train(config, model, dataloaders, criterion, optimizer, scheduler,
          writer, last_epoch+1) 
开发者ID:pudae,项目名称:kaggle-hpa,代码行数:24,代码来源:train.py

示例9: __init__

# 需要导入模块: import models [as 别名]
# 或者: from models import get_model [as 别名]
def __init__(self, config):
        self.anchors = [np.array(LandmarkDetector.RATIO) * s for s in LandmarkDetector.SCALE]
        self.anchors = np.concatenate(self.anchors, axis=0)
        assert self.anchors.shape == (len(LandmarkDetector.SCALE) * len(LandmarkDetector.RATIO), 2)
        self.feature_size = config.model.params.feature_size
        self.num_anchors = len(LandmarkDetector.SCALE) * len(LandmarkDetector.RATIO)

        num_outputs = LandmarkDetector.NUM_OUTPUTS
        self.model = get_model(config, num_outputs=num_outputs)
        self.model.avgpool = nn.AdaptiveAvgPool2d(self.feature_size)
        in_features = self.model.last_linear.in_features
        self.model.last_linear = nn.Conv2d(in_channels=in_features,
                                           out_channels=len(self.anchors)*num_outputs,
                                           kernel_size=1)
        def logits(self, features):
            x = self.avgpool(features)
            x = self.last_linear(x)
            return x

        self.model.logits = types.MethodType(logits, self.model)

        if torch.cuda.device_count() > 1:
            self.model = torch.nn.DataParallel(self.model)
        self.model = self.model.cuda()
        self.preprocess_opt = {'mean': self.model.mean,
                               'std': self.model.std,
                               'input_range': self.model.input_range,
                               'input_space': self.model.input_space}

        self.criterion = get_loss(config)
        self.cls_criterion = F.binary_cross_entropy_with_logits 
开发者ID:pudae,项目名称:kaggle-humpback,代码行数:33,代码来源:landmark_detector.py

示例10: get_model

# 需要导入模块: import models [as 别名]
# 或者: from models import get_model [as 别名]
def get_model(self):
        return self.model 
开发者ID:pudae,项目名称:kaggle-humpback,代码行数:4,代码来源:landmark_detector.py

示例11: __init__

# 需要导入模块: import models [as 别名]
# 或者: from models import get_model [as 别名]
def __init__(self, config):
        super().__init__()
        num_outputs = config.model.params.num_outputs
        feature_size = config.model.params.feature_size
        if 'channel_size' in config.model.params:
            channel_size = config.model.params.channel_size
        else:
            channel_size = 512

        self.model = get_model(config)
        if isinstance(self.model.last_linear, nn.Conv2d):
            in_features = self.model.last_linear.in_channels
        else:
            in_features = self.model.last_linear.in_features
        self.bn1 = nn.BatchNorm2d(in_features)
        self.dropout = nn.Dropout2d(config.model.params.drop_rate, inplace=True)
        self.fc1 = nn.Linear(in_features * feature_size * feature_size, channel_size)
        self.bn2 = nn.BatchNorm1d(channel_size)

        s = config.model.params.s if 's' in config.model.params else 65
        m = config.model.params.m if 'm' in config.model.params else 0.5
        self.arc = ArcModule(channel_size, num_outputs, s=s, m=m)

        if config.model.params.pretrained:
            self.mean = self.model.mean
            self.std = self.model.std
            self.input_range = self.model.input_range
            self.input_space = self.model.input_space
        else:
            self.mean = [0.5, 0.5, 0.5]
            self.std = [0.5, 0.5, 0.5]
            self.input_range = [0, 1]
            self.input_space = 'RGB' 
开发者ID:pudae,项目名称:kaggle-humpback,代码行数:35,代码来源:identifier.py

示例12: main

# 需要导入模块: import models [as 别名]
# 或者: from models import get_model [as 别名]
def main(args):
  """Run testing."""
  test_data = utils.read_data(args, "test")
  print("total test samples:%s" % test_data.num_examples)

  if args.random_other:
    print("warning, testing mode with 'random_other' will result in "
          "different results every run...")

  model = models.get_model(args, gpuid=args.gpuid)
  tfconfig = tf.ConfigProto(allow_soft_placement=True)
  tfconfig.gpu_options.allow_growth = True
  tfconfig.gpu_options.visible_device_list = "%s" % (
      ",".join(["%s" % i for i in [args.gpuid]]))

  with tf.Session(config=tfconfig) as sess:
    utils.initialize(load=True, load_best=args.load_best,
                     args=args, sess=sess)

    # load the graph and variables
    tester = models.Tester(model, args, sess)

    perf = utils.evaluate(test_data, args, sess, tester)

  print("performance:")
  numbers = []
  for k in sorted(perf.keys()):
    print("%s, %s" % (k, perf[k]))
    numbers.append("%s" % perf[k])
  print(" ".join(sorted(perf.keys())))
  print(" ".join(numbers)) 
开发者ID:google,项目名称:next-prediction,代码行数:33,代码来源:test.py

示例13: __init__

# 需要导入模块: import models [as 别名]
# 或者: from models import get_model [as 别名]
def __init__(self, model_path, gpu_id=None):
        """
        初始化gluon模型
        :param model_path: 模型地址
        :param gpu_id: 在哪一块gpu上运行
        """
        info = pickle.load(open(model_path.replace('.params', '.info'), 'rb'))
        print('load {} epoch params'.format(info['epoch']))
        config = info['config']
        alphabet = config['dataset']['alphabet']
        self.ctx = try_gpu(gpu_id)

        self.transform = []
        for t in config['dataset']['train']['dataset']['args']['transforms']:
            if t['type'] in ['ToTensor', 'Normalize']:
                self.transform.append(t)
        self.transform = get_transforms(self.transform)

        self.gpu_id = gpu_id
        img_h, img_w = 32, 100
        for process in config['dataset']['train']['dataset']['args']['pre_processes']:
            if process['type'] == "Resize":
                img_h = process['args']['img_h']
                img_w = process['args']['img_w']
                break
        self.img_w = img_w
        self.img_h = img_h
        self.img_mode = config['dataset']['train']['dataset']['args']['img_mode']
        self.alphabet = alphabet
        self.net = get_model(len(alphabet), self.ctx, config['arch']['args'])
        self.net.load_parameters(model_path, self.ctx)
        # self.net = gluon.SymbolBlock.imports('crnn_lite-symbol.json', ['data'], 'crnn_lite-0000.params', ctx=self.ctx)
        self.net.hybridize() 
开发者ID:WenmuZhou,项目名称:crnn.gluon,代码行数:35,代码来源:predict.py

示例14: main

# 需要导入模块: import models [as 别名]
# 或者: from models import get_model [as 别名]
def main(config):
    train_loader = get_dataloader(config['data_loader']['type'], config['data_loader']['args'])

    criterion = get_loss(config).cuda()

    model = get_model(config)

    trainer = Trainer(config=config,
                      model=model,
                      criterion=criterion,
                      train_loader=train_loader)
    trainer.train() 
开发者ID:SURFZJY,项目名称:Real-time-Text-Detection,代码行数:14,代码来源:train.py

示例15: extract_loc_feat

# 需要导入模块: import models [as 别名]
# 或者: from models import get_model [as 别名]
def extract_loc_feat(config):
    """Extract local features."""
    prog_bar = progressbar.ProgressBar()
    config['stage'] = 'loc'
    dataset = get_dataset(config['data_name'])(**config)
    prog_bar.max_value = dataset.data_length
    test_set = dataset.get_test_set()

    model = get_model('loc_model')(config['pretrained']['loc_model'], **(config['loc_feat']))
    idx = 0
    while True:
        try:
            data = next(test_set)
            dump_path = data['dump_path'].decode('utf-8')
            loc_f = h5py.File(dump_path, 'a')
            if 'loc_info' not in loc_f and 'kpt' not in loc_f or config['loc_feat']['overwrite']:
                # detect SIFT keypoints and crop image patches.
                loc_feat, kpt_mb, npy_kpts, cv_kpts, _ = model.run_test_data(data['image'])
                loc_info = np.concatenate((npy_kpts, loc_feat, kpt_mb), axis=-1)
                raw_kpts = [np.array((i.pt[0], i.pt[1], i.size, i.angle, i.response))
                            for i in cv_kpts]
                raw_kpts = np.stack(raw_kpts, axis=0)
                loc_info = np.concatenate((raw_kpts, loc_info), axis=-1)
                if 'loc_info' in loc_f or 'kpt' in loc_f:
                    del loc_f['loc_info']
                _ = loc_f.create_dataset('loc_info', data=loc_info)
            prog_bar.update(idx)
            idx += 1
        except dataset.end_set:
            break
    model.close() 
开发者ID:luigifreda,项目名称:pyslam,代码行数:33,代码来源:evaluations.py


注:本文中的models.get_model方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。