当前位置: 首页>>代码示例>>Python>>正文


Python data.get_test_loader方法代码示例

本文整理汇总了Python中data.get_test_loader方法的典型用法代码示例。如果您正苦于以下问题:Python data.get_test_loader方法的具体用法?Python data.get_test_loader怎么用?Python data.get_test_loader使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在data的用法示例。


在下文中一共展示了data.get_test_loader方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: eval_with_extended

# 需要导入模块: import data [as 别名]
# 或者: from data import get_test_loader [as 别名]
def eval_with_extended(model_path, data_path=None, data_name=None, split='test'):
    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    opt.use_external_captions = True
    opt.negative_number = 5
    if data_path is not None:
        opt.data_path = data_path
    if data_name is not None:
        opt.data_name = data_name

    # load vocabulary used by the model
    with open(os.path.join(opt.vocab_path,
                           '%s_vocab.pkl' % opt.data_name), 'rb') as f:
        vocab = pickle.load(f)
    opt.vocab_size = len(vocab)
    opt.use_external_captions = True

    # construct model
    model = VSE(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])

    print('Loading dataset')
    data_loader = get_test_loader(split, opt.data_name, vocab, opt.crop_size,
                                  opt.batch_size, opt.workers, opt)
    print('Computing results...')
    img_embs, cap_embs = encode_data(model, data_loader)
    print('Images: %d, Captions: %d' %
          (img_embs.shape[0] // 5, cap_embs.shape[0]))

    r, rt = i2t_text_only(img_embs, cap_embs, measure=opt.measure, return_ranks=True)
    ar = (r[0] + r[1] + r[2]) / 3
    print("Average i2t Recall: %.1f" % ar)
    print("Image to text: %.1f\t%.1f\t%.1f\t%.1f\t%.1f" % r)
    torch.save({'rt': rt}, model_path[:model_path.find('model_best')] + 'ranks_extended.pth.tar') 
开发者ID:ExplorerFreda,项目名称:VSE-C,代码行数:38,代码来源:evaluation.py

示例2: eval_with_single_extended

# 需要导入模块: import data [as 别名]
# 或者: from data import get_test_loader [as 别名]
def eval_with_single_extended(model_path, data_path=None, data_name=None, split='test', backup_vec_ex=None):
    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    opt.use_external_captions = False
    if data_path is not None:
        opt.data_path = data_path
    if data_name is not None:
        opt.data_name = data_name

    # load vocabulary used by the model
    with open(os.path.join(opt.vocab_path,
                           '%s_vocab.pkl' % opt.data_name), 'rb') as f:
        vocab = pickle.load(f)
    opt.vocab_size = len(vocab)

    # construct model
    model = VSE(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])

    print('Loading dataset')
    data_loader = get_test_loader(split, opt.data_name, vocab, opt.crop_size,
                                  opt.batch_size, opt.workers, opt)
    img_embs, cap_embs = encode_data(model, data_loader)
    if backup_vec_ex is None:
        cap_embs_ex = list()
        for i in range(img_embs.shape[0]):
            data_loader_ex = get_text_loader(
                split, opt.data_name, vocab, opt.batch_size, opt.workers, opt, 'ex/%d' % i)
            encoding = encode_data(model, data_loader_ex)[1]
            if encoding is not None:
                cap_embs_ex.append(encoding.copy())
            else:
                cap_embs_ex.append(np.zeros(cap_embs[:1].shape))
            print('Caption Embedding: %d' % i)
        # torch.save(cap_embs_ex, 'data/coco_precomp/cap_embs_ex.pth')
    else:
        cap_embs_ex = torch.load(backup_vec_ex)
    print('Computing results...')

    r, rt = i2t_split(img_embs, cap_embs, cap_embs_ex, measure=opt.measure, return_ranks=True)
    ar = (r[0] + r[1] + r[2]) / 3
    print("Average i2t Recall: %.1f" % ar)
    print("Image to text: %.1f\t%.1f\t%.1f\t%.1f\t%.1f" % r)
    torch.save({'rt': rt}, model_path[:model_path.find('model_best')] + 'ranks_single_extended.pth.tar') 
开发者ID:ExplorerFreda,项目名称:VSE-C,代码行数:48,代码来源:evaluation.py

示例3: eval_with_manually_extended

# 需要导入模块: import data [as 别名]
# 或者: from data import get_test_loader [as 别名]
def eval_with_manually_extended(model_path, data_path=None, split='test'):
    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    opt.use_external_captions = False
    if data_path is not None:
        opt.data_path = data_path

    # load vocabulary used by the model
    with open(os.path.join(opt.vocab_path,
                           '%s_vocab.pkl' % opt.data_name), 'rb') as f:
        vocab = pickle.load(f)
    opt.vocab_size = len(vocab)

    # construct model
    model = VSE(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])

    print('Loading dataset')
    data_loader = get_test_loader(split, opt.data_name, vocab, opt.crop_size,
                                  opt.batch_size, opt.workers, opt)
    img_embs, cap_embs = encode_data(model, data_loader)
    img_embs = img_embs[:100]
    cap_embs = cap_embs[:100]
    cap_embs_ex = list()
    data_loader_ex_0 = get_text_loader(
        split, opt.data_name, vocab, opt.batch_size, opt.workers, opt, 'manually_ex_%d' % 0)
    encoding_0 = encode_data(model, data_loader_ex_0)[1]
    data_loader_ex_1 = get_text_loader(
        split, opt.data_name, vocab, opt.batch_size, opt.workers, opt, 'manually_ex_%d' % 1)
    encoding_1 = encode_data(model, data_loader_ex_1)[1]
    for i in range(100):
        cap_emb = np.concatenate((encoding_0[i*2:i*2+2], encoding_1[i*2:i*2+2]), axis=0)
        cap_embs_ex.append(cap_emb)
    print('Computing results...')

    r, rt = i2t_split(img_embs, cap_embs, cap_embs_ex, measure=opt.measure, return_ranks=True)
    # r, rt = i2t(img_embs, cap_embs, measure=opt.measure, return_ranks=True)
    ar = (r[0] + r[1] + r[2]) / 3
    print("Average i2t Recall: %.1f" % ar)
    print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
    torch.save({'rt': rt}, model_path[:model_path.find('model_best')] + 'ranks_manually_extended_1.pth.tar') 
开发者ID:ExplorerFreda,项目名称:VSE-C,代码行数:45,代码来源:evaluation.py

示例4: debug_show_similarity_with_manually_created_examples

# 需要导入模块: import data [as 别名]
# 或者: from data import get_test_loader [as 别名]
def debug_show_similarity_with_manually_created_examples(model_path, data_path=None, split='test'):
    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    opt.use_external_captions = False
    if data_path is not None:
        opt.data_path = data_path

    # load vocabulary used by the model
    with open(os.path.join(opt.vocab_path,
                           '%s_vocab.pkl' % opt.data_name), 'rb') as f:
        vocab = pickle.load(f)
    opt.vocab_size = len(vocab)

    # construct model
    model = VSE(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])

    print('Loading dataset')
    data_loader = get_test_loader(split, opt.data_name, vocab, opt.crop_size,
                                  opt.batch_size, opt.workers, opt)
    img_embs, cap_embs = encode_data(model, data_loader)
    img_embs = img_embs[:100]
    cap_embs = cap_embs[:100]
    data_loader_ex_0 = get_text_loader(
        split, opt.data_name, vocab, opt.batch_size, opt.workers, opt, 'manually_ex_%d' % 0)
    encoding_0 = encode_data(model, data_loader_ex_0)[1]
    data_loader_ex_1 = get_text_loader(
        split, opt.data_name, vocab, opt.batch_size, opt.workers, opt, 'manually_ex_%d' % 1)
    encoding_1 = encode_data(model, data_loader_ex_1)[1]
    print('Computing results...')

    # compute similarity
    result = list()
    result_0 = list()
    result_1 = list()

    npts = img_embs.shape[0] // 5
    for index in range(npts):
        # Get query image
        im = img_embs[5 * index].reshape(1, img_embs.shape[1])

        # Compute scores
        if opt.measure == 'order':
            raise Exception('Measure order not supported.')
        else:
            result.append(numpy.dot(im, cap_embs.T).flatten())
            result_0.append(numpy.dot(im, encoding_0.T).flatten())
            result_1.append(numpy.dot(im, encoding_1.T).flatten())
    torch.save({'orig': result, 'Tete': result_0, 'Haoyue': result_1}, 'shy_runs/debug.pt') 
开发者ID:ExplorerFreda,项目名称:VSE-C,代码行数:53,代码来源:evaluation.py

示例5: evalrank

# 需要导入模块: import data [as 别名]
# 或者: from data import get_test_loader [as 别名]
def evalrank(model, args, split='test'):
  print('Loading dataset')
  data_loader = get_test_loader(args, vocab)

  print('Computing results... (eval_on_gpu={})'.format(args.eval_on_gpu))
  img_embs, txt_embs = encode_data(model, data_loader, args.eval_on_gpu)
  n_samples = img_embs.shape[0]

  nreps = 5 if args.data_name == 'coco' else 1
  print('Images: %d, Sentences: %d' % (img_embs.shape[0] / nreps, txt_embs.shape[0]))

  # 5fold cross-validation, only for MSCOCO
  mean_metrics = None
  if args.data_name == 'coco':
    results = []
    for i in range(5):
      r, rt0 = i2t(img_embs[i*5000:(i + 1)*5000], txt_embs[i*5000:(i + 1)*5000], 
                   nreps=nreps, return_ranks=True, order=args.order, use_gpu=args.eval_on_gpu)
      r = (r[0], r[1], r[2], r[3], r[3] / n_samples, r[4], r[4] / n_samples)
      print("Image to text: %.2f, %.2f, %.2f, %.2f (%.2f), %.2f (%.2f)" % r)

      ri, rti0 = t2i(img_embs[i*5000:(i + 1)*5000], txt_embs[i*5000:(i + 1)*5000], 
                     nreps=nreps, return_ranks=True, order=args.order, use_gpu=args.eval_on_gpu)
      if i == 0:
        rt, rti = rt0, rti0
      ri = (ri[0], ri[1], ri[2], ri[3], ri[3] / n_samples, ri[4], ri[4] / n_samples)
      print("Text to image: %.2f, %.2f, %.2f, %.2f (%.2f), %.2f (%.2f)" % ri)

      ar = (r[0] + r[1] + r[2]) / 3
      ari = (ri[0] + ri[1] + ri[2]) / 3
      rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
      print("rsum: %.2f ar: %.2f ari: %.2f" % (rsum, ar, ari))
      results += [list(r) + list(ri) + [ar, ari, rsum]]

    mean_metrics = tuple(np.array(results).mean(axis=0).flatten())

    print("-----------------------------------")
    print("Mean metrics from 5-fold evaluation: ")
    print("rsum: %.2f" % (mean_metrics[-1] * 6))
    print("Average i2t Recall: %.2f" % mean_metrics[-3])
    print("Image to text: %.2f %.2f %.2f %.2f (%.2f) %.2f (%.2f)" % mean_metrics[:7])
    print("Average t2i Recall: %.2f" % mean_metrics[-2])
    print("Text to image: %.2f %.2f %.2f %.2f (%.2f) %.2f (%.2f)" % mean_metrics[7:14])

  # no cross-validation, full evaluation
  r, rt = i2t(img_embs, txt_embs, nreps=nreps, return_ranks=True, use_gpu=args.eval_on_gpu)
  ri, rti = t2i(img_embs, txt_embs, nreps=nreps, return_ranks=True, use_gpu=args.eval_on_gpu)
  ar = (r[0] + r[1] + r[2]) / 3
  ari = (ri[0] + ri[1] + ri[2]) / 3
  rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
  r = (r[0], r[1], r[2], r[3], r[3] / n_samples, r[4], r[4] / n_samples)
  ri = (ri[0], ri[1], ri[2], ri[3], ri[3] / n_samples, ri[4], ri[4] / n_samples)
  print("rsum: %.2f" % rsum)
  print("Average i2t Recall: %.2f" % ar)
  print("Image to text: %.2f %.2f %.2f %.2f (%.2f) %.2f (%.2f)" % r)
  print("Average t2i Recall: %.2f" % ari)
  print("Text to image: %.2f %.2f %.2f %.2f (%.2f) %.2f (%.2f)" % ri)

  return mean_metrics 
开发者ID:yalesong,项目名称:pvse,代码行数:61,代码来源:eval.py

示例6: test_CAMP_model

# 需要导入模块: import data [as 别名]
# 或者: from data import get_test_loader [as 别名]
def test_CAMP_model(config_path):
    print("OK!")
    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    parser = argparse.ArgumentParser()
    #config_path = "./experiments/f30k_cross_attention/config_test.yaml"
    with open(config_path) as f:
        opt = yaml.load(f)
    opt = EasyDict(opt['common'])


    vocab = pickle.load(open(os.path.join(opt.vocab_path,
            '%s_vocab.pkl' % opt.data_name), 'rb'))
    opt.vocab_size = len(vocab)

    train_logger = LogCollector()

    print("----Start init model----")
    CAMP = model.CAMP(opt)
    CAMP.logger = train_logger

    if opt.resume is not None:
       ckp = torch.load(opt.resume)
       CAMP.load_state_dict(ckp["model"])

    CAMP.train_start()
    print("----Model init success----")

    """
    fake_img = torch.randn(16, 36, opt.img_dim)
    fake_text = torch.ones(16, 32).long()
    fake_lengths = torch.Tensor([32] * 16)
    fake_pos = torch.ones(16, 32).long()
    fake_ids = torch.ones(16).long()

    CAMP.train_emb(fake_img, fake_text, fake_lengths,
                   instance_ids=fake_ids)
    print("----Test train_emb success----")
    """
    
    train_loader, val_loader = data.get_loaders(
        opt.data_name, vocab, opt.crop_size, 128, 4, opt)

    test_loader = data.get_test_loader("test", opt.data_name, vocab, opt.crop_size, 128, 4, opt)

    CAMP.val_start()
    img_embs, cap_embs, cap_masks = encode_data(
        CAMP, test_loader, opt.log_step, logging.info)


    (r1, r5, r10, medr, meanr), (r1i, r5i, r10i, medri, meanri), score_matrix= i2t(img_embs, cap_embs, cap_masks, measure=opt.measure,
                                     model=CAMP, return_ranks=True)
    logging.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" %
                 (r1, r5, r10, medr, meanr))
    logging.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" %
                 (r1i, r5i, r10i, medri, meanri)) 
开发者ID:ZihaoWang-CV,项目名称:CAMP_iccv19,代码行数:57,代码来源:test_modules.py


注:本文中的data.get_test_loader方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。