本文整理匯總了Python中data_helper.load_data方法的典型用法代碼示例。如果您正苦於以下問題:Python data_helper.load_data方法的具體用法?Python data_helper.load_data怎麽用?Python data_helper.load_data使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類data_helper
的用法示例。
在下文中一共展示了data_helper.load_data方法的2個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: main
# 需要導入模塊: import data_helper [as 別名]
# 或者: from data_helper import load_data [as 別名]
def main(_):
FLAGS = tf.app.flags.FLAGS
pp = pprint.PrettyPrinter()
FLAGS._parse_flags()
pp.pprint(FLAGS.__flags)
# Load Data
X_train, Q_train, Y_train = data_helper.load_data('train')
X_test, Q_test, Y_test = data_helper.load_data('valid')
vocab_size = np.max(X_train) + 1
print('[?] Vocabulary Size:', vocab_size)
# Create directories
if not os.path.exists(FLAGS.ckpt_dir):
os.makedirs(FLAGS.ckpt_dir)
timestamp = datetime.now().strftime('%c')
FLAGS.log_dir = os.path.join(FLAGS.log_dir, timestamp)
if not os.path.exists(FLAGS.log_dir):
os.makedirs(FLAGS.log_dir)
# Train Model
with tf.Session(config=tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)) as sess, tf.device('/gpu:0'):
model = AlternatingAttention(FLAGS.batch_size, vocab_size, FLAGS.encoding_dim, FLAGS.embedding_dim, FLAGS.num_glimpses, session=sess)
if FLAGS.trace: # Trace model for debugging
train.trace(FLAGS, sess, model, (X_train, Q_train, Y_train))
return
saver = tf.train.Saver()
if FLAGS.restore_file is not None:
print('[?] Loading variables from checkpoint %s' % FLAGS.restore_file)
saver.restore(sess, FLAGS.restore_file)
# Run evaluation
if FLAGS.evaluate:
if not FLAGS.restore_file:
print('Need to specify a restore_file checkpoint to evaluate')
else:
test_data = data_helper.load_data('test')
word2idx, _, _ = data_helper.build_vocab()
test.run(FLAGS, sess, model, test_data, word2idx)
else:
train.run(FLAGS, sess, model,
(X_train, Q_train, Y_train),
(X_test, Q_test, Y_test),
saver)
示例2: train_step
# 需要導入模塊: import data_helper [as 別名]
# 或者: from data_helper import load_data [as 別名]
def train_step():
print("loading the dataset...")
config = Config()
eval_config=Config()
eval_config.keep_prob=1.0
train_data,valid_data,test_data=data_helper.load_data(FLAGS.max_len,batch_size=config.batch_size)
print("begin training")
# gpu_config=tf.ConfigProto()
# gpu_config.gpu_options.allow_growth=True
with tf.Graph().as_default(), tf.Session() as session:
initializer = tf.random_uniform_initializer(-1*FLAGS.init_scale,1*FLAGS.init_scale)
with tf.variable_scope("model",reuse=None,initializer=initializer):
model = RNN_Model(config=config,is_training=True)
with tf.variable_scope("model",reuse=True,initializer=initializer):
valid_model = RNN_Model(config=eval_config,is_training=False)
test_model = RNN_Model(config=eval_config,is_training=False)
#add summary
# train_summary_op = tf.merge_summary([model.loss_summary,model.accuracy])
train_summary_dir = os.path.join(config.out_dir,"summaries","train")
train_summary_writer = tf.train.SummaryWriter(train_summary_dir,session.graph)
# dev_summary_op = tf.merge_summary([valid_model.loss_summary,valid_model.accuracy])
dev_summary_dir = os.path.join(eval_config.out_dir,"summaries","dev")
dev_summary_writer = tf.train.SummaryWriter(dev_summary_dir,session.graph)
#add checkpoint
checkpoint_dir = os.path.abspath(os.path.join(config.out_dir, "checkpoints"))
checkpoint_prefix = os.path.join(checkpoint_dir, "model")
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
saver = tf.train.Saver(tf.all_variables())
tf.initialize_all_variables().run()
global_steps=1
begin_time=int(time.time())
for i in range(config.num_epoch):
print("the %d epoch training..."%(i+1))
lr_decay = config.lr_decay ** max(i-config.max_decay_epoch,0.0)
model.assign_new_lr(session,config.lr*lr_decay)
global_steps=run_epoch(model,session,train_data,global_steps,valid_model,valid_data,train_summary_writer,dev_summary_writer)
if i% config.checkpoint_every==0:
path = saver.save(session,checkpoint_prefix,global_steps)
print("Saved model chechpoint to{}\n".format(path))
print("the train is finished")
end_time=int(time.time())
print("training takes %d seconds already\n"%(end_time-begin_time))
test_accuracy=evaluate(test_model,session,test_data)
print("the test data accuracy is %f"%test_accuracy)
print("program end!")