本文整理汇总了Python中utils.TextLoader.tensor方法的典型用法代码示例。如果您正苦于以下问题:Python TextLoader.tensor方法的具体用法?Python TextLoader.tensor怎么用?Python TextLoader.tensor使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类utils.TextLoader
的用法示例。
在下文中一共展示了TextLoader.tensor方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: cross_validation
# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import tensor [as 别名]
def cross_validation(args):
data_loader = TextLoader(args.utils_dir, args.data_path, args.batch_size, args.seq_length, None, None)
args.vocab_size = data_loader.vocab_size
args.label_size = data_loader.label_size
with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
pickle.dump(args, f)
with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
pickle.dump((data_loader.chars, data_loader.vocab), f)
with open(os.path.join(args.save_dir, 'labels.pkl'), 'wb') as f:
pickle.dump(data_loader.labels, f)
data = data_loader.tensor.copy()
np.random.shuffle(data)
data_list = np.array_split(data, 10, axis=0)
model = Model(args)
accuracy_list = []
with tf.Session() as sess:
for n in range(10):
init = tf.initialize_all_variables()
sess.run(init)
saver = tf.train.Saver(tf.all_variables())
test_data = data_list[n].copy()
train_data = np.concatenate(map(lambda i: data_list[i], [j for j in range(10) if j!=n]), axis=0)
data_loader.tensor = train_data
for e in range(args.num_epochs):
sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
data_loader.reset_batch_pointer()
for b in range(data_loader.num_batches):
start = time.time()
state = model.initial_state.eval()
x, y = data_loader.next_batch()
feed = {model.input_data: x, model.targets: y, model.initial_state: state}
train_loss, state, _, accuracy = sess.run([model.cost, model.final_state, model.optimizer, model.accuracy], feed_dict=feed)
end = time.time()
print '{}/{} (epoch {}), train_loss = {:.3f}, accuracy = {:.3f}, time/batch = {:.3f}'\
.format(e * data_loader.num_batches + b + 1,
args.num_epochs * data_loader.num_batches,
e + 1,
train_loss,
accuracy,
end - start)
if (e*data_loader.num_batches+b+1) % args.save_every == 0 \
or (e==args.num_epochs-1 and b==data_loader.num_batches-1):
checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=e*data_loader.num_batches+b+1)
print 'model saved to {}'.format(checkpoint_path)
n_chunks = len(test_data) / args.batch_size
if len(test_data) % args.batch_size:
n_chunks += 1
test_data_list = np.array_split(test_data, n_chunks, axis=0)
correct_total = 0.0
num_total = 0.0
for m in range(n_chunks):
start = time.time()
x = test_data_list[m][:, :-1]
y = test_data_list[m][:, -1]
results = model.predict_class(sess, x)
correct_num = np.sum(results==y)
end = time.time()
correct_total += correct_num
num_total += len(x)
accuracy_total = correct_total / num_total
accuracy_list.append(accuracy_total)
print 'total_num = {}, total_accuracy = {:.6f}'.format(int(num_total), accuracy_total)
accuracy_average = np.average(accuracy_list)
print 'The average accuracy of cross_validation is {}'.format(accuracy_average)