本文整理汇总了Python中data_helper.batch_iter方法的典型用法代码示例。如果您正苦于以下问题:Python data_helper.batch_iter方法的具体用法?Python data_helper.batch_iter怎么用?Python data_helper.batch_iter使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类data_helper
的用法示例。
在下文中一共展示了data_helper.batch_iter方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: run_epoch
# 需要导入模块: import data_helper [as 别名]
# 或者: from data_helper import batch_iter [as 别名]
def run_epoch(model,session,data,global_steps,valid_model,valid_data,train_summary_writer,valid_summary_writer=None):
for step, (x,y,mask_x) in enumerate(data_helper.batch_iter(data,batch_size=FLAGS.batch_size)):
feed_dict={}
feed_dict[model.input_data]=x
feed_dict[model.target]=y
feed_dict[model.mask_x]=mask_x
model.assign_new_batch_size(session,len(x))
fetches = [model.cost,model.accuracy,model.train_op,model.summary]
state = session.run(model._initial_state)
for i , (c,h) in enumerate(model._initial_state):
feed_dict[c]=state[i].c
feed_dict[h]=state[i].h
cost,accuracy,_,summary = session.run(fetches,feed_dict)
train_summary_writer.add_summary(summary,global_steps)
train_summary_writer.flush()
valid_accuracy=evaluate(valid_model,session,valid_data,global_steps,valid_summary_writer)
if(global_steps%100==0):
print("the %i step, train cost is: %f and the train accuracy is %f and the valid accuracy is %f"%(global_steps,cost,accuracy,valid_accuracy))
global_steps+=1
return global_steps
示例2: evaluate
# 需要导入模块: import data_helper [as 别名]
# 或者: from data_helper import batch_iter [as 别名]
def evaluate(rnn, sess, x, y):
"""在其他数据集上评估模型的准确率"""
data_len = len(x)
total_loss = 0.0
total_acc = 0.0
for x_batch, y_batch in batch_iter(x, y):
batch_len = len(x_batch)
feed_dict = {
rnn.input_x: x_batch,
rnn.input_y: y_batch,
rnn.keep_prob: 1.0
}
loss, acc = sess.run([rnn.loss, rnn.acc], feed_dict)
total_loss += loss * batch_len
total_acc += acc * batch_len
return total_loss / data_len, total_acc / data_len
示例3: evaluate
# 需要导入模块: import data_helper [as 别名]
# 或者: from data_helper import batch_iter [as 别名]
def evaluate(model,session,data,global_steps=None,summary_writer=None):
correct_num=0
total_num=len(data[0])
for step, (x,y,mask_x) in enumerate(data_helper.batch_iter(data,batch_size=FLAGS.batch_size)):
fetches = model.correct_num
feed_dict={}
feed_dict[model.input_data]=x
feed_dict[model.target]=y
feed_dict[model.mask_x]=mask_x
model.assign_new_batch_size(session,len(x))
state = session.run(model._initial_state)
for i , (c,h) in enumerate(model._initial_state):
feed_dict[c]=state[i].c
feed_dict[h]=state[i].h
count=session.run(fetches,feed_dict)
correct_num+=count
accuracy=float(correct_num)/total_num
dev_summary = tf.scalar_summary('dev_accuracy',accuracy)
dev_summary = session.run(dev_summary)
if summary_writer:
summary_writer.add_summary(dev_summary,global_steps)
summary_writer.flush()
return accuracy