當前位置: 首頁>>代碼示例>>Python>>正文


Python prediction_input.build_tfrecord_input方法代碼示例

本文整理匯總了Python中prediction_input.build_tfrecord_input方法的典型用法代碼示例。如果您正苦於以下問題:Python prediction_input.build_tfrecord_input方法的具體用法?Python prediction_input.build_tfrecord_input怎麽用?Python prediction_input.build_tfrecord_input使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在prediction_input的用法示例。


在下文中一共展示了prediction_input.build_tfrecord_input方法的1個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: main

# 需要導入模塊: import prediction_input [as 別名]
# 或者: from prediction_input import build_tfrecord_input [as 別名]
def main(unused_argv):

  print('Constructing models and inputs.')
  with tf.variable_scope('model', reuse=None) as training_scope:
    images, actions, states = build_tfrecord_input(training=True)
    model = Model(images, actions, states, FLAGS.sequence_length,
                  prefix='train')

  with tf.variable_scope('val_model', reuse=None):
    val_images, val_actions, val_states = build_tfrecord_input(training=False)
    val_model = Model(val_images, val_actions, val_states,
                      FLAGS.sequence_length, training_scope, prefix='val')

  print('Constructing saver.')
  # Make saver.
  saver = tf.train.Saver(
      tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=0)

  # Make training session.
  sess = tf.InteractiveSession()
  summary_writer = tf.summary.FileWriter(
      FLAGS.event_log_dir, graph=sess.graph, flush_secs=10)

  if FLAGS.pretrained_model:
    saver.restore(sess, FLAGS.pretrained_model)

  tf.train.start_queue_runners(sess)
  sess.run(tf.global_variables_initializer())

  tf.logging.info('iteration number, cost')

  # Run training.
  for itr in range(FLAGS.num_iterations):
    # Generate new batch of data.
    feed_dict = {model.iter_num: np.float32(itr),
                 model.lr: FLAGS.learning_rate}
    cost, _, summary_str = sess.run([model.loss, model.train_op, model.summ_op],
                                    feed_dict)

    # Print info: iteration #, cost.
    tf.logging.info(str(itr) + ' ' + str(cost))

    if (itr) % VAL_INTERVAL == 2:
      # Run through validation set.
      feed_dict = {val_model.lr: 0.0,
                   val_model.iter_num: np.float32(itr)}
      _, val_summary_str = sess.run([val_model.train_op, val_model.summ_op],
                                     feed_dict)
      summary_writer.add_summary(val_summary_str, itr)

    if (itr) % SAVE_INTERVAL == 2:
      tf.logging.info('Saving model.')
      saver.save(sess, FLAGS.output_dir + '/model' + str(itr))

    if (itr) % SUMMARY_INTERVAL:
      summary_writer.add_summary(summary_str, itr)

  tf.logging.info('Saving model.')
  saver.save(sess, FLAGS.output_dir + '/model')
  tf.logging.info('Training complete')
  tf.logging.flush() 
開發者ID:ymao1993,項目名稱:HumanRecognition,代碼行數:63,代碼來源:prediction_train.py


注:本文中的prediction_input.build_tfrecord_input方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。