當前位置: 首頁>>代碼示例>>Python>>正文


Python cifar_input.build_input方法代碼示例

本文整理匯總了Python中cifar_input.build_input方法的典型用法代碼示例。如果您正苦於以下問題:Python cifar_input.build_input方法的具體用法?Python cifar_input.build_input怎麽用?Python cifar_input.build_input使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在cifar_input的用法示例。


在下文中一共展示了cifar_input.build_input方法的4個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: get_model

# 需要導入模塊: import cifar_input [as 別名]
# 或者: from cifar_input import build_input [as 別名]
def get_model(hps, dataset, train_data_path, mode='train'):
  images, labels = cifar_input.build_input(
    dataset, train_data_path, hps.batch_size, mode)
  model = resnet_model.ResNet(hps, images, labels, mode)
  model.build_graph()
  return model 
開發者ID:JianGoForIt,項目名稱:YellowFin,代碼行數:8,代碼來源:resnet_utils.py

示例2: evaluate

# 需要導入模塊: import cifar_input [as 別名]
# 或者: from cifar_input import build_input [as 別名]
def evaluate(hps):
  """Eval loop."""
  images, labels = cifar_input.build_input(
      FLAGS.dataset, FLAGS.eval_data_path, hps.batch_size, FLAGS.mode)
  model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
  model.build_graph()
  saver = tf.train.Saver()
  summary_writer = tf.summary.FileWriter(FLAGS.eval_dir)

  sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
  tf.train.start_queue_runners(sess)

  best_precision = 0.0
  while True:
    try:
      ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
    except tf.errors.OutOfRangeError as e:
      tf.logging.error('Cannot restore checkpoint: %s', e)
      continue
    if not (ckpt_state and ckpt_state.model_checkpoint_path):
      tf.logging.info('No model to eval yet at %s', FLAGS.log_root)
      continue
    tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
    saver.restore(sess, ckpt_state.model_checkpoint_path)

    total_prediction, correct_prediction = 0, 0
    for _ in six.moves.range(FLAGS.eval_batch_count):
      (summaries, loss, predictions, truth, train_step) = sess.run(
          [model.summaries, model.cost, model.predictions,
           model.labels, model.global_step])

      truth = np.argmax(truth, axis=1)
      predictions = np.argmax(predictions, axis=1)
      correct_prediction += np.sum(truth == predictions)
      total_prediction += predictions.shape[0]

    precision = 1.0 * correct_prediction / total_prediction
    best_precision = max(precision, best_precision)

    precision_summ = tf.Summary()
    precision_summ.value.add(
        tag='Precision', simple_value=precision)
    summary_writer.add_summary(precision_summ, train_step)
    best_precision_summ = tf.Summary()
    best_precision_summ.value.add(
        tag='Best Precision', simple_value=best_precision)
    summary_writer.add_summary(best_precision_summ, train_step)
    summary_writer.add_summary(summaries, train_step)
    tf.logging.info('loss: %.3f, precision: %.3f, best precision: %.3f' %
                    (loss, precision, best_precision))
    summary_writer.flush()

    if FLAGS.eval_once:
      break

    time.sleep(60) 
開發者ID:ringringyi,項目名稱:DOTA_models,代碼行數:58,代碼來源:resnet_main.py

示例3: evaluate

# 需要導入模塊: import cifar_input [as 別名]
# 或者: from cifar_input import build_input [as 別名]
def evaluate(hps):
  """Eval loop."""
  images, labels = cifar_input.build_input(
      FLAGS.dataset, FLAGS.eval_data_path, hps.batch_size, FLAGS.mode)
  model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
  model.build_graph()
  saver = tf.train.Saver()
  summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir)

  sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
  tf.train.start_queue_runners(sess)

  best_precision = 0.0
  while True:
    time.sleep(60)
    try:
      ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
    except tf.errors.OutOfRangeError as e:
      tf.logging.error('Cannot restore checkpoint: %s', e)
      continue
    if not (ckpt_state and ckpt_state.model_checkpoint_path):
      tf.logging.info('No model to eval yet at %s', FLAGS.log_root)
      continue
    tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
    saver.restore(sess, ckpt_state.model_checkpoint_path)

    total_prediction, correct_prediction = 0, 0
    for _ in xrange(FLAGS.eval_batch_count):
      (summaries, loss, predictions, truth, train_step) = sess.run(
          [model.summaries, model.cost, model.predictions,
           model.labels, model.global_step])

      truth = np.argmax(truth, axis=1)
      predictions = np.argmax(predictions, axis=1)
      correct_prediction += np.sum(truth == predictions)
      total_prediction += predictions.shape[0]

    precision = 1.0 * correct_prediction / total_prediction
    best_precision = max(precision, best_precision)

    precision_summ = tf.Summary()
    precision_summ.value.add(
        tag='Precision', simple_value=precision)
    summary_writer.add_summary(precision_summ, train_step)
    best_precision_summ = tf.Summary()
    best_precision_summ.value.add(
        tag='Best Precision', simple_value=best_precision)
    summary_writer.add_summary(best_precision_summ, train_step)
    summary_writer.add_summary(summaries, train_step)
    tf.logging.info('loss: %.3f, precision: %.3f, best precision: %.3f\n' %
                    (loss, precision, best_precision))
    summary_writer.flush()

    if FLAGS.eval_once:
      break 
開發者ID:awslabs,項目名稱:deeplearning-benchmark,代碼行數:57,代碼來源:resnet_main.py

示例4: train

# 需要導入模塊: import cifar_input [as 別名]
# 或者: from cifar_input import build_input [as 別名]
def train(hps):
  """Training loop."""
  images, labels = cifar_input.build_input(
      FLAGS.dataset, FLAGS.train_data_path, hps.batch_size, FLAGS.mode)
  model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
  model.build_graph()
  summary_writer = tf.train.SummaryWriter(FLAGS.train_dir)

  sv = tf.train.Supervisor(logdir=FLAGS.log_root,
                           is_chief=True,
                           summary_op=None,
                           save_summaries_secs=60,
                           save_model_secs=300,
                           global_step=model.global_step)
  sess = sv.prepare_or_wait_for_session(
      config=tf.ConfigProto(allow_soft_placement=True))

  step = 0
  lrn_rate = 0.1

  while not sv.should_stop():
    (_, summaries, loss, predictions, truth, train_step) = sess.run(
        [model.train_op, model.summaries, model.cost, model.predictions,
         model.labels, model.global_step],
        feed_dict={model.lrn_rate: lrn_rate})

    if train_step < 40000:
      lrn_rate = 0.1
    elif train_step < 60000:
      lrn_rate = 0.01
    elif train_step < 80000:
      lrn_rate = 0.001
    else:
      lrn_rate = 0.0001

    truth = np.argmax(truth, axis=1)
    predictions = np.argmax(predictions, axis=1)
    precision = np.mean(truth == predictions)

    step += 1
    if step % 100 == 0:
      precision_summ = tf.Summary()
      precision_summ.value.add(
          tag='Precision', simple_value=precision)
      summary_writer.add_summary(precision_summ, train_step)
      summary_writer.add_summary(summaries, train_step)
      tf.logging.info('loss: %.3f, precision: %.3f\n' % (loss, precision))
      summary_writer.flush()

  sv.Stop() 
開發者ID:coderSkyChen,項目名稱:Action_Recognition_Zoo,代碼行數:52,代碼來源:resnet_main.py


注:本文中的cifar_input.build_input方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。