本文整理匯總了Python中data_helper.batch_iter方法的典型用法代碼示例。如果您正苦於以下問題:Python data_helper.batch_iter方法的具體用法?Python data_helper.batch_iter怎麽用?Python data_helper.batch_iter使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類data_helper
的用法示例。
在下文中一共展示了data_helper.batch_iter方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: run_epoch
# 需要導入模塊: import data_helper [as 別名]
# 或者: from data_helper import batch_iter [as 別名]
def run_epoch(model,session,data,global_steps,valid_model,valid_data,train_summary_writer,valid_summary_writer=None):
for step, (x,y,mask_x) in enumerate(data_helper.batch_iter(data,batch_size=FLAGS.batch_size)):
feed_dict={}
feed_dict[model.input_data]=x
feed_dict[model.target]=y
feed_dict[model.mask_x]=mask_x
model.assign_new_batch_size(session,len(x))
fetches = [model.cost,model.accuracy,model.train_op,model.summary]
state = session.run(model._initial_state)
for i , (c,h) in enumerate(model._initial_state):
feed_dict[c]=state[i].c
feed_dict[h]=state[i].h
cost,accuracy,_,summary = session.run(fetches,feed_dict)
train_summary_writer.add_summary(summary,global_steps)
train_summary_writer.flush()
valid_accuracy=evaluate(valid_model,session,valid_data,global_steps,valid_summary_writer)
if(global_steps%100==0):
print("the %i step, train cost is: %f and the train accuracy is %f and the valid accuracy is %f"%(global_steps,cost,accuracy,valid_accuracy))
global_steps+=1
return global_steps
示例2: evaluate
# 需要導入模塊: import data_helper [as 別名]
# 或者: from data_helper import batch_iter [as 別名]
def evaluate(rnn, sess, x, y):
"""在其他數據集上評估模型的準確率"""
data_len = len(x)
total_loss = 0.0
total_acc = 0.0
for x_batch, y_batch in batch_iter(x, y):
batch_len = len(x_batch)
feed_dict = {
rnn.input_x: x_batch,
rnn.input_y: y_batch,
rnn.keep_prob: 1.0
}
loss, acc = sess.run([rnn.loss, rnn.acc], feed_dict)
total_loss += loss * batch_len
total_acc += acc * batch_len
return total_loss / data_len, total_acc / data_len
示例3: evaluate
# 需要導入模塊: import data_helper [as 別名]
# 或者: from data_helper import batch_iter [as 別名]
def evaluate(model,session,data,global_steps=None,summary_writer=None):
correct_num=0
total_num=len(data[0])
for step, (x,y,mask_x) in enumerate(data_helper.batch_iter(data,batch_size=FLAGS.batch_size)):
fetches = model.correct_num
feed_dict={}
feed_dict[model.input_data]=x
feed_dict[model.target]=y
feed_dict[model.mask_x]=mask_x
model.assign_new_batch_size(session,len(x))
state = session.run(model._initial_state)
for i , (c,h) in enumerate(model._initial_state):
feed_dict[c]=state[i].c
feed_dict[h]=state[i].h
count=session.run(fetches,feed_dict)
correct_num+=count
accuracy=float(correct_num)/total_num
dev_summary = tf.scalar_summary('dev_accuracy',accuracy)
dev_summary = session.run(dev_summary)
if summary_writer:
summary_writer.add_summary(dev_summary,global_steps)
summary_writer.flush()
return accuracy