本文整理匯總了Python中models.Generator方法的典型用法代碼示例。如果您正苦於以下問題:Python models.Generator方法的具體用法?Python models.Generator怎麽用?Python models.Generator使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類models
的用法示例。
在下文中一共展示了models.Generator方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: train
# 需要導入模塊: import models [as 別名]
# 或者: from models import Generator [as 別名]
def train(args):
nz = args.nz
batch_size = args.batch_size
epochs = args.epochs
gpu = args.gpu
# CIFAR-10 images in range [-1, 1] (tanh generator outputs)
train, _ = datasets.get_cifar10(withlabel=False, ndim=3, scale=2)
train -= 1.0
train_iter = iterators.SerialIterator(train, batch_size)
z_iter = RandomNoiseIterator(GaussianNoiseGenerator(0, 1, args.nz),
batch_size)
optimizer_generator = optimizers.RMSprop(lr=0.00005)
optimizer_critic = optimizers.RMSprop(lr=0.00005)
optimizer_generator.setup(Generator())
optimizer_critic.setup(Critic())
updater = WassersteinGANUpdater(
iterator=train_iter,
noise_iterator=z_iter,
optimizer_generator=optimizer_generator,
optimizer_critic=optimizer_critic,
device=gpu)
trainer = training.Trainer(updater, stop_trigger=(epochs, 'epoch'))
trainer.extend(extensions.ProgressBar())
trainer.extend(extensions.LogReport(trigger=(1, 'iteration')))
trainer.extend(GeneratorSample(), trigger=(1, 'epoch'))
trainer.extend(extensions.PrintReport(['epoch', 'iteration', 'critic/loss',
'critic/loss/real', 'critic/loss/fake', 'generator/loss']))
trainer.run()
示例2: build_models
# 需要導入模塊: import models [as 別名]
# 或者: from models import Generator [as 別名]
def build_models(hps, current_res_w, use_ema_sampling=False, num_classes=None, label_list=None): # todo: fix num_classes
mapping_network = MappingNetwork() if hps.do_mapping_network else None
gen_model = Generator(current_res_w, hps.res_w, use_pixel_norm=hps.do_pixel_norm,
start_shape=(hps.start_res_h, hps.start_res_w),
equalized_lr=hps.do_equalized_lr,
traditional_input=hps.do_traditional_input,
add_noise=hps.do_add_noise,
resize_method=hps.resize_method,
use_mapping_network=hps.do_mapping_network,
cond_layers=hps.cond_layers,
map_cond=hps.map_cond)
dis_model = Discriminator(current_res_w, equalized_lr=hps.do_equalized_lr,
do_minibatch_stddev=hps.do_minibatch_stddev,
end_shape=(hps.start_res_h, hps.start_res_w),
resize_method=hps.resize_method, cgan_nclasses=num_classes,
label_list=label_list)
if use_ema_sampling:
sampling_model = Generator(current_res_w, hps.res_w, use_pixel_norm=hps.do_pixel_norm,
start_shape=(hps.start_res_h, hps.start_res_w),
equalized_lr=hps.do_equalized_lr,
traditional_input=hps.do_traditional_input,
add_noise=hps.do_add_noise,
resize_method=hps.resize_method,
use_mapping_network=hps.do_mapping_network,
cond_layers=hps.cond_layers,
map_cond=hps.map_cond)
return gen_model, mapping_network, dis_model, sampling_model
else:
return gen_model, mapping_network, dis_model
示例3: __init__
# 需要導入模塊: import models [as 別名]
# 或者: from models import Generator [as 別名]
def __init__(self,
device,
model,
model_num_labels,
image_nc,
box_min,
box_max,
model_path):
output_nc = image_nc
self.device = device
self.model_num_labels = model_num_labels
self.model = model
self.input_nc = image_nc
self.output_nc = output_nc
self.box_min = box_min
self.box_max = box_max
self.model_path = model_path
self.gen_input_nc = image_nc
self.netG = models.Generator(self.gen_input_nc, image_nc).to(device)
self.netDisc = models.Discriminator(image_nc).to(device)
# initialize all weights
self.netG.apply(weights_init)
self.netDisc.apply(weights_init)
# initialize optimizers
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
lr=0.001)
self.optimizer_D = torch.optim.Adam(self.netDisc.parameters(),
lr=0.001)
開發者ID:PacktPublishing,項目名稱:Hands-On-Generative-Adversarial-Networks-with-PyTorch-1.x,代碼行數:33,代碼來源:advGAN.py