当前位置: 首页>>代码示例>>Python>>正文


Python data_helper.load_data方法代码示例

本文整理汇总了Python中data_helper.load_data方法的典型用法代码示例。如果您正苦于以下问题:Python data_helper.load_data方法的具体用法?Python data_helper.load_data怎么用?Python data_helper.load_data使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在data_helper的用法示例。


在下文中一共展示了data_helper.load_data方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: import data_helper [as 别名]
# 或者: from data_helper import load_data [as 别名]
def main(_):
    FLAGS = tf.app.flags.FLAGS
    pp = pprint.PrettyPrinter()
    FLAGS._parse_flags()
    pp.pprint(FLAGS.__flags)

    # Load Data
    X_train, Q_train, Y_train = data_helper.load_data('train')
    X_test, Q_test, Y_test = data_helper.load_data('valid')

    vocab_size = np.max(X_train) + 1
    print('[?] Vocabulary Size:', vocab_size)

    # Create directories
    if not os.path.exists(FLAGS.ckpt_dir):
        os.makedirs(FLAGS.ckpt_dir)

    timestamp = datetime.now().strftime('%c')
    FLAGS.log_dir = os.path.join(FLAGS.log_dir, timestamp)
    if not os.path.exists(FLAGS.log_dir):
        os.makedirs(FLAGS.log_dir)

    # Train Model
    with tf.Session(config=tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)) as sess, tf.device('/gpu:0'):
        model = AlternatingAttention(FLAGS.batch_size, vocab_size, FLAGS.encoding_dim, FLAGS.embedding_dim, FLAGS.num_glimpses, session=sess)

        if FLAGS.trace: # Trace model for debugging
            train.trace(FLAGS, sess, model, (X_train, Q_train, Y_train))
            return

        saver = tf.train.Saver()

        if FLAGS.restore_file is not None:
            print('[?] Loading variables from checkpoint %s' % FLAGS.restore_file)
            saver.restore(sess, FLAGS.restore_file)

        # Run evaluation
        if FLAGS.evaluate:
            if not FLAGS.restore_file:
                print('Need to specify a restore_file checkpoint to evaluate')
            else:
                test_data = data_helper.load_data('test')
                word2idx, _, _ = data_helper.build_vocab()
                test.run(FLAGS, sess, model, test_data, word2idx)
        else:
            train.run(FLAGS, sess, model,
                    (X_train, Q_train, Y_train),
                    (X_test, Q_test, Y_test),
                    saver) 
开发者ID:nschuc,项目名称:alternating-reader-tf,代码行数:51,代码来源:main.py

示例2: train_step

# 需要导入模块: import data_helper [as 别名]
# 或者: from data_helper import load_data [as 别名]
def train_step():

    print("loading the dataset...")
    config = Config()
    eval_config=Config()
    eval_config.keep_prob=1.0

    train_data,valid_data,test_data=data_helper.load_data(FLAGS.max_len,batch_size=config.batch_size)

    print("begin training")

    # gpu_config=tf.ConfigProto()
    # gpu_config.gpu_options.allow_growth=True
    with tf.Graph().as_default(), tf.Session() as session:
        initializer = tf.random_uniform_initializer(-1*FLAGS.init_scale,1*FLAGS.init_scale)
        with tf.variable_scope("model",reuse=None,initializer=initializer):
            model = RNN_Model(config=config,is_training=True)

        with tf.variable_scope("model",reuse=True,initializer=initializer):
            valid_model = RNN_Model(config=eval_config,is_training=False)
            test_model = RNN_Model(config=eval_config,is_training=False)

        #add summary
        # train_summary_op = tf.merge_summary([model.loss_summary,model.accuracy])
        train_summary_dir = os.path.join(config.out_dir,"summaries","train")
        train_summary_writer =  tf.train.SummaryWriter(train_summary_dir,session.graph)

        # dev_summary_op = tf.merge_summary([valid_model.loss_summary,valid_model.accuracy])
        dev_summary_dir = os.path.join(eval_config.out_dir,"summaries","dev")
        dev_summary_writer =  tf.train.SummaryWriter(dev_summary_dir,session.graph)

        #add checkpoint
        checkpoint_dir = os.path.abspath(os.path.join(config.out_dir, "checkpoints"))
        checkpoint_prefix = os.path.join(checkpoint_dir, "model")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        saver = tf.train.Saver(tf.all_variables())


        tf.initialize_all_variables().run()
        global_steps=1
        begin_time=int(time.time())

        for i in range(config.num_epoch):
            print("the %d epoch training..."%(i+1))
            lr_decay = config.lr_decay ** max(i-config.max_decay_epoch,0.0)
            model.assign_new_lr(session,config.lr*lr_decay)
            global_steps=run_epoch(model,session,train_data,global_steps,valid_model,valid_data,train_summary_writer,dev_summary_writer)

            if i% config.checkpoint_every==0:
                path = saver.save(session,checkpoint_prefix,global_steps)
                print("Saved model chechpoint to{}\n".format(path))

        print("the train is finished")
        end_time=int(time.time())
        print("training takes %d seconds already\n"%(end_time-begin_time))
        test_accuracy=evaluate(test_model,session,test_data)
        print("the test data accuracy is %f"%test_accuracy)
        print("program end!") 
开发者ID:luchi007,项目名称:RNN_Text_Classify,代码行数:61,代码来源:train_rnn_classify.py


注:本文中的data_helper.load_data方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。