当前位置: 首页>>代码示例>>Python>>正文


Python densenet.DenseNet方法代码示例

本文整理汇总了Python中densenet.DenseNet方法的典型用法代码示例。如果您正苦于以下问题:Python densenet.DenseNet方法的具体用法?Python densenet.DenseNet怎么用?Python densenet.DenseNet使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在densenet的用法示例。


在下文中一共展示了densenet.DenseNet方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_net

# 需要导入模块: import densenet [as 别名]
# 或者: from densenet import DenseNet [as 别名]
def get_net(args):
    if args.model == 'densenet':
        net = densenet.DenseNet(growthRate=12, depth=100, reduction=0.5,
                                bottleneck=True, nClasses=10)
    elif args.model == 'lenet':
        net = models.Lenet(args.nHidden, 10, args.proj)
    elif args.model == 'lenet-optnet':
        net = models.LenetOptNet(args.nHidden, args.nineq)
    elif args.model == 'fc':
        net = models.FC(args.nHidden, args.bn)
    elif args.model == 'optnet':
        net = models.OptNet(28*28, args.nHidden, 10, args.bn, args.nineq)
    elif args.model == 'optnet-eq':
        net = models.OptNetEq(28*28, args.nHidden, 10, args.neq)
    else:
        assert(False)

    return net 
开发者ID:locuslab,项目名称:optnet,代码行数:20,代码来源:train.py

示例2: main

# 需要导入模块: import densenet [as 别名]
# 或者: from densenet import DenseNet [as 别名]
def main():
  save_path = tf.train.latest_checkpoint(args.model_dir)
  model = densenet.DenseNet(1, args.num_class, mode='test')
  saver = tf.train.Saver()
  id_to_word = load_vocab()

  with tf.Session() as sess:
    saver.restore(sess=sess, save_path=save_path)
    if args.export:
      export_model(sess, model)
      exit(0)

    print("load model from %s"%(save_path))
    counter = 0
    right_counter = 0
    for batch_data in data_generator.get_batch(args.test_image_list, batch_size=1, mode='test', workers=1, max_queue_size=12):
      image = np.array(batch_data[0])
      label = batch_data[1]
      image_path = batch_data[2]
      feed_dict = {model.images: image}
      prediction, predict_prob = sess.run([model.prediction, model.predict_prob], feed_dict=feed_dict)
      predict_id = prediction[0]
      predict_label = id_to_word[predict_id]
      predict_prob = predict_prob[0][predict_id]
      true_label = id_to_word[label[0]]
      print("image_path: %s, true_id: %d, true_label: %s, predict_label: %s, predict_prob: %f"%(
        image_path, label[0], true_label ,predict_label, predict_prob))

      if true_label == predict_label :
        right_counter += 1
      counter += 1
      if counter > 100:
        break
    print("acc : %f"%(1.0 * right_counter / counter )) 
开发者ID:rockyzhengwu,项目名称:document-ocr,代码行数:36,代码来源:eval.py

示例3: train

# 需要导入模块: import densenet [as 别名]
# 或者: from densenet import DenseNet [as 别名]
def train():
  batch_size = args.batch_size
  num_class = args.num_class
  model = densenet.DenseNet(batch_size=batch_size, num_classes=num_class)
  global_step = tf.train.get_or_create_global_step()
  start_learning_rate= 0.0001
  learning_rate = tf.train.exponential_decay(
    start_learning_rate,
    global_step,
    100000,
    0.98,
    staircase=False,
    name="learning_rate"
  )
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  train_op= tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss=model.loss, global_step=global_step)
  train_op = tf.group([train_op, update_ops])
  #optimizer=tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9).minimize(loss=model.loss)
  saver = tf.train.Saver()
  tf.summary.scalar(name='loss', tensor=model.loss)
  #tf.summary.scalar(name='softmax_loss', tensor=model.softmax_loss)
  #tf.summary.scalar(name='center_loss', tensor=model.center_loss)
  tf.summary.scalar(name='accuracy', tensor=model.accuracy)
  merge_summary_op = tf.summary.merge_all()
  sess_config = tf.ConfigProto(allow_soft_placement=True,)
  with tf.Session(config=sess_config) as sess:
    ckpt = tf.train.latest_checkpoint(args.checkpoint_path)
    if ckpt:
      print("restore form %s "%(ckpt))
      st = int(ckpt.split('-')[-1])
      saver.restore(sess, ckpt)
      sess.run(global_step.assign(st))
    else:
      tf.global_variables_initializer().run()
    summary_writer = tf.summary.FileWriter(args.checkpoint_path)
    summary_writer.add_graph(sess.graph)
    start_time = time.time()
    step = 0
    iterator = data_generator.get_batch(args.train_image_list, batch_size)
    for batch in iterator:
      if batch is None:
        print("batch is None")
        continue
      image = batch[0]
      labels = batch[1]
      feed_dict = {model.images: image, model.labels: labels}
      _, loss, accuracy,summary, g_step, logits, lr = sess.run(
              [train_op, model.loss, model.accuracy, merge_summary_op, global_step, model.logits, learning_rate ], 
              feed_dict=feed_dict)
      if loss is None:
        print(np.max(logits), np.min(logits))
        exit(0)
      if step % 10 ==0:
        print(np.max(logits), np.min(logits))
        print("step:%d, lr: %f, loss: %f, accuracy: %f"%(g_step, lr, loss, accuracy))
      if step % 100 == 0:
        summary_writer.add_summary(summary=summary, global_step=g_step)
        saver.save(sess=sess, save_path=os.path.join(args.checkpont_path, 'model'), global_step=g_step)
      step += 1
    print("cost: ", time.time() - start_time) 
开发者ID:rockyzhengwu,项目名称:document-ocr,代码行数:62,代码来源:train.py


注:本文中的densenet.DenseNet方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。