當前位置: 首頁>>代碼示例>>Python>>正文


Python DataLoader.DataLoader方法代碼示例

本文整理匯總了Python中DataLoader.DataLoader方法的典型用法代碼示例。如果您正苦於以下問題:Python DataLoader.DataLoader方法的具體用法?Python DataLoader.DataLoader怎麽用?Python DataLoader.DataLoader使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在DataLoader的用法示例。


在下文中一共展示了DataLoader.DataLoader方法的5個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: main

# 需要導入模塊: import DataLoader [as 別名]
# 或者: from DataLoader import DataLoader [as 別名]
def main():
    """ Main function """
    # Opptional command line args
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--train", help="train the neural network", action="store_true")
    parser.add_argument(
        "--validate", help="validate the neural network", action="store_true")
    parser.add_argument(
        "--wordbeamsearch", help="use word beam search instead of best path decoding", action="store_true")
    args = parser.parse_args()

    decoderType = DecoderType.BestPath
    if args.wordbeamsearch:
        decoderType = DecoderType.WordBeamSearch

    # Train or validate on Cinnamon dataset
    if args.train or args.validate:
        # Load training data, create TF model
        loader = DataLoader(FilePaths.fnTrain, Model.batchSize,
                            Model.imgSize, Model.maxTextLen, load_aug=True)

        # Execute training or validation
        if args.train:
            model = Model(loader.charList, decoderType)
            train(model, loader)
        elif args.validate:
            model = Model(loader.charList, decoderType, mustRestore=False)
            validate(model, loader)

    # Infer text on test image
    else:
        print(open(FilePaths.fnAccuracy).read())
        model = Model(open(FilePaths.fnCharList).read(),
                      decoderType, mustRestore=False)
        infer(model, FilePaths.fnInfer) 
開發者ID:sushant097,項目名稱:Handwritten-Line-Text-Recognition-using-Deep-Learning-with-Tensorflow,代碼行數:38,代碼來源:main.py

示例2: __init__

# 需要導入模塊: import DataLoader [as 別名]
# 或者: from DataLoader import DataLoader [as 別名]
def __init__(self):
        self.dataLoader = DataLoader()
        self.model = None 
開發者ID:SamVenkatesh,項目名稱:FakeBlock,代碼行數:5,代碼來源:BuildTrainTestCNN.py

示例3: main

# 需要導入模塊: import DataLoader [as 別名]
# 或者: from DataLoader import DataLoader [as 別名]
def main():
	"main function"
	# optional command line args
	parser = argparse.ArgumentParser()
	parser.add_argument('--train', help='train the NN', action='store_true')
	parser.add_argument('--validate', help='validate the NN', action='store_true')
	parser.add_argument('--beamsearch', help='use beam search instead of best path decoding', action='store_true')
	parser.add_argument('--wordbeamsearch', help='use word beam search instead of best path decoding', action='store_true')
	parser.add_argument('--dump', help='dump output of NN to CSV file(s)', action='store_true')

	args = parser.parse_args()

	decoderType = DecoderType.BestPath
	if args.beamsearch:
		decoderType = DecoderType.BeamSearch
	elif args.wordbeamsearch:
		decoderType = DecoderType.WordBeamSearch

	# train or validate on IAM dataset	
	if args.train or args.validate:
		# load training data, create TF model
		loader = DataLoader(FilePaths.fnTrain, Model.batchSize, Model.imgSize, Model.maxTextLen)

		# save characters of model for inference mode
		open(FilePaths.fnCharList, 'w').write(str().join(loader.charList))
		
		# save words contained in dataset into file
		open(FilePaths.fnCorpus, 'w').write(str(' ').join(loader.trainWords + loader.validationWords))

		# execute training or validation
		if args.train:
			model = Model(loader.charList, decoderType)
			train(model, loader)
		elif args.validate:
			model = Model(loader.charList, decoderType, mustRestore=True)
			validate(model, loader)

	# infer text on test image
	else:
		print(open(FilePaths.fnAccuracy).read())
		model = Model(open(FilePaths.fnCharList).read(), decoderType, mustRestore=True, dump=args.dump)
		infer(model, FilePaths.fnInfer) 
開發者ID:githubharald,項目名稱:SimpleHTR,代碼行數:44,代碼來源:main.py

示例4: train

# 需要導入模塊: import DataLoader [as 別名]
# 或者: from DataLoader import DataLoader [as 別名]
def train(sess, preprocessed_data, model):
    # keep track of all input parameters
    write_log(log_file, "####################INPUT PARAMETERS###################")
    for attr in FLAGS.flag_values_dict():
        value = FLAGS.flag_values_dict()[attr]
        write_log(log_file, attr + " = " + str(value))
    write_log(log_file, "#######################################################")

    train_iterator = DataLoader(preprocessed_data.train_set, FLAGS.domain,
                                batch_size=FLAGS.batch_size, shuffle=True, eos=eos, empty=empty)

    k = 0
    record_k = 0
    record_loss_k = 0 
    loss, start_time = 0.0, time.time()
    record_loss = 0.0
    record_copy_loss = 0.0
    record_cov_loss = 0.0

    for _ in range(FLAGS.epoch):
        train_iterator.reset()
        for x in train_iterator:
            model(x, sess, 0)
            k += 1

            #TODO also add to tensorboard
            if k % FLAGS.batch_update == 0:
                this_loss, this_copy_gate_loss, this_cov_loss = model(x, sess, 1)
                record_loss += this_loss
                record_copy_loss += this_copy_gate_loss
                record_cov_loss += this_cov_loss
                record_k += 1
                record_loss_k += 1

                if record_loss_k > 1 and record_loss_k % FLAGS.report_loss == 0:
                    write_log(log_file, "%d : loss = %.3f, copyloss = %.3f, covloss = %.3f" % \
                        (record_k, record_loss / record_loss_k, record_copy_loss / record_loss_k,
                         record_cov_loss / record_loss_k))
                    record_loss = 0.0
                    record_copy_loss = 0.0
                    record_cov_loss = 0.0
                    record_loss_k = 0

                if record_k > 1 and record_k % FLAGS.report == 0:
                    print("Round: ", record_k / FLAGS.report)
                    cost_time = time.time() - start_time
                    write_log(log_file, "%d : time = %.3f " % (record_k // FLAGS.report, cost_time))
                    start_time = time.time()
                    if record_k // FLAGS.report >= 1:
                        # save model
                        saved_model_path_cnt = os.path.join(saved_model_path, 'loads', str(record_k // FLAGS.report))
                        os.makedirs(saved_model_path_cnt, exist_ok=True)
                        model.save(saved_model_path_cnt, sess)

                        results_path_cnt = os.path.join(results_path, 'loads', str(record_k // FLAGS.report))
                        os.makedirs(results_path_cnt, exist_ok=True)
                        validation_result = evaluate(sess, preprocessed_data, model, results_path_cnt, 'valid')
                        write_log(log_file, validation_result) 
開發者ID:czyssrs,項目名稱:Few-Shot-NLG,代碼行數:60,代碼來源:Main.py

示例5: evaluate

# 需要導入模塊: import DataLoader [as 別名]
# 或者: from DataLoader import DataLoader [as 別名]
def evaluate(sess, preprocessed_data, model, ksave_dir, mode='valid'):
    if mode == 'valid':
        gold_path = gold_path_valid
        data_iterator = DataLoader(preprocessed_data.dev_set,
                                    FLAGS.domain, batch_size=FLAGS.batch_size, shuffle=False,
                                    eos=eos, empty=empty)
    else:
        gold_path = gold_path_test
        data_iterator = DataLoader(preprocessed_data.test_set,
                                   FLAGS.domain, batch_size=FLAGS.batch_size, shuffle=False, eos=eos,
                                   empty=empty)

    pred_list = []
    pred_unk = []

    ksave_dir_mode = os.path.join(ksave_dir, mode)
    os.makedirs(ksave_dir_mode, exist_ok=True)

    out_bpe = open(os.path.join(ksave_dir_mode, mode + "_summary_bpe.txt"), "w")
    out_real = open(os.path.join(ksave_dir_mode,  mode + "_summary.clean.txt"), "w")
    pred_path = os.path.join(ksave_dir_mode,  mode + "_pred_summary_")

    k = 0
    for x in tqdm(data_iterator):
        predictions, atts = model.generate(x, sess)

        for summary in np.array(predictions):
            with open(pred_path + str(k), 'w') as sw:
                summary = list(summary)

                if eos in summary:
                    summary = summary[:summary.index(eos)] if summary[0] != eos else [eos]
                real_sum = enc.decode(summary)
                bpe_sum = " ".join([enc.decoder[tmp] for tmp in summary])

                real_sum = real_sum.replace("\n", " ")
                sw.write(real_sum + '\n')
                pred_list.append(real_sum)
                pred_unk.append(bpe_sum)

                out_real.write(real_sum + '\n')
                out_bpe.write(bpe_sum + '\n')

                k += 1

    out_bpe.close()
    out_real.close()

    # new bleu
    bleu_copy = bleu_score(gold_path, os.path.join(ksave_dir_mode,  mode + "_summary.clean.txt"))
    copy_result = "with copy BLEU: %.4f\n" % bleu_copy

    result = copy_result

    return result 
開發者ID:czyssrs,項目名稱:Few-Shot-NLG,代碼行數:57,代碼來源:Main.py


注:本文中的DataLoader.DataLoader方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。