当前位置: 首页>>代码示例>>Python>>正文


Python vgg.preprocess方法代码示例

本文整理汇总了Python中vgg.preprocess方法的典型用法代码示例。如果您正苦于以下问题:Python vgg.preprocess方法的具体用法?Python vgg.preprocess怎么用?Python vgg.preprocess使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在vgg的用法示例。


在下文中一共展示了vgg.preprocess方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: train

# 需要导入模块: import vgg [as 别名]
# 或者: from vgg import preprocess [as 别名]
def train(self):
        with tf.Session() as sess:
            out_im = self.U_net(self.holder[41]/127.5-1)
            
            gt_resize = tf.image.resize_images(self.holder[42]/127.5-1, [256,256])
            image_pre = vgg.preprocess(gt_resize)
            fai_imgt = {}
            net = vgg.net(self.vgg_path, image_pre)
            for layer in self.vgg_layer:
                fai_imgt[layer] = net[layer]
                
            
            image_pre = vgg.preprocess(tf.image.resize_images(out_im, [256,256]))
            fai_imout = {}
            net = vgg.net(self.vgg_path, image_pre)
            for layer in self.vgg_layer:
                fai_imout[layer] = net[layer]
            
            Im_compt = self.holder[16]*self.holder[42]+(tf.add(tf.multiply(self.holder[16],-1),1))*((out_im+1)*127.5)
            im_compt = tf.image.resize_images(Im_compt/127.5-1, [256,256])
            image_pre = vgg.preprocess(im_compt)
            fai_compt = {}
            net = vgg.net(self.vgg_path, image_pre)
            for layer in self.vgg_layer:
                fai_compt[layer] = net[layer]
                
            U_vars = [var for var in tf.trainable_variables() if 'UNET' in var.name]
            total_loss = get_total_loss(out_im,self.holder[-1]/127.5-1,self.holder[16],fai_imout,fai_imgt,fai_compt,self.vgg_layer,im_compt)
            optim = tf.train.AdamOptimizer()
            optimizer = optim.minimize(total_loss[0],var_list=U_vars)
            
    
            int_group = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
            sess.run(int_group)
            
            graph = tf.summary.FileWriter(self.logdir, sess.graph)
            saver = tf.train.Saver(U_vars,max_to_keep=20)
            
            for epoch in range(self.num_epochs):
                for imid in range(int(self.total_ims//self.batch)):
                    mask_ims,gt_ims = get_im(self.ims_dir,imid)
                    self.get_all_mask(mask_ims,gt_ims)
                    feed_dic = get_feedict(self.all_masks,self.holder)
                    _,loss_total = sess.run([optimizer,total_loss],feed_dict=feed_dic)
                    
                    if (int(epoch*self.total_ims)+imid)%1==0:
                        print('epoch: %d,  cur_num: %d,  total_loss: %f, l_hole: %f, l_valid: %f, percept_loss: %f, style_loss_out: %f, style_loss_comp: %f, tv_loss: %f'%(epoch,imid,loss_total[0],loss_total[1],loss_total[2],loss_total[3],loss_total[4],loss_total[5],loss_total[6]))
                if epoch%5==0:
                    saver.save(sess, self.save_path+'model.ckpt', global_step=epoch) 
开发者ID:Rongpeng-Lin,项目名称:PConv_in_tf,代码行数:51,代码来源:main.py


注:本文中的vgg.preprocess方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。