当前位置: 首页>>代码示例>>Python>>正文


Python dataloader.DataLoader方法代码示例

本文整理汇总了Python中dataloader.DataLoader方法的典型用法代码示例。如果您正苦于以下问题:Python dataloader.DataLoader方法的具体用法?Python dataloader.DataLoader怎么用?Python dataloader.DataLoader使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在dataloader的用法示例。


在下文中一共展示了dataloader.DataLoader方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: train

# 需要导入模块: import dataloader [as 别名]
# 或者: from dataloader import DataLoader [as 别名]
def train(args):
    logger = logging.getLogger("QANet")
    logger.info("====== training ======")

    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)

    dataloader = DataLoader(args.max_p_num, args.max_p_len, args.max_q_len, args.max_ch_len,
                          args.train_files, args.dev_files)

    logger.info('Converting text into ids...')
    dataloader.convert_to_ids(vocab)

    logger.info('Initialize the model...')
    model = Model(vocab, args)

    logger.info('Training the model...')
    model.train(dataloader, args.epochs, args.batch_size, save_dir=args.model_dir, save_prefix=args.algo, dropout=args.dropout)

    logger.info('====== Done with model training! ======') 
开发者ID:SeanLee97,项目名称:QANet_dureader,代码行数:23,代码来源:cli.py

示例2: evaluate

# 需要导入模块: import dataloader [as 别名]
# 或者: from dataloader import DataLoader [as 别名]
def evaluate(args):
    logger = logging.getLogger("QANet")
    logger.info("====== evaluating ======")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)

    assert len(args.dev_files) > 0, 'No dev files are provided.'
    dataloader = DataLoader(args.max_p_num, args.max_p_len, args.max_q_len, args.max_ch_len, dev_files=args.dev_files)

    logger.info('Converting text into ids...')
    dataloader.convert_to_ids(vocab)

    logger.info('Restoring the model...')
    model = Model(vocab, args)
    model.restore(args.model_dir, args.algo)
    logger.info('Evaluating the model on dev set...')
    dev_batches = dataloader.next_batch('dev', args.batch_size, vocab.get_word_id(vocab.pad_token), vocab.get_char_id(vocab.pad_token), shuffle=False)

    dev_loss, dev_bleu_rouge = model.evaluate(
        dev_batches, result_dir=args.result_dir, result_prefix='dev.predicted')

    logger.info('Loss on dev set: {}'.format(dev_loss))
    logger.info('Result on dev set: {}'.format(dev_bleu_rouge))
    logger.info('Predicted answers are saved to {}'.format(os.path.join(args.result_dir))) 
开发者ID:SeanLee97,项目名称:QANet_dureader,代码行数:27,代码来源:cli.py

示例3: predict

# 需要导入模块: import dataloader [as 别名]
# 或者: from dataloader import DataLoader [as 别名]
def predict(args):
    logger = logging.getLogger("QANet")

    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)

    assert len(args.test_files) > 0, 'No test files are provided.'
    dataloader = DataLoader(args.max_p_num, args.max_p_len, args.max_q_len, args.max_ch_len, 
                          test_files=args.test_files)

    logger.info('Converting text into ids...')
    dataloader.convert_to_ids(vocab)
    logger.info('Restoring the model...')

    model = Model(vocab, args)
    model.restore(args.model_dir, args.algo)
    logger.info('Predicting answers for test set...')
    test_batches = dataloader.next_batch('test', args.batch_size, vocab.get_word_id(vocab.pad_token), vocab.get_char_id(vocab.pad_token), shuffle=False)

    model.evaluate(test_batches,
                      result_dir=args.result_dir, result_prefix='test.predicted') 
开发者ID:SeanLee97,项目名称:QANet_dureader,代码行数:24,代码来源:cli.py

示例4: main

# 需要导入模块: import dataloader [as 别名]
# 或者: from dataloader import DataLoader [as 别名]
def main():
    # Parse the CLI arguments.
    args = parser.parse_args()

    # create directory for saving trained models.
    if not os.path.exists('models'):
        os.makedirs('models')

    # Create the tensorflow dataset.
    ds = DataLoader(args.image_dir, args.hr_size).dataset(args.batch_size)

    # Initialize the GAN object.
    gan = FastSRGAN(args)

    # Define the directory for saving pretrainig loss tensorboard summary.
    pretrain_summary_writer = tf.summary.create_file_writer('logs/pretrain')

    # Run pre-training.
    pretrain_generator(gan, ds, pretrain_summary_writer)

    # Define the directory for saving the SRGAN training tensorbaord summary.
    train_summary_writer = tf.summary.create_file_writer('logs/train')

    # Run training.
    for _ in range(args.epochs):
        train(gan, ds, args.save_iter, train_summary_writer) 
开发者ID:HasnainRaz,项目名称:Fast-SRGAN,代码行数:28,代码来源:main.py

示例5: test

# 需要导入模块: import dataloader [as 别名]
# 或者: from dataloader import DataLoader [as 别名]
def test(model_dict, using_cuda=True):
    if using_cuda:
        net = Net().cuda()
    else:
        net = Net()
    net.load_state_dict(torch.load(model_dict))
    dataset = dataloader.DataLoader("test_set.pkl", batch_size=1, using_cuda=using_cuda)
    count = 0
    for i, batch in enumerate(dataset):
        X = batch["feature"]
        y = batch["class"]
        y_pred, _ = net(X)
        p, idx = torch.max(y_pred.data, dim=1)
        count += torch.sum(torch.eq(idx.cpu(), y.data.cpu()))
    print("accuracy: %f"%(count / dataset.num)) 
开发者ID:fastnlp,项目名称:fastNLP,代码行数:17,代码来源:example.py

示例6: test

# 需要导入模块: import dataloader [as 别名]
# 或者: from dataloader import DataLoader [as 别名]
def test(args):
    print('...Building inputs')
    tf.reset_default_graph()

    print('...Connecting data io and preprocessing')
    with tf.device("/cpu:0"):
        with tf.name_scope("IO"):
            test_data = DataLoader(args.test_file, 'test', args.batch_size,
                                    args.height, args.jitter, shuffle=False)
            args.n_classes = test_data.n_classes
            args.data_size = test_data.data_size
            print("Found {} test examples".format(args.data_size))

            test_iterator = test_data.data.make_initializable_iterator()
            test_inputs, test_targets = test_iterator.get_next()
            test_inputs.set_shape([args.batch_size, args.height, args.width, args.depth, 1])
            test_init_op = test_iterator.make_initializer(test_data.data)
    
    # Outputs
    print('...Constructing model')
    with tf.get_default_graph().as_default(): 
        with tf.variable_scope("model", reuse=False):
            model = GVGG(test_inputs, False, args)
            test_logits = model.pred_logits
            test_preds = tf.nn.softmax(test_logits)

            # Prediction loss
            print("...Building metrics")
            preds = tf.to_int32(tf.argmax(test_preds, 1))
            test_accuracy = tf.contrib.metrics.accuracy(preds, test_targets)
            # HACK: Rotation averaging is brittle.
            preds_rot = tf.to_int32(tf.argmax(tf.reduce_mean(test_preds, 0)))
            test_targets_rot = test_targets[0]
            test_accuracy_rot = tf.contrib.metrics.accuracy(preds_rot, test_targets_rot)
    
    with tf.Session() as sess:
        # Load pretrained model, ignoring final layer
        print('...Restore variables')
        tf.global_variables_initializer().run()
        restorer = tf.train.Saver()
        model_path = tf.train.latest_checkpoint(args.save_dir)
        restorer.restore(sess, model_path)

        accuracies = []
        accuracies_rotavg = []
        print("...Testing")

        sess.run([test_init_op])
        for i in range(args.data_size // args.batch_size):
            tacc, tacc_rotavg = sess.run([test_accuracy, test_accuracy_rot])

            accuracies.append(tacc)
            accuracies_rotavg.append(tacc_rotavg)

            sys.stdout.write("[{} | {}] Running acc: {:0.4f}, Running rot acc: {:0.4f}\r".format(i*args.batch_size, args.data_size, np.mean(accuracies), np.mean(accuracies_rotavg)))
            sys.stdout.flush()
            
        print()
        print("Test accuracy: {:04f}".format(np.mean(accuracies)))
        print("Test accuracy rot avg: {:04f}".format(np.mean(accuracies_rotavg)))
        print() 
开发者ID:deworrall92,项目名称:cubenet,代码行数:63,代码来源:test.py

示例7: prepro

# 需要导入模块: import dataloader [as 别名]
# 或者: from dataloader import DataLoader [as 别名]
def prepro(args):
    logger = logging.getLogger("QANet")
    logger.info("====== preprocessing ======")
    logger.info('Checking the data files...')
    for data_path in args.train_files + args.dev_files + args.test_files:
        assert os.path.exists(data_path), '{} file does not exist.'.format(data_path)

    logger.info('Preparing the directories...')
    for dir_path in [args.vocab_dir, args.model_dir, args.result_dir, args.summary_dir]:
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)

    logger.info('Building vocabulary...')
    dataloader = DataLoader(args.max_p_num, args.max_p_len, args.max_q_len, args.max_ch_len, 
                          args.train_files, args.dev_files, args.test_files)

    vocab = Vocab(lower=True)
    for word in dataloader.word_iter('train'):
        vocab.add_word(word)
        [vocab.add_char(ch) for ch in word]

    unfiltered_vocab_size = vocab.word_size()
    vocab.filter_words_by_cnt(min_cnt=2)
    filtered_num = unfiltered_vocab_size - vocab.word_size()
    logger.info('After filter {} tokens, the final vocab size is {}, char size is{}'.format(filtered_num,
                                                                            vocab.word_size(), vocab.char_size()))

    unfiltered_vocab_char_size = vocab.char_size()
    vocab.filter_chars_by_cnt(min_cnt=2)
    filtered_char_num = unfiltered_vocab_char_size - vocab.char_size()
    logger.info('After filter {} chars, the final char vocab size is {}'.format(filtered_char_num,
                                                                            vocab.char_size()))

    logger.info('Assigning embeddings...')
    if args.pretrained_word_path is not None:
        vocab.load_pretrained_word_embeddings(args.pretrained_word_path)
    else:
        vocab.randomly_init_word_embeddings(args.word_embed_size)
    
    if args.pretrained_char_path is not None:
        vocab.load_pretrained_char_embeddings(args.pretrained_char_path)
    else:
        vocab.randomly_init_char_embeddings(args.char_embed_size)

    logger.info('Saving vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'wb') as fout:
        pickle.dump(vocab, fout)

    logger.info('====== Done with preparing! ======') 
开发者ID:SeanLee97,项目名称:QANet_dureader,代码行数:51,代码来源:cli.py


注:本文中的dataloader.DataLoader方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。