当前位置: 首页>>代码示例>>Python>>正文


Python TextLoader.tensor方法代码示例

本文整理汇总了Python中utils.TextLoader.tensor方法的典型用法代码示例。如果您正苦于以下问题:Python TextLoader.tensor方法的具体用法?Python TextLoader.tensor怎么用?Python TextLoader.tensor使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在utils.TextLoader的用法示例。


在下文中一共展示了TextLoader.tensor方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: cross_validation

# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import tensor [as 别名]
def cross_validation(args):
    data_loader = TextLoader(args.utils_dir, args.data_path, args.batch_size, args.seq_length, None, None)
    args.vocab_size = data_loader.vocab_size
    args.label_size = data_loader.label_size

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        pickle.dump((data_loader.chars, data_loader.vocab), f)
    with open(os.path.join(args.save_dir, 'labels.pkl'), 'wb') as f:
        pickle.dump(data_loader.labels, f)

    data = data_loader.tensor.copy()
    np.random.shuffle(data)
    data_list = np.array_split(data, 10, axis=0)

    model = Model(args)
    accuracy_list = []

    with tf.Session() as sess:
        for n in range(10):
            init = tf.initialize_all_variables()
            sess.run(init)
            saver = tf.train.Saver(tf.all_variables())

            test_data = data_list[n].copy()
            train_data = np.concatenate(map(lambda i: data_list[i], [j for j in range(10) if j!=n]), axis=0)
            data_loader.tensor = train_data

            for e in range(args.num_epochs):
                sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
                data_loader.reset_batch_pointer()

                for b in range(data_loader.num_batches):
                    start = time.time()
                    state = model.initial_state.eval()
                    x, y = data_loader.next_batch()
                    feed = {model.input_data: x, model.targets: y, model.initial_state: state}
                    train_loss, state, _, accuracy = sess.run([model.cost, model.final_state, model.optimizer, model.accuracy], feed_dict=feed)
                    end = time.time()
                    print '{}/{} (epoch {}), train_loss = {:.3f}, accuracy = {:.3f}, time/batch = {:.3f}'\
                        .format(e * data_loader.num_batches + b + 1,
                                args.num_epochs * data_loader.num_batches,
                                e + 1,
                                train_loss,
                                accuracy,
                                end - start)
                    if (e*data_loader.num_batches+b+1) % args.save_every == 0 \
                        or (e==args.num_epochs-1 and b==data_loader.num_batches-1):
                        checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                        saver.save(sess, checkpoint_path, global_step=e*data_loader.num_batches+b+1)
                        print 'model saved to {}'.format(checkpoint_path)

            n_chunks = len(test_data) / args.batch_size
            if len(test_data) % args.batch_size:
                n_chunks += 1
            test_data_list = np.array_split(test_data, n_chunks, axis=0)

            correct_total = 0.0
            num_total = 0.0
            for m in range(n_chunks):
                start = time.time()
                x = test_data_list[m][:, :-1]
                y = test_data_list[m][:, -1]
                results = model.predict_class(sess, x)
                correct_num = np.sum(results==y)
                end = time.time()

                correct_total += correct_num
                num_total += len(x)

            accuracy_total = correct_total / num_total
            accuracy_list.append(accuracy_total)
            print 'total_num = {}, total_accuracy = {:.6f}'.format(int(num_total), accuracy_total)

    accuracy_average = np.average(accuracy_list)
    print 'The average accuracy of cross_validation is {}'.format(accuracy_average)
开发者ID:12190143,项目名称:RNN-Classification,代码行数:79,代码来源:cross_validation.py


注:本文中的utils.TextLoader.tensor方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。