本文整理汇总了Python中densenet.DenseNet方法的典型用法代码示例。如果您正苦于以下问题:Python densenet.DenseNet方法的具体用法?Python densenet.DenseNet怎么用?Python densenet.DenseNet使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类densenet
的用法示例。
在下文中一共展示了densenet.DenseNet方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_net
# 需要导入模块: import densenet [as 别名]
# 或者: from densenet import DenseNet [as 别名]
def get_net(args):
if args.model == 'densenet':
net = densenet.DenseNet(growthRate=12, depth=100, reduction=0.5,
bottleneck=True, nClasses=10)
elif args.model == 'lenet':
net = models.Lenet(args.nHidden, 10, args.proj)
elif args.model == 'lenet-optnet':
net = models.LenetOptNet(args.nHidden, args.nineq)
elif args.model == 'fc':
net = models.FC(args.nHidden, args.bn)
elif args.model == 'optnet':
net = models.OptNet(28*28, args.nHidden, 10, args.bn, args.nineq)
elif args.model == 'optnet-eq':
net = models.OptNetEq(28*28, args.nHidden, 10, args.neq)
else:
assert(False)
return net
示例2: main
# 需要导入模块: import densenet [as 别名]
# 或者: from densenet import DenseNet [as 别名]
def main():
save_path = tf.train.latest_checkpoint(args.model_dir)
model = densenet.DenseNet(1, args.num_class, mode='test')
saver = tf.train.Saver()
id_to_word = load_vocab()
with tf.Session() as sess:
saver.restore(sess=sess, save_path=save_path)
if args.export:
export_model(sess, model)
exit(0)
print("load model from %s"%(save_path))
counter = 0
right_counter = 0
for batch_data in data_generator.get_batch(args.test_image_list, batch_size=1, mode='test', workers=1, max_queue_size=12):
image = np.array(batch_data[0])
label = batch_data[1]
image_path = batch_data[2]
feed_dict = {model.images: image}
prediction, predict_prob = sess.run([model.prediction, model.predict_prob], feed_dict=feed_dict)
predict_id = prediction[0]
predict_label = id_to_word[predict_id]
predict_prob = predict_prob[0][predict_id]
true_label = id_to_word[label[0]]
print("image_path: %s, true_id: %d, true_label: %s, predict_label: %s, predict_prob: %f"%(
image_path, label[0], true_label ,predict_label, predict_prob))
if true_label == predict_label :
right_counter += 1
counter += 1
if counter > 100:
break
print("acc : %f"%(1.0 * right_counter / counter ))
示例3: train
# 需要导入模块: import densenet [as 别名]
# 或者: from densenet import DenseNet [as 别名]
def train():
batch_size = args.batch_size
num_class = args.num_class
model = densenet.DenseNet(batch_size=batch_size, num_classes=num_class)
global_step = tf.train.get_or_create_global_step()
start_learning_rate= 0.0001
learning_rate = tf.train.exponential_decay(
start_learning_rate,
global_step,
100000,
0.98,
staircase=False,
name="learning_rate"
)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_op= tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss=model.loss, global_step=global_step)
train_op = tf.group([train_op, update_ops])
#optimizer=tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9).minimize(loss=model.loss)
saver = tf.train.Saver()
tf.summary.scalar(name='loss', tensor=model.loss)
#tf.summary.scalar(name='softmax_loss', tensor=model.softmax_loss)
#tf.summary.scalar(name='center_loss', tensor=model.center_loss)
tf.summary.scalar(name='accuracy', tensor=model.accuracy)
merge_summary_op = tf.summary.merge_all()
sess_config = tf.ConfigProto(allow_soft_placement=True,)
with tf.Session(config=sess_config) as sess:
ckpt = tf.train.latest_checkpoint(args.checkpoint_path)
if ckpt:
print("restore form %s "%(ckpt))
st = int(ckpt.split('-')[-1])
saver.restore(sess, ckpt)
sess.run(global_step.assign(st))
else:
tf.global_variables_initializer().run()
summary_writer = tf.summary.FileWriter(args.checkpoint_path)
summary_writer.add_graph(sess.graph)
start_time = time.time()
step = 0
iterator = data_generator.get_batch(args.train_image_list, batch_size)
for batch in iterator:
if batch is None:
print("batch is None")
continue
image = batch[0]
labels = batch[1]
feed_dict = {model.images: image, model.labels: labels}
_, loss, accuracy,summary, g_step, logits, lr = sess.run(
[train_op, model.loss, model.accuracy, merge_summary_op, global_step, model.logits, learning_rate ],
feed_dict=feed_dict)
if loss is None:
print(np.max(logits), np.min(logits))
exit(0)
if step % 10 ==0:
print(np.max(logits), np.min(logits))
print("step:%d, lr: %f, loss: %f, accuracy: %f"%(g_step, lr, loss, accuracy))
if step % 100 == 0:
summary_writer.add_summary(summary=summary, global_step=g_step)
saver.save(sess=sess, save_path=os.path.join(args.checkpont_path, 'model'), global_step=g_step)
step += 1
print("cost: ", time.time() - start_time)