當前位置: 首頁>>代碼示例>>Python>>正文


Python vgg.preprocess方法代碼示例

本文整理匯總了Python中vgg.preprocess方法的典型用法代碼示例。如果您正苦於以下問題:Python vgg.preprocess方法的具體用法?Python vgg.preprocess怎麽用?Python vgg.preprocess使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在vgg的用法示例。


在下文中一共展示了vgg.preprocess方法的1個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: train

# 需要導入模塊: import vgg [as 別名]
# 或者: from vgg import preprocess [as 別名]
def train(self):
        with tf.Session() as sess:
            out_im = self.U_net(self.holder[41]/127.5-1)
            
            gt_resize = tf.image.resize_images(self.holder[42]/127.5-1, [256,256])
            image_pre = vgg.preprocess(gt_resize)
            fai_imgt = {}
            net = vgg.net(self.vgg_path, image_pre)
            for layer in self.vgg_layer:
                fai_imgt[layer] = net[layer]
                
            
            image_pre = vgg.preprocess(tf.image.resize_images(out_im, [256,256]))
            fai_imout = {}
            net = vgg.net(self.vgg_path, image_pre)
            for layer in self.vgg_layer:
                fai_imout[layer] = net[layer]
            
            Im_compt = self.holder[16]*self.holder[42]+(tf.add(tf.multiply(self.holder[16],-1),1))*((out_im+1)*127.5)
            im_compt = tf.image.resize_images(Im_compt/127.5-1, [256,256])
            image_pre = vgg.preprocess(im_compt)
            fai_compt = {}
            net = vgg.net(self.vgg_path, image_pre)
            for layer in self.vgg_layer:
                fai_compt[layer] = net[layer]
                
            U_vars = [var for var in tf.trainable_variables() if 'UNET' in var.name]
            total_loss = get_total_loss(out_im,self.holder[-1]/127.5-1,self.holder[16],fai_imout,fai_imgt,fai_compt,self.vgg_layer,im_compt)
            optim = tf.train.AdamOptimizer()
            optimizer = optim.minimize(total_loss[0],var_list=U_vars)
            
    
            int_group = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
            sess.run(int_group)
            
            graph = tf.summary.FileWriter(self.logdir, sess.graph)
            saver = tf.train.Saver(U_vars,max_to_keep=20)
            
            for epoch in range(self.num_epochs):
                for imid in range(int(self.total_ims//self.batch)):
                    mask_ims,gt_ims = get_im(self.ims_dir,imid)
                    self.get_all_mask(mask_ims,gt_ims)
                    feed_dic = get_feedict(self.all_masks,self.holder)
                    _,loss_total = sess.run([optimizer,total_loss],feed_dict=feed_dic)
                    
                    if (int(epoch*self.total_ims)+imid)%1==0:
                        print('epoch: %d,  cur_num: %d,  total_loss: %f, l_hole: %f, l_valid: %f, percept_loss: %f, style_loss_out: %f, style_loss_comp: %f, tv_loss: %f'%(epoch,imid,loss_total[0],loss_total[1],loss_total[2],loss_total[3],loss_total[4],loss_total[5],loss_total[6]))
                if epoch%5==0:
                    saver.save(sess, self.save_path+'model.ckpt', global_step=epoch) 
開發者ID:Rongpeng-Lin,項目名稱:PConv_in_tf,代碼行數:51,代碼來源:main.py


注:本文中的vgg.preprocess方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。