当前位置: 首页>>代码示例>>Python>>正文


Python data.get_data方法代码示例

本文整理汇总了Python中data.get_data方法的典型用法代码示例。如果您正苦于以下问题:Python data.get_data方法的具体用法?Python data.get_data怎么用?Python data.get_data使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在data的用法示例。


在下文中一共展示了data.get_data方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: load_data

# 需要导入模块: import data [as 别名]
# 或者: from data import get_data [as 别名]
def load_data(data_dir, data_file, max_num_images=-1):
  data_path = os.path.join(data_dir, data_file)
  if max_num_images >= 0:
    data_source = get_data(data_path, data_dir, max_num_images)
  else:
    data_source = get_data(data_path, data_dir)
  num_images = len(data_source.images)
  num_batches = (num_images + batch_size - 1) / batch_size
  print 'num_images', num_images, 'batch_size', batch_size, 'num_batches', num_batches
  return data_source 
开发者ID:futurely,项目名称:deep-camera-relocalization,代码行数:12,代码来源:run_posenet.py

示例2: predict

# 需要导入模块: import data [as 别名]
# 或者: from data import get_data [as 别名]
def predict(model, modules, consts, options):
    print("start predicting,")
    model.eval()
    options["has_y"] = TESTING_DATASET_CLS.HAS_Y
    if options["beam_decoding"]:
        print("using beam search")
    else:
        print("using greedy search")
    rebuild_dir(cfg.cc.BEAM_SUMM_PATH)
    rebuild_dir(cfg.cc.BEAM_GT_PATH)
    rebuild_dir(cfg.cc.GROUND_TRUTH_PATH)
    rebuild_dir(cfg.cc.SUMM_PATH)

    print("loading test set...")
    if options["model_selection"]:
        xy_list = pickle.load(open(cfg.cc.VALIDATE_DATA_PATH + "pj1000.pkl", "rb")) 
    else:
        xy_list = pickle.load(open(cfg.cc.TESTING_DATA_PATH + "test.pkl", "rb")) 
    batch_list, num_files, num_batches = datar.batched(len(xy_list), options, consts)

    print("num_files = ", num_files, ", num_batches = ", num_batches)
    
    running_start = time.time()
    partial_num = 0
    total_num = 0
    si = 0
    for idx_batch in range(num_batches):
        test_idx = batch_list[idx_batch]
        batch_raw = [xy_list[xy_idx] for xy_idx in test_idx]
        batch = datar.get_data(batch_raw, modules, consts, options)
        
        assert len(test_idx) == batch.x.shape[1] # local_batch_size

                    
        word_emb, padding_mask = model.encode(torch.LongTensor(batch.x).to(options["device"]))

        if options["beam_decoding"]:
            for idx_s in range(len(test_idx)):
                if options["copy"]:
                    inputx = (torch.LongTensor(batch.x_ext[:, idx_s]).to(options["device"]), \
                            torch.FloatTensor(batch.x_mask[:, idx_s, :]).to(options["device"]), \
                          word_emb[:, idx_s, :], padding_mask[:, idx_s],\
                          batch.y[:, idx_s], [batch.len_y[idx_s]], batch.original_summarys[idx_s],\
                          batch.max_ext_len, batch.x_ext_words[idx_s])
                else:
                    inputx = (torch.LongTensor(batch.x[:, idx_s]).to(options["device"]), word_emb[:, idx_s, :], padding_mask[:, idx_s],\
                              batch.y[:, idx_s], [batch.len_y[idx_s]], batch.original_summarys[idx_s])

                beam_decode(si, inputx, model, modules, consts, options)
                si += 1
        else:
            pass
            #greedy_decode()

        testing_batch_size = len(test_idx)
        partial_num += testing_batch_size
        total_num += testing_batch_size
        if partial_num >= consts["testing_print_size"]:
            print(total_num, "summs are generated")
            partial_num = 0
    print (si, total_num) 
开发者ID:lipiji,项目名称:TranSummar,代码行数:63,代码来源:main.py


注:本文中的data.get_data方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。