当前位置: 首页>>代码示例>>Python>>正文


Python model.Discriminator方法代码示例

本文整理汇总了Python中model.Discriminator方法的典型用法代码示例。如果您正苦于以下问题:Python model.Discriminator方法的具体用法?Python model.Discriminator怎么用?Python model.Discriminator使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在model的用法示例。


在下文中一共展示了model.Discriminator方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: build_model

# 需要导入模块: import model [as 别名]
# 或者: from model import Discriminator [as 别名]
def build_model(self):
        """Create a generator and a discriminator."""
        if self.dataset in ['CelebA', 'RaFD']:
            self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
            self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) 
        elif self.dataset in ['Both']:
            self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num)   # 2 for mask vector.
            self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num)

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')
            
        self.G.to(self.device)
        self.D.to(self.device) 
开发者ID:yunjey,项目名称:stargan,代码行数:18,代码来源:solver.py

示例2: train

# 需要导入模块: import model [as 别名]
# 或者: from model import Discriminator [as 别名]
def train(self):
		batch_num = self.data.length//self.FLAGS.batch_size if self.data.length%self.FLAGS.batch_size==0 else self.data.length//self.FLAGS.batch_size + 1

		print("Start training WGAN...\n")

		for t in range(self.FLAGS.iter):

			d_cost = 0
			g_coat = 0

			for d_ep in range(self.d_epoch):

				img, tags, _, w_img, w_tags = self.data.next_data_batch(self.FLAGS.batch_size)
				z = self.data.next_noise_batch(len(tags), self.FLAGS.z_dim)

				feed_dict = {
					self.seq:tags,
					self.img:img,
					self.z:z,
					self.w_seq:w_tags,
					self.w_img:w_img
				}

				_, loss = self.sess.run([self.d_updates, self.d_loss], feed_dict=feed_dict)

				d_cost += loss/self.d_epoch

			z = self.data.next_noise_batch(len(tags), self.FLAGS.z_dim)
			feed_dict = {
				self.img:img,
				self.w_seq:w_tags,
				self.w_img:w_img,
				self.seq:tags,
				self.z:z
			}

			_, loss, step = self.sess.run([self.g_updates, self.g_loss, self.global_step], feed_dict=feed_dict)

			current_step = tf.train.global_step(self.sess, self.global_step)

			g_cost = loss

			if current_step % self.FLAGS.display_every == 0:
				print("Epoch {}, Current_step {}".format(self.data.epoch, current_step))
				print("Discriminator loss :{}".format(d_cost))
				print("Generator loss     :{}".format(g_cost))
				print("---------------------------------")

			if current_step % self.FLAGS.checkpoint_every == 0:
				path = self.saver.save(self.sess, self.checkpoint_prefix, global_step=current_step)
				print ("\nSaved model checkpoint to {}\n".format(path))

			if current_step % self.FLAGS.dump_every == 0:
				self.eval(current_step)
				print("Dump test image") 
开发者ID:m516825,项目名称:Conditional-GAN,代码行数:57,代码来源:improved_WGAN.py

示例3: main

# 需要导入模块: import model [as 别名]
# 或者: from model import Discriminator [as 别名]
def main():
    voc = util.Voc(init_from_file="data/voc_b.txt")
    netR_path = 'output/rf_dis.pkg'
    netG_path = 'output/net_p'
    netD_path = 'output/net_d'
    agent_path = 'output/net_gan_%d_%d_%dx%d' % (SIGMA * 10, BL * 10, BATCH_SIZE, MC)

    netR = util.Environment(netR_path)

    agent = model.Generator(voc)
    agent.load_state_dict(T.load(netG_path + '.pkg'))

    df = pd.read_table('data/CHEMBL251.txt')
    df = df[df['PCHEMBL_VALUE'] >= 6.5]
    data = util.MolData(df, voc)
    loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, collate_fn=data.collate_fn)

    netD = model.Discriminator(VOCAB_SIZE, EMBED_DIM, FILTER_SIZE, NUM_FILTER)
    if not os.path.exists(netD_path + '.pkg'):
        Train_dis_BCE(netD, agent, loader, epochs=100, out=netD_path)
    netD.load_state_dict(T.load(netD_path + '.pkg'))

    best_score = 0
    log = open(agent_path + '.log', 'w')
    for epoch in range(1000):
        print('\n--------\nEPOCH %d\n--------' % (epoch + 1))
        print('\nPolicy Gradient Training Generator : ')
        Train_GAN(agent, netD, netR)

        print('\nAdversarial Training Discriminator : ')
        Train_dis_BCE(netD, agent, loader, epochs=1)

        seqs = agent.sample(1000)
        ix = util.unique(seqs)
        smiles, valids = util.check_smiles(seqs[ix], agent.voc)
        scores = netR(smiles)
        scores[valids == False] = 0
        unique = (scores >= 0.5).sum() / 1000
        if best_score < unique:
            T.save(agent.state_dict(), agent_path + '.pkg')
            best_score = unique
        print("Epoch+: %d average: %.4f valid: %.4f unique: %.4f" % (epoch, scores.mean(), valids.mean(), unique), file=log)
        for i, smile in enumerate(smiles):
            print('%f\t%s' % (scores[i], smile), file=log)

        for param_group in agent.optim.param_groups:
            param_group['lr'] *= (1 - 0.01)

    log.close() 
开发者ID:XuhanLiu,项目名称:DrugEx,代码行数:51,代码来源:organic.py


注:本文中的model.Discriminator方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。