當前位置: 首頁>>代碼示例>>Python>>正文


Python models.Generator方法代碼示例

本文整理匯總了Python中models.Generator方法的典型用法代碼示例。如果您正苦於以下問題:Python models.Generator方法的具體用法?Python models.Generator怎麽用?Python models.Generator使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在models的用法示例。


在下文中一共展示了models.Generator方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: train

# 需要導入模塊: import models [as 別名]
# 或者: from models import Generator [as 別名]
def train(args):
    nz = args.nz
    batch_size = args.batch_size
    epochs = args.epochs
    gpu = args.gpu

    # CIFAR-10 images in range [-1, 1] (tanh generator outputs)
    train, _ = datasets.get_cifar10(withlabel=False, ndim=3, scale=2)
    train -= 1.0
    train_iter = iterators.SerialIterator(train, batch_size)

    z_iter = RandomNoiseIterator(GaussianNoiseGenerator(0, 1, args.nz),
                                 batch_size)

    optimizer_generator = optimizers.RMSprop(lr=0.00005)
    optimizer_critic = optimizers.RMSprop(lr=0.00005)
    optimizer_generator.setup(Generator())
    optimizer_critic.setup(Critic())

    updater = WassersteinGANUpdater(
        iterator=train_iter,
        noise_iterator=z_iter,
        optimizer_generator=optimizer_generator,
        optimizer_critic=optimizer_critic,
        device=gpu)

    trainer = training.Trainer(updater, stop_trigger=(epochs, 'epoch'))
    trainer.extend(extensions.ProgressBar())
    trainer.extend(extensions.LogReport(trigger=(1, 'iteration')))
    trainer.extend(GeneratorSample(), trigger=(1, 'epoch'))
    trainer.extend(extensions.PrintReport(['epoch', 'iteration', 'critic/loss',
            'critic/loss/real', 'critic/loss/fake', 'generator/loss']))
    trainer.run() 
開發者ID:hvy,項目名稱:chainer-wasserstein-gan,代碼行數:35,代碼來源:train.py

示例2: build_models

# 需要導入模塊: import models [as 別名]
# 或者: from models import Generator [as 別名]
def build_models(hps, current_res_w, use_ema_sampling=False, num_classes=None, label_list=None): # todo: fix num_classes
    mapping_network = MappingNetwork() if hps.do_mapping_network else None
    gen_model = Generator(current_res_w, hps.res_w, use_pixel_norm=hps.do_pixel_norm,
                          start_shape=(hps.start_res_h, hps.start_res_w),
                          equalized_lr=hps.do_equalized_lr,
                          traditional_input=hps.do_traditional_input,
                          add_noise=hps.do_add_noise,
                          resize_method=hps.resize_method,
                          use_mapping_network=hps.do_mapping_network,
                          cond_layers=hps.cond_layers,
                          map_cond=hps.map_cond)
    dis_model = Discriminator(current_res_w, equalized_lr=hps.do_equalized_lr,
                              do_minibatch_stddev=hps.do_minibatch_stddev,
                              end_shape=(hps.start_res_h, hps.start_res_w),
                              resize_method=hps.resize_method, cgan_nclasses=num_classes,
                              label_list=label_list)
    if use_ema_sampling:
        sampling_model = Generator(current_res_w, hps.res_w, use_pixel_norm=hps.do_pixel_norm,
                                   start_shape=(hps.start_res_h, hps.start_res_w),
                                   equalized_lr=hps.do_equalized_lr,
                                   traditional_input=hps.do_traditional_input,
                                   add_noise=hps.do_add_noise,
                                   resize_method=hps.resize_method,
                                   use_mapping_network=hps.do_mapping_network,
                                   cond_layers=hps.cond_layers,
                                   map_cond=hps.map_cond)
        return gen_model, mapping_network, dis_model, sampling_model
    else:
        return gen_model, mapping_network, dis_model 
開發者ID:nolan-dev,項目名稱:stylegan_reimplementation,代碼行數:31,代碼來源:train.py

示例3: __init__

# 需要導入模塊: import models [as 別名]
# 或者: from models import Generator [as 別名]
def __init__(self,
                 device,
                 model,
                 model_num_labels,
                 image_nc,
                 box_min,
                 box_max,
                 model_path):
        output_nc = image_nc
        self.device = device
        self.model_num_labels = model_num_labels
        self.model = model
        self.input_nc = image_nc
        self.output_nc = output_nc
        self.box_min = box_min
        self.box_max = box_max
        self.model_path = model_path

        self.gen_input_nc = image_nc
        self.netG = models.Generator(self.gen_input_nc, image_nc).to(device)
        self.netDisc = models.Discriminator(image_nc).to(device)

        # initialize all weights
        self.netG.apply(weights_init)
        self.netDisc.apply(weights_init)

        # initialize optimizers
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=0.001)
        self.optimizer_D = torch.optim.Adam(self.netDisc.parameters(),
                                            lr=0.001) 
開發者ID:PacktPublishing,項目名稱:Hands-On-Generative-Adversarial-Networks-with-PyTorch-1.x,代碼行數:33,代碼來源:advGAN.py


注:本文中的models.Generator方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。