當前位置: 首頁>>代碼示例>>Python>>正文


Python TextLoader.tensor方法代碼示例

本文整理匯總了Python中utils.TextLoader.tensor方法的典型用法代碼示例。如果您正苦於以下問題:Python TextLoader.tensor方法的具體用法?Python TextLoader.tensor怎麽用?Python TextLoader.tensor使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在utils.TextLoader的用法示例。


在下文中一共展示了TextLoader.tensor方法的1個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: cross_validation

# 需要導入模塊: from utils import TextLoader [as 別名]
# 或者: from utils.TextLoader import tensor [as 別名]
def cross_validation(args):
    data_loader = TextLoader(args.utils_dir, args.data_path, args.batch_size, args.seq_length, None, None)
    args.vocab_size = data_loader.vocab_size
    args.label_size = data_loader.label_size

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        pickle.dump((data_loader.chars, data_loader.vocab), f)
    with open(os.path.join(args.save_dir, 'labels.pkl'), 'wb') as f:
        pickle.dump(data_loader.labels, f)

    data = data_loader.tensor.copy()
    np.random.shuffle(data)
    data_list = np.array_split(data, 10, axis=0)

    model = Model(args)
    accuracy_list = []

    with tf.Session() as sess:
        for n in range(10):
            init = tf.initialize_all_variables()
            sess.run(init)
            saver = tf.train.Saver(tf.all_variables())

            test_data = data_list[n].copy()
            train_data = np.concatenate(map(lambda i: data_list[i], [j for j in range(10) if j!=n]), axis=0)
            data_loader.tensor = train_data

            for e in range(args.num_epochs):
                sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
                data_loader.reset_batch_pointer()

                for b in range(data_loader.num_batches):
                    start = time.time()
                    state = model.initial_state.eval()
                    x, y = data_loader.next_batch()
                    feed = {model.input_data: x, model.targets: y, model.initial_state: state}
                    train_loss, state, _, accuracy = sess.run([model.cost, model.final_state, model.optimizer, model.accuracy], feed_dict=feed)
                    end = time.time()
                    print '{}/{} (epoch {}), train_loss = {:.3f}, accuracy = {:.3f}, time/batch = {:.3f}'\
                        .format(e * data_loader.num_batches + b + 1,
                                args.num_epochs * data_loader.num_batches,
                                e + 1,
                                train_loss,
                                accuracy,
                                end - start)
                    if (e*data_loader.num_batches+b+1) % args.save_every == 0 \
                        or (e==args.num_epochs-1 and b==data_loader.num_batches-1):
                        checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                        saver.save(sess, checkpoint_path, global_step=e*data_loader.num_batches+b+1)
                        print 'model saved to {}'.format(checkpoint_path)

            n_chunks = len(test_data) / args.batch_size
            if len(test_data) % args.batch_size:
                n_chunks += 1
            test_data_list = np.array_split(test_data, n_chunks, axis=0)

            correct_total = 0.0
            num_total = 0.0
            for m in range(n_chunks):
                start = time.time()
                x = test_data_list[m][:, :-1]
                y = test_data_list[m][:, -1]
                results = model.predict_class(sess, x)
                correct_num = np.sum(results==y)
                end = time.time()

                correct_total += correct_num
                num_total += len(x)

            accuracy_total = correct_total / num_total
            accuracy_list.append(accuracy_total)
            print 'total_num = {}, total_accuracy = {:.6f}'.format(int(num_total), accuracy_total)

    accuracy_average = np.average(accuracy_list)
    print 'The average accuracy of cross_validation is {}'.format(accuracy_average)
開發者ID:12190143,項目名稱:RNN-Classification,代碼行數:79,代碼來源:cross_validation.py


注:本文中的utils.TextLoader.tensor方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。