當前位置: 首頁>>代碼示例>>Python>>正文


Python densenet.DenseNet方法代碼示例

本文整理匯總了Python中densenet.DenseNet方法的典型用法代碼示例。如果您正苦於以下問題:Python densenet.DenseNet方法的具體用法?Python densenet.DenseNet怎麽用?Python densenet.DenseNet使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在densenet的用法示例。


在下文中一共展示了densenet.DenseNet方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: get_net

# 需要導入模塊: import densenet [as 別名]
# 或者: from densenet import DenseNet [as 別名]
def get_net(args):
    if args.model == 'densenet':
        net = densenet.DenseNet(growthRate=12, depth=100, reduction=0.5,
                                bottleneck=True, nClasses=10)
    elif args.model == 'lenet':
        net = models.Lenet(args.nHidden, 10, args.proj)
    elif args.model == 'lenet-optnet':
        net = models.LenetOptNet(args.nHidden, args.nineq)
    elif args.model == 'fc':
        net = models.FC(args.nHidden, args.bn)
    elif args.model == 'optnet':
        net = models.OptNet(28*28, args.nHidden, 10, args.bn, args.nineq)
    elif args.model == 'optnet-eq':
        net = models.OptNetEq(28*28, args.nHidden, 10, args.neq)
    else:
        assert(False)

    return net 
開發者ID:locuslab,項目名稱:optnet,代碼行數:20,代碼來源:train.py

示例2: main

# 需要導入模塊: import densenet [as 別名]
# 或者: from densenet import DenseNet [as 別名]
def main():
  save_path = tf.train.latest_checkpoint(args.model_dir)
  model = densenet.DenseNet(1, args.num_class, mode='test')
  saver = tf.train.Saver()
  id_to_word = load_vocab()

  with tf.Session() as sess:
    saver.restore(sess=sess, save_path=save_path)
    if args.export:
      export_model(sess, model)
      exit(0)

    print("load model from %s"%(save_path))
    counter = 0
    right_counter = 0
    for batch_data in data_generator.get_batch(args.test_image_list, batch_size=1, mode='test', workers=1, max_queue_size=12):
      image = np.array(batch_data[0])
      label = batch_data[1]
      image_path = batch_data[2]
      feed_dict = {model.images: image}
      prediction, predict_prob = sess.run([model.prediction, model.predict_prob], feed_dict=feed_dict)
      predict_id = prediction[0]
      predict_label = id_to_word[predict_id]
      predict_prob = predict_prob[0][predict_id]
      true_label = id_to_word[label[0]]
      print("image_path: %s, true_id: %d, true_label: %s, predict_label: %s, predict_prob: %f"%(
        image_path, label[0], true_label ,predict_label, predict_prob))

      if true_label == predict_label :
        right_counter += 1
      counter += 1
      if counter > 100:
        break
    print("acc : %f"%(1.0 * right_counter / counter )) 
開發者ID:rockyzhengwu,項目名稱:document-ocr,代碼行數:36,代碼來源:eval.py

示例3: train

# 需要導入模塊: import densenet [as 別名]
# 或者: from densenet import DenseNet [as 別名]
def train():
  batch_size = args.batch_size
  num_class = args.num_class
  model = densenet.DenseNet(batch_size=batch_size, num_classes=num_class)
  global_step = tf.train.get_or_create_global_step()
  start_learning_rate= 0.0001
  learning_rate = tf.train.exponential_decay(
    start_learning_rate,
    global_step,
    100000,
    0.98,
    staircase=False,
    name="learning_rate"
  )
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  train_op= tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss=model.loss, global_step=global_step)
  train_op = tf.group([train_op, update_ops])
  #optimizer=tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9).minimize(loss=model.loss)
  saver = tf.train.Saver()
  tf.summary.scalar(name='loss', tensor=model.loss)
  #tf.summary.scalar(name='softmax_loss', tensor=model.softmax_loss)
  #tf.summary.scalar(name='center_loss', tensor=model.center_loss)
  tf.summary.scalar(name='accuracy', tensor=model.accuracy)
  merge_summary_op = tf.summary.merge_all()
  sess_config = tf.ConfigProto(allow_soft_placement=True,)
  with tf.Session(config=sess_config) as sess:
    ckpt = tf.train.latest_checkpoint(args.checkpoint_path)
    if ckpt:
      print("restore form %s "%(ckpt))
      st = int(ckpt.split('-')[-1])
      saver.restore(sess, ckpt)
      sess.run(global_step.assign(st))
    else:
      tf.global_variables_initializer().run()
    summary_writer = tf.summary.FileWriter(args.checkpoint_path)
    summary_writer.add_graph(sess.graph)
    start_time = time.time()
    step = 0
    iterator = data_generator.get_batch(args.train_image_list, batch_size)
    for batch in iterator:
      if batch is None:
        print("batch is None")
        continue
      image = batch[0]
      labels = batch[1]
      feed_dict = {model.images: image, model.labels: labels}
      _, loss, accuracy,summary, g_step, logits, lr = sess.run(
              [train_op, model.loss, model.accuracy, merge_summary_op, global_step, model.logits, learning_rate ], 
              feed_dict=feed_dict)
      if loss is None:
        print(np.max(logits), np.min(logits))
        exit(0)
      if step % 10 ==0:
        print(np.max(logits), np.min(logits))
        print("step:%d, lr: %f, loss: %f, accuracy: %f"%(g_step, lr, loss, accuracy))
      if step % 100 == 0:
        summary_writer.add_summary(summary=summary, global_step=g_step)
        saver.save(sess=sess, save_path=os.path.join(args.checkpont_path, 'model'), global_step=g_step)
      step += 1
    print("cost: ", time.time() - start_time) 
開發者ID:rockyzhengwu,項目名稱:document-ocr,代碼行數:62,代碼來源:train.py


注:本文中的densenet.DenseNet方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。