本文整理汇总了Python中model.Discriminator方法的典型用法代码示例。如果您正苦于以下问题:Python model.Discriminator方法的具体用法?Python model.Discriminator怎么用?Python model.Discriminator使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类model
的用法示例。
在下文中一共展示了model.Discriminator方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: build_model
# 需要导入模块: import model [as 别名]
# 或者: from model import Discriminator [as 别名]
def build_model(self):
"""Create a generator and a discriminator."""
if self.dataset in ['CelebA', 'RaFD']:
self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num)
elif self.dataset in ['Both']:
self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector.
self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num)
self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
self.print_network(self.G, 'G')
self.print_network(self.D, 'D')
self.G.to(self.device)
self.D.to(self.device)
示例2: train
# 需要导入模块: import model [as 别名]
# 或者: from model import Discriminator [as 别名]
def train(self):
batch_num = self.data.length//self.FLAGS.batch_size if self.data.length%self.FLAGS.batch_size==0 else self.data.length//self.FLAGS.batch_size + 1
print("Start training WGAN...\n")
for t in range(self.FLAGS.iter):
d_cost = 0
g_coat = 0
for d_ep in range(self.d_epoch):
img, tags, _, w_img, w_tags = self.data.next_data_batch(self.FLAGS.batch_size)
z = self.data.next_noise_batch(len(tags), self.FLAGS.z_dim)
feed_dict = {
self.seq:tags,
self.img:img,
self.z:z,
self.w_seq:w_tags,
self.w_img:w_img
}
_, loss = self.sess.run([self.d_updates, self.d_loss], feed_dict=feed_dict)
d_cost += loss/self.d_epoch
z = self.data.next_noise_batch(len(tags), self.FLAGS.z_dim)
feed_dict = {
self.img:img,
self.w_seq:w_tags,
self.w_img:w_img,
self.seq:tags,
self.z:z
}
_, loss, step = self.sess.run([self.g_updates, self.g_loss, self.global_step], feed_dict=feed_dict)
current_step = tf.train.global_step(self.sess, self.global_step)
g_cost = loss
if current_step % self.FLAGS.display_every == 0:
print("Epoch {}, Current_step {}".format(self.data.epoch, current_step))
print("Discriminator loss :{}".format(d_cost))
print("Generator loss :{}".format(g_cost))
print("---------------------------------")
if current_step % self.FLAGS.checkpoint_every == 0:
path = self.saver.save(self.sess, self.checkpoint_prefix, global_step=current_step)
print ("\nSaved model checkpoint to {}\n".format(path))
if current_step % self.FLAGS.dump_every == 0:
self.eval(current_step)
print("Dump test image")
示例3: main
# 需要导入模块: import model [as 别名]
# 或者: from model import Discriminator [as 别名]
def main():
voc = util.Voc(init_from_file="data/voc_b.txt")
netR_path = 'output/rf_dis.pkg'
netG_path = 'output/net_p'
netD_path = 'output/net_d'
agent_path = 'output/net_gan_%d_%d_%dx%d' % (SIGMA * 10, BL * 10, BATCH_SIZE, MC)
netR = util.Environment(netR_path)
agent = model.Generator(voc)
agent.load_state_dict(T.load(netG_path + '.pkg'))
df = pd.read_table('data/CHEMBL251.txt')
df = df[df['PCHEMBL_VALUE'] >= 6.5]
data = util.MolData(df, voc)
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, collate_fn=data.collate_fn)
netD = model.Discriminator(VOCAB_SIZE, EMBED_DIM, FILTER_SIZE, NUM_FILTER)
if not os.path.exists(netD_path + '.pkg'):
Train_dis_BCE(netD, agent, loader, epochs=100, out=netD_path)
netD.load_state_dict(T.load(netD_path + '.pkg'))
best_score = 0
log = open(agent_path + '.log', 'w')
for epoch in range(1000):
print('\n--------\nEPOCH %d\n--------' % (epoch + 1))
print('\nPolicy Gradient Training Generator : ')
Train_GAN(agent, netD, netR)
print('\nAdversarial Training Discriminator : ')
Train_dis_BCE(netD, agent, loader, epochs=1)
seqs = agent.sample(1000)
ix = util.unique(seqs)
smiles, valids = util.check_smiles(seqs[ix], agent.voc)
scores = netR(smiles)
scores[valids == False] = 0
unique = (scores >= 0.5).sum() / 1000
if best_score < unique:
T.save(agent.state_dict(), agent_path + '.pkg')
best_score = unique
print("Epoch+: %d average: %.4f valid: %.4f unique: %.4f" % (epoch, scores.mean(), valids.mean(), unique), file=log)
for i, smile in enumerate(smiles):
print('%f\t%s' % (scores[i], smile), file=log)
for param_group in agent.optim.param_groups:
param_group['lr'] *= (1 - 0.01)
log.close()