本文整理汇总了Python中data.get_data方法的典型用法代码示例。如果您正苦于以下问题:Python data.get_data方法的具体用法?Python data.get_data怎么用?Python data.get_data使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类data
的用法示例。
在下文中一共展示了data.get_data方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: load_data
# 需要导入模块: import data [as 别名]
# 或者: from data import get_data [as 别名]
def load_data(data_dir, data_file, max_num_images=-1):
data_path = os.path.join(data_dir, data_file)
if max_num_images >= 0:
data_source = get_data(data_path, data_dir, max_num_images)
else:
data_source = get_data(data_path, data_dir)
num_images = len(data_source.images)
num_batches = (num_images + batch_size - 1) / batch_size
print 'num_images', num_images, 'batch_size', batch_size, 'num_batches', num_batches
return data_source
示例2: predict
# 需要导入模块: import data [as 别名]
# 或者: from data import get_data [as 别名]
def predict(model, modules, consts, options):
print("start predicting,")
model.eval()
options["has_y"] = TESTING_DATASET_CLS.HAS_Y
if options["beam_decoding"]:
print("using beam search")
else:
print("using greedy search")
rebuild_dir(cfg.cc.BEAM_SUMM_PATH)
rebuild_dir(cfg.cc.BEAM_GT_PATH)
rebuild_dir(cfg.cc.GROUND_TRUTH_PATH)
rebuild_dir(cfg.cc.SUMM_PATH)
print("loading test set...")
if options["model_selection"]:
xy_list = pickle.load(open(cfg.cc.VALIDATE_DATA_PATH + "pj1000.pkl", "rb"))
else:
xy_list = pickle.load(open(cfg.cc.TESTING_DATA_PATH + "test.pkl", "rb"))
batch_list, num_files, num_batches = datar.batched(len(xy_list), options, consts)
print("num_files = ", num_files, ", num_batches = ", num_batches)
running_start = time.time()
partial_num = 0
total_num = 0
si = 0
for idx_batch in range(num_batches):
test_idx = batch_list[idx_batch]
batch_raw = [xy_list[xy_idx] for xy_idx in test_idx]
batch = datar.get_data(batch_raw, modules, consts, options)
assert len(test_idx) == batch.x.shape[1] # local_batch_size
word_emb, padding_mask = model.encode(torch.LongTensor(batch.x).to(options["device"]))
if options["beam_decoding"]:
for idx_s in range(len(test_idx)):
if options["copy"]:
inputx = (torch.LongTensor(batch.x_ext[:, idx_s]).to(options["device"]), \
torch.FloatTensor(batch.x_mask[:, idx_s, :]).to(options["device"]), \
word_emb[:, idx_s, :], padding_mask[:, idx_s],\
batch.y[:, idx_s], [batch.len_y[idx_s]], batch.original_summarys[idx_s],\
batch.max_ext_len, batch.x_ext_words[idx_s])
else:
inputx = (torch.LongTensor(batch.x[:, idx_s]).to(options["device"]), word_emb[:, idx_s, :], padding_mask[:, idx_s],\
batch.y[:, idx_s], [batch.len_y[idx_s]], batch.original_summarys[idx_s])
beam_decode(si, inputx, model, modules, consts, options)
si += 1
else:
pass
#greedy_decode()
testing_batch_size = len(test_idx)
partial_num += testing_batch_size
total_num += testing_batch_size
if partial_num >= consts["testing_print_size"]:
print(total_num, "summs are generated")
partial_num = 0
print (si, total_num)