本文整理汇总了Python中prediction_input.build_tfrecord_input方法的典型用法代码示例。如果您正苦于以下问题:Python prediction_input.build_tfrecord_input方法的具体用法?Python prediction_input.build_tfrecord_input怎么用?Python prediction_input.build_tfrecord_input使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类prediction_input
的用法示例。
在下文中一共展示了prediction_input.build_tfrecord_input方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: import prediction_input [as 别名]
# 或者: from prediction_input import build_tfrecord_input [as 别名]
def main(unused_argv):
print('Constructing models and inputs.')
with tf.variable_scope('model', reuse=None) as training_scope:
images, actions, states = build_tfrecord_input(training=True)
model = Model(images, actions, states, FLAGS.sequence_length,
prefix='train')
with tf.variable_scope('val_model', reuse=None):
val_images, val_actions, val_states = build_tfrecord_input(training=False)
val_model = Model(val_images, val_actions, val_states,
FLAGS.sequence_length, training_scope, prefix='val')
print('Constructing saver.')
# Make saver.
saver = tf.train.Saver(
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=0)
# Make training session.
sess = tf.InteractiveSession()
summary_writer = tf.summary.FileWriter(
FLAGS.event_log_dir, graph=sess.graph, flush_secs=10)
if FLAGS.pretrained_model:
saver.restore(sess, FLAGS.pretrained_model)
tf.train.start_queue_runners(sess)
sess.run(tf.global_variables_initializer())
tf.logging.info('iteration number, cost')
# Run training.
for itr in range(FLAGS.num_iterations):
# Generate new batch of data.
feed_dict = {model.iter_num: np.float32(itr),
model.lr: FLAGS.learning_rate}
cost, _, summary_str = sess.run([model.loss, model.train_op, model.summ_op],
feed_dict)
# Print info: iteration #, cost.
tf.logging.info(str(itr) + ' ' + str(cost))
if (itr) % VAL_INTERVAL == 2:
# Run through validation set.
feed_dict = {val_model.lr: 0.0,
val_model.iter_num: np.float32(itr)}
_, val_summary_str = sess.run([val_model.train_op, val_model.summ_op],
feed_dict)
summary_writer.add_summary(val_summary_str, itr)
if (itr) % SAVE_INTERVAL == 2:
tf.logging.info('Saving model.')
saver.save(sess, FLAGS.output_dir + '/model' + str(itr))
if (itr) % SUMMARY_INTERVAL:
summary_writer.add_summary(summary_str, itr)
tf.logging.info('Saving model.')
saver.save(sess, FLAGS.output_dir + '/model')
tf.logging.info('Training complete')
tf.logging.flush()