本文整理汇总了Python中models.Generator方法的典型用法代码示例。如果您正苦于以下问题:Python models.Generator方法的具体用法?Python models.Generator怎么用?Python models.Generator使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类models
的用法示例。
在下文中一共展示了models.Generator方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train
# 需要导入模块: import models [as 别名]
# 或者: from models import Generator [as 别名]
def train(args):
nz = args.nz
batch_size = args.batch_size
epochs = args.epochs
gpu = args.gpu
# CIFAR-10 images in range [-1, 1] (tanh generator outputs)
train, _ = datasets.get_cifar10(withlabel=False, ndim=3, scale=2)
train -= 1.0
train_iter = iterators.SerialIterator(train, batch_size)
z_iter = RandomNoiseIterator(GaussianNoiseGenerator(0, 1, args.nz),
batch_size)
optimizer_generator = optimizers.RMSprop(lr=0.00005)
optimizer_critic = optimizers.RMSprop(lr=0.00005)
optimizer_generator.setup(Generator())
optimizer_critic.setup(Critic())
updater = WassersteinGANUpdater(
iterator=train_iter,
noise_iterator=z_iter,
optimizer_generator=optimizer_generator,
optimizer_critic=optimizer_critic,
device=gpu)
trainer = training.Trainer(updater, stop_trigger=(epochs, 'epoch'))
trainer.extend(extensions.ProgressBar())
trainer.extend(extensions.LogReport(trigger=(1, 'iteration')))
trainer.extend(GeneratorSample(), trigger=(1, 'epoch'))
trainer.extend(extensions.PrintReport(['epoch', 'iteration', 'critic/loss',
'critic/loss/real', 'critic/loss/fake', 'generator/loss']))
trainer.run()
示例2: build_models
# 需要导入模块: import models [as 别名]
# 或者: from models import Generator [as 别名]
def build_models(hps, current_res_w, use_ema_sampling=False, num_classes=None, label_list=None): # todo: fix num_classes
mapping_network = MappingNetwork() if hps.do_mapping_network else None
gen_model = Generator(current_res_w, hps.res_w, use_pixel_norm=hps.do_pixel_norm,
start_shape=(hps.start_res_h, hps.start_res_w),
equalized_lr=hps.do_equalized_lr,
traditional_input=hps.do_traditional_input,
add_noise=hps.do_add_noise,
resize_method=hps.resize_method,
use_mapping_network=hps.do_mapping_network,
cond_layers=hps.cond_layers,
map_cond=hps.map_cond)
dis_model = Discriminator(current_res_w, equalized_lr=hps.do_equalized_lr,
do_minibatch_stddev=hps.do_minibatch_stddev,
end_shape=(hps.start_res_h, hps.start_res_w),
resize_method=hps.resize_method, cgan_nclasses=num_classes,
label_list=label_list)
if use_ema_sampling:
sampling_model = Generator(current_res_w, hps.res_w, use_pixel_norm=hps.do_pixel_norm,
start_shape=(hps.start_res_h, hps.start_res_w),
equalized_lr=hps.do_equalized_lr,
traditional_input=hps.do_traditional_input,
add_noise=hps.do_add_noise,
resize_method=hps.resize_method,
use_mapping_network=hps.do_mapping_network,
cond_layers=hps.cond_layers,
map_cond=hps.map_cond)
return gen_model, mapping_network, dis_model, sampling_model
else:
return gen_model, mapping_network, dis_model
示例3: __init__
# 需要导入模块: import models [as 别名]
# 或者: from models import Generator [as 别名]
def __init__(self,
device,
model,
model_num_labels,
image_nc,
box_min,
box_max,
model_path):
output_nc = image_nc
self.device = device
self.model_num_labels = model_num_labels
self.model = model
self.input_nc = image_nc
self.output_nc = output_nc
self.box_min = box_min
self.box_max = box_max
self.model_path = model_path
self.gen_input_nc = image_nc
self.netG = models.Generator(self.gen_input_nc, image_nc).to(device)
self.netDisc = models.Discriminator(image_nc).to(device)
# initialize all weights
self.netG.apply(weights_init)
self.netDisc.apply(weights_init)
# initialize optimizers
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
lr=0.001)
self.optimizer_D = torch.optim.Adam(self.netDisc.parameters(),
lr=0.001)
开发者ID:PacktPublishing,项目名称:Hands-On-Generative-Adversarial-Networks-with-PyTorch-1.x,代码行数:33,代码来源:advGAN.py