本文整理匯總了Python中vgg.preprocess方法的典型用法代碼示例。如果您正苦於以下問題:Python vgg.preprocess方法的具體用法?Python vgg.preprocess怎麽用?Python vgg.preprocess使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類vgg
的用法示例。
在下文中一共展示了vgg.preprocess方法的1個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: train
# 需要導入模塊: import vgg [as 別名]
# 或者: from vgg import preprocess [as 別名]
def train(self):
with tf.Session() as sess:
out_im = self.U_net(self.holder[41]/127.5-1)
gt_resize = tf.image.resize_images(self.holder[42]/127.5-1, [256,256])
image_pre = vgg.preprocess(gt_resize)
fai_imgt = {}
net = vgg.net(self.vgg_path, image_pre)
for layer in self.vgg_layer:
fai_imgt[layer] = net[layer]
image_pre = vgg.preprocess(tf.image.resize_images(out_im, [256,256]))
fai_imout = {}
net = vgg.net(self.vgg_path, image_pre)
for layer in self.vgg_layer:
fai_imout[layer] = net[layer]
Im_compt = self.holder[16]*self.holder[42]+(tf.add(tf.multiply(self.holder[16],-1),1))*((out_im+1)*127.5)
im_compt = tf.image.resize_images(Im_compt/127.5-1, [256,256])
image_pre = vgg.preprocess(im_compt)
fai_compt = {}
net = vgg.net(self.vgg_path, image_pre)
for layer in self.vgg_layer:
fai_compt[layer] = net[layer]
U_vars = [var for var in tf.trainable_variables() if 'UNET' in var.name]
total_loss = get_total_loss(out_im,self.holder[-1]/127.5-1,self.holder[16],fai_imout,fai_imgt,fai_compt,self.vgg_layer,im_compt)
optim = tf.train.AdamOptimizer()
optimizer = optim.minimize(total_loss[0],var_list=U_vars)
int_group = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(int_group)
graph = tf.summary.FileWriter(self.logdir, sess.graph)
saver = tf.train.Saver(U_vars,max_to_keep=20)
for epoch in range(self.num_epochs):
for imid in range(int(self.total_ims//self.batch)):
mask_ims,gt_ims = get_im(self.ims_dir,imid)
self.get_all_mask(mask_ims,gt_ims)
feed_dic = get_feedict(self.all_masks,self.holder)
_,loss_total = sess.run([optimizer,total_loss],feed_dict=feed_dic)
if (int(epoch*self.total_ims)+imid)%1==0:
print('epoch: %d, cur_num: %d, total_loss: %f, l_hole: %f, l_valid: %f, percept_loss: %f, style_loss_out: %f, style_loss_comp: %f, tv_loss: %f'%(epoch,imid,loss_total[0],loss_total[1],loss_total[2],loss_total[3],loss_total[4],loss_total[5],loss_total[6]))
if epoch%5==0:
saver.save(sess, self.save_path+'model.ckpt', global_step=epoch)