当前位置: 首页>>代码示例>>Python>>正文


Python data_helper.batch_iter方法代码示例

本文整理汇总了Python中data_helper.batch_iter方法的典型用法代码示例。如果您正苦于以下问题:Python data_helper.batch_iter方法的具体用法?Python data_helper.batch_iter怎么用?Python data_helper.batch_iter使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在data_helper的用法示例。


在下文中一共展示了data_helper.batch_iter方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: run_epoch

# 需要导入模块: import data_helper [as 别名]
# 或者: from data_helper import batch_iter [as 别名]
def run_epoch(model,session,data,global_steps,valid_model,valid_data,train_summary_writer,valid_summary_writer=None):
    for step, (x,y,mask_x) in enumerate(data_helper.batch_iter(data,batch_size=FLAGS.batch_size)):

        feed_dict={}
        feed_dict[model.input_data]=x
        feed_dict[model.target]=y
        feed_dict[model.mask_x]=mask_x
        model.assign_new_batch_size(session,len(x))
        fetches = [model.cost,model.accuracy,model.train_op,model.summary]
        state = session.run(model._initial_state)
        for i , (c,h) in enumerate(model._initial_state):
            feed_dict[c]=state[i].c
            feed_dict[h]=state[i].h
        cost,accuracy,_,summary = session.run(fetches,feed_dict)
        train_summary_writer.add_summary(summary,global_steps)
        train_summary_writer.flush()
        valid_accuracy=evaluate(valid_model,session,valid_data,global_steps,valid_summary_writer)
        if(global_steps%100==0):
            print("the %i step, train cost is: %f and the train accuracy is %f and the valid accuracy is %f"%(global_steps,cost,accuracy,valid_accuracy))
        global_steps+=1

    return global_steps 
开发者ID:luchi007,项目名称:RNN_Text_Classify,代码行数:24,代码来源:train_rnn_classify.py

示例2: evaluate

# 需要导入模块: import data_helper [as 别名]
# 或者: from data_helper import batch_iter [as 别名]
def evaluate(rnn, sess, x, y):
    """在其他数据集上评估模型的准确率"""
    data_len = len(x)
    total_loss = 0.0
    total_acc = 0.0
    for x_batch, y_batch in batch_iter(x, y):
        batch_len = len(x_batch)
        feed_dict = {
            rnn.input_x: x_batch,
            rnn.input_y: y_batch,
            rnn.keep_prob: 1.0
        }
        loss, acc = sess.run([rnn.loss, rnn.acc], feed_dict)
        total_loss += loss * batch_len
        total_acc += acc * batch_len
    return total_loss / data_len, total_acc / data_len 
开发者ID:baiyyang,项目名称:medical-diagnosis-cnn-rnn-rcnn,代码行数:18,代码来源:run_rnn.py

示例3: evaluate

# 需要导入模块: import data_helper [as 别名]
# 或者: from data_helper import batch_iter [as 别名]
def evaluate(model,session,data,global_steps=None,summary_writer=None):


    correct_num=0
    total_num=len(data[0])
    for step, (x,y,mask_x) in enumerate(data_helper.batch_iter(data,batch_size=FLAGS.batch_size)):

         fetches = model.correct_num
         feed_dict={}
         feed_dict[model.input_data]=x
         feed_dict[model.target]=y
         feed_dict[model.mask_x]=mask_x
         model.assign_new_batch_size(session,len(x))
         state = session.run(model._initial_state)
         for i , (c,h) in enumerate(model._initial_state):
            feed_dict[c]=state[i].c
            feed_dict[h]=state[i].h
         count=session.run(fetches,feed_dict)
         correct_num+=count

    accuracy=float(correct_num)/total_num
    dev_summary = tf.scalar_summary('dev_accuracy',accuracy)
    dev_summary = session.run(dev_summary)
    if summary_writer:
        summary_writer.add_summary(dev_summary,global_steps)
        summary_writer.flush()
    return accuracy 
开发者ID:luchi007,项目名称:RNN_Text_Classify,代码行数:29,代码来源:train_rnn_classify.py


注:本文中的data_helper.batch_iter方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。