当前位置: 首页>>代码示例>>Python>>正文


Python utils.save_image函数代码示例

本文整理汇总了Python中torchvision.utils.save_image函数的典型用法代码示例。如果您正苦于以下问题:Python save_image函数的具体用法?Python save_image怎么用?Python save_image使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了save_image函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test

    def test(self):
        """Translate images using StarGAN trained on a single dataset."""
        # Load the trained generator.
        self.restore_model(self.test_iters)
        
        # Set data loader.
        if self.dataset == 'CelebA':
            data_loader = self.celeba_loader
        elif self.dataset == 'RaFD':
            data_loader = self.rafd_loader
        
        with torch.no_grad():
            for i, (x_real, c_org) in enumerate(data_loader):

                # Prepare input images and target domain labels.
                x_real = x_real.to(self.device)
                c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)

                # Translate images.
                x_fake_list = [x_real]
                for c_trg in c_trg_list:
                    x_fake_list.append(self.G(x_real, c_trg))

                # Save the translated images.
                x_concat = torch.cat(x_fake_list, dim=3)
                result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
                save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
                print('Saved real and fake images into {}...'.format(result_path))
开发者ID:JacobLee121,项目名称:StarGAN,代码行数:28,代码来源:solver.py

示例2: test_multi

    def test_multi(self):
        """Translate images using StarGAN trained on multiple datasets."""
        # Load the trained generator.
        self.restore_model(self.test_iters)
        
        with torch.no_grad():
            for i, (x_real, c_org) in enumerate(self.celeba_loader):

                # Prepare input images and target domain labels.
                x_real = x_real.to(self.device)
                c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs)
                c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
                zero_celeba = torch.zeros(x_real.size(0), self.c_dim).to(self.device)            # Zero vector for CelebA.
                zero_rafd = torch.zeros(x_real.size(0), self.c2_dim).to(self.device)             # Zero vector for RaFD.
                mask_celeba = self.label2onehot(torch.zeros(x_real.size(0)), 2).to(self.device)  # Mask vector: [1, 0].
                mask_rafd = self.label2onehot(torch.ones(x_real.size(0)), 2).to(self.device)     # Mask vector: [0, 1].

                # Translate images.
                x_fake_list = [x_real]
                for c_celeba in c_celeba_list:
                    c_trg = torch.cat([c_celeba, zero_rafd, mask_celeba], dim=1)
                    x_fake_list.append(self.G(x_real, c_trg))
                for c_rafd in c_rafd_list:
                    c_trg = torch.cat([zero_celeba, c_rafd, mask_rafd], dim=1)
                    x_fake_list.append(self.G(x_real, c_trg))

                # Save the translated images.
                x_concat = torch.cat(x_fake_list, dim=3)
                result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
                save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
                print('Saved real and fake images into {}...'.format(result_path))
开发者ID:JacobLee121,项目名称:StarGAN,代码行数:31,代码来源:solver.py

示例3: sampleTrue

def sampleTrue(dataset, imageSize, dataroot, sampleSize, batchSize, saveFolder, workers=4):
    print('sampling real images ...')
    saveFolder = saveFolder + '0/'

    dataset = make_dataset(dataset, dataroot, imageSize)
    dataloader = torch.utils.data.DataLoader(
        dataset, shuffle=True, batch_size=batchSize, num_workers=int(workers))

    if not os.path.exists(saveFolder):
        try:
            os.makedirs(saveFolder)
        except OSError:
            pass

    iter = 0
    for i, data in enumerate(dataloader, 0):
        img, _ = data
        for j in range(0, len(img)):

            vutils.save_image(img[j].mul(0.5).add(
                0.5), saveFolder + giveName(iter) + ".png")
            iter += 1
            if iter >= sampleSize:
                break
        if iter >= sampleSize:
            break
开发者ID:RobinROAR,项目名称:TensorflowTutorialsCode,代码行数:26,代码来源:metric.py

示例4: save_img_results

def save_img_results(imgs_tcpu, fake_imgs, num_imgs,
                     count, image_dir, summary_writer):
    num = cfg.TRAIN.VIS_COUNT

    # The range of real_img (i.e., self.imgs_tcpu[i][0:num])
    # is changed to [0, 1] by function vutils.save_image
    real_img = imgs_tcpu[-1][0:num]
    vutils.save_image(
        real_img, '%s/real_samples.png' % (image_dir),
        normalize=True)
    real_img_set = vutils.make_grid(real_img).numpy()
    real_img_set = np.transpose(real_img_set, (1, 2, 0))
    real_img_set = real_img_set * 255
    real_img_set = real_img_set.astype(np.uint8)
    sup_real_img = summary.image('real_img', real_img_set)
    summary_writer.add_summary(sup_real_img, count)

    for i in range(num_imgs):
        fake_img = fake_imgs[i][0:num]
        # The range of fake_img.data (i.e., self.fake_imgs[i][0:num])
        # is still [-1. 1]...
        vutils.save_image(
            fake_img.data, '%s/count_%09d_fake_samples%d.png' %
            (image_dir, count, i), normalize=True)

        fake_img_set = vutils.make_grid(fake_img.data).cpu().numpy()

        fake_img_set = np.transpose(fake_img_set, (1, 2, 0))
        fake_img_set = (fake_img_set + 1) * 255 / 2
        fake_img_set = fake_img_set.astype(np.uint8)

        sup_fake_img = summary.image('fake_img%d' % i, fake_img_set)
        summary_writer.add_summary(sup_fake_img, count)
        summary_writer.flush()
开发者ID:tensoralex,项目名称:StackGAN-v2,代码行数:34,代码来源:trainer.py

示例5: plot_rec

def plot_rec(x, netEC, netEP, netD):
    x_c = x[0]
    x_p = x[np.random.randint(1, opt.max_step)]

    h_c = netEC(x_c)
    h_p = netEP(x_p)

    # print('h_c shape: ', h_c.shape)
    # print('h p shape: ', h_p.shape)
    rec = netD([h_c, h_p])

    x_c, x_p, rec = x_c.data, x_p.data, rec.data
    fname = '%s/rec/rec_test.png' % (opt.log_dir)

    comparison = None
    for i in range(len(x_c)):
        if comparison is None:
            comparison = torch.stack([x_c[i], x_p[i], rec[i]])
        else:
            new_comparison = torch.stack([x_c[i], x_p[i], rec[i]])
            comparison = torch.cat([comparison, new_comparison])
    print('comparison: ', comparison.shape)

    # row_sz = 5
    # nplot = 20
    # for i in range(0, nplot - row_sz, row_sz):
    #     row = [[xc, xp, xr] for xc, xp, xr in zip(x_c[i:i + row_sz], x_p[i:i + row_sz], rec[i:i + row_sz])]
    #     print('row: ', row)
    #     to_plot.append(list(itertools.chain(*row)))
    # print(len(to_plot[0]))
    # utils.save_tensors_image(fname, comparison)
    if not os.path.exists(os.path.dirname(fname)):
        os.makedirs(os.path.dirname(fname))
    save_image(comparison.cpu(), fname, nrow=3)
开发者ID:ZhenyueQin,项目名称:drnet-py,代码行数:34,代码来源:drnet_test_field.py

示例6: test

    def test(self):
        """Facial attribute transfer on CelebA or facial expression synthesis on RaFD."""
        # Load trained parameters
        G_path = os.path.join(self.model_save_path, '{}_G.pth'.format(self.test_model))
        self.G.load_state_dict(torch.load(G_path))
        self.G.eval()

        if self.dataset == 'CelebA':
            data_loader = self.celebA_loader
        else:
            data_loader = self.rafd_loader

        for i, (real_x, org_c) in enumerate(data_loader):
            real_x = self.to_var(real_x, volatile=True)

            if self.dataset == 'CelebA':
                target_c_list = self.make_celeb_labels(org_c)
            else:
                target_c_list = []
                for j in range(self.c_dim):
                    target_c = self.one_hot(torch.ones(real_x.size(0)) * j, self.c_dim)
                    target_c_list.append(self.to_var(target_c, volatile=True))

            # Start translations
            fake_image_list = [real_x]
            for target_c in target_c_list:
                fake_image_list.append(self.G(real_x, target_c))
            fake_images = torch.cat(fake_image_list, dim=3)
            save_path = os.path.join(self.result_path, '{}_fake.png'.format(i+1))
            save_image(self.denorm(fake_images.data), save_path, nrow=1, padding=0)
            print('Translated test images and saved into "{}"..!'.format(save_path))
开发者ID:rafalsc,项目名称:StarGAN,代码行数:31,代码来源:solver.py

示例7: saver

 def saver(state):
     if state[torchbearer.BATCH] == 0:
         data = state[torchbearer.X]
         recon_batch = state[torchbearer.Y_PRED]
         comparison = torch.cat([data[:num_images],
                                 recon_batch.view(128, 1, 28, 28)[:num_images]])
         save_image(comparison.cpu(),
                    str(folder) + 'reconstruction_' + str(state[torchbearer.EPOCH]) + '.png', nrow=num_images)
开发者ID:little1tow,项目名称:torchbearer,代码行数:8,代码来源:vae.py

示例8: sample_image

def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row**2, opt.latent_dim))))
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = Variable(LongTensor(labels))
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, 'images/%d.png' % batches_done, nrow=n_row, normalize=True)
开发者ID:hjpwhu,项目名称:PyTorch-GAN,代码行数:9,代码来源:cgan.py

示例9: _train

    def _train(self, epoch):
        """Perform the actual train."""
        # put model into train mode
        self.d_model.train()
        # TODO: why?
        cp_loader = deepcopy(self.train_loader)
        if self.verbose:
            progress_bar = tqdm(total=len(cp_loader),
                                desc='Current Epoch',
                                file=sys.stdout,
                                leave=False,
                                ncols=75,
                                position=0,
                                unit=' Batch')
        else:
            progress_bar = None
        real_label = 1
        fake_label = 0
        for batch_idx, inputs in enumerate(cp_loader):
            # Update Discriminator network maximize log(D(x)) + log(1 - D(G(z)))
            # train with real
            self.optimizer_d.zero_grad()
            inputs = inputs.to(self.device)
            batch_size = inputs.size(0)
            outputs = self.d_model(inputs)

            label = torch.full((batch_size,), real_label, device=self.device)
            loss_d_real = self.loss_function(outputs, label)
            loss_d_real.backward()

            # train with fake
            noise = torch.randn((batch_size, self.g_model.nz, 1, 1,), device=self.device)
            fake_outputs = self.g_model(noise)
            label.fill_(fake_label)
            outputs = self.d_model(fake_outputs.detach())
            loss_g_fake = self.loss_function(outputs, label)
            loss_g_fake.backward()
            self.optimizer_d.step()
            # (2) Update G network: maximize log(D(G(z)))
            self.g_model.zero_grad()
            label.fill_(real_label)
            outputs = self.d_model(fake_outputs)
            loss_g = self.loss_function(outputs, label)
            loss_g.backward()
            self.optimizer_g.step()

            if self.verbose:
                if batch_idx % 10 == 0:
                    progress_bar.update(10)
            if self.out_f is not None and batch_idx % 100 == 0:
                fake = self.g_model(self.sample_noise)
                vutils.save_image(
                    fake.detach(),
                    '%s/fake_samples_epoch_%03d.png' % (self.out_f, epoch),
                    normalize=True)
        if self.verbose:
            progress_bar.close()
开发者ID:Saiuz,项目名称:autokeras,代码行数:57,代码来源:model_trainer.py

示例10: sample_images

def sample_images(batches_done):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    real_A = Variable(imgs['A'].type(Tensor))
    fake_B = G_AB(real_A)
    real_B = Variable(imgs['B'].type(Tensor))
    fake_A = G_BA(real_B)
    img_sample = torch.cat((real_A.data, fake_B.data,
                            real_B.data, fake_A.data), 0)
    save_image(img_sample, 'images/%s/%s.png' % (opt.dataset_name, batches_done), nrow=5, normalize=True)
开发者ID:hjpwhu,项目名称:PyTorch-GAN,代码行数:10,代码来源:cyclegan.py

示例11: save_image

def save_image(img):
    post = transforms.Compose([transforms.Lambda(lambda x: x.mul_(1./255)),
         transforms.Normalize(mean=[-0.40760392, -0.45795686, -0.48501961], std=[1,1,1]),
         transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]), #turn to RGB
         ])
    img = post(img)
    img = img.clamp_(0,1)
    vutils.save_image(img,
                '%s/transfer.png' % (opt.outf),
                normalize=True)
    return
开发者ID:HadXu,项目名称:machine-learning,代码行数:11,代码来源:train.py

示例12: reconstruction_loss

    def reconstruction_loss(self, images, input, size_average=True):
        # Get the lengths of capsule outputs.
        v_mag = torch.sqrt((input**2).sum(dim=2))

        # Get index of longest capsule output.
        _, v_max_index = v_mag.max(dim=1)
        v_max_index = v_max_index.data

        # Use just the winning capsule's representation (and zeros for other capsules) to reconstruct input image.
        batch_size = input.size(0)
        all_masked = [None] * batch_size
        for batch_idx in range(batch_size):
            # Get one sample from the batch.
            input_batch = input[batch_idx]

            # Copy only the maximum capsule index from this batch sample.
            # This masks out (leaves as zero) the other capsules in this sample.
            batch_masked = Variable(torch.zeros(input_batch.size())).cuda()
            batch_masked[v_max_index[batch_idx]] = input_batch[v_max_index[batch_idx]]
            all_masked[batch_idx] = batch_masked

        # Stack masked capsules over the batch dimension.
        masked = torch.stack(all_masked, dim=0)

        # Reconstruct input image.
        masked = masked.view(input.size(0), -1)
        output = self.relu(self.reconstruct0(masked))
        output = self.relu(self.reconstruct1(output))
        output = self.sigmoid(self.reconstruct2(output))
        output = output.view(-1, self.image_channels, self.image_height, self.image_width)

        # Save reconstructed images occasionally.
        if self.reconstructed_image_count % 10 == 0:
            if output.size(1) == 2:
                # handle two-channel images
                zeros = torch.zeros(output.size(0), 1, output.size(2), output.size(3))
                output_image = torch.cat([zeros, output.data.cpu()], dim=1)
            else:
                # assume RGB or grayscale
                output_image = output.data.cpu()
            vutils.save_image(output_image, "reconstruction.png")
        self.reconstructed_image_count += 1

        # The reconstruction loss is the sum squared difference between the input image and reconstructed image.
        # Multiplied by a small number so it doesn't dominate the margin (class) loss.
        error = (output - images).view(output.size(0), -1)
        error = error**2
        error = torch.sum(error, dim=1) * 0.0005

        # Average over batch
        if size_average:
            error = error.mean()

        return error
开发者ID:weridmaid,项目名称:pytorch-capsule,代码行数:54,代码来源:capsule_network.py

示例13: sample_images

def sample_images(batches_done):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    X1 = Variable(imgs['A'].type(Tensor))
    X2 = Variable(imgs['B'].type(Tensor))
    _, Z1 = E1(X1)
    _, Z2 = E2(X2)
    fake_X1 = G1(Z2)
    fake_X2 = G2(Z1)
    img_sample = torch.cat((X1.data, fake_X2.data,
                            X2.data, fake_X1.data), 0)
    save_image(img_sample, 'images/%s/%s.png' % (opt.dataset_name, batches_done), nrow=5, normalize=True)
开发者ID:hjpwhu,项目名称:PyTorch-GAN,代码行数:12,代码来源:unit.py

示例14: save_images

def save_images(netG, fixed_noise, outputDir, epoch):
    '''
    Generates a batch of images from the given 'noise'.
    Saves 64 of the generated samples to 'outputDir' system path.
    Inputs are the network (netG), a 'noise' input, system path to which images will be saved (outputDir) and current 'epoch'.
    '''
    noise = Variable(fixed_noise)
    netG.eval()
    fake = netG(noise)
    netG.train()
    vutils.save_image(
        fake.data[0:64, :, :, :], '%s/fake_samples_epoch_%03d.png' % (outputDir, epoch), nrow=8)
开发者ID:apsvieira,项目名称:Projeto_final_ia368z,代码行数:12,代码来源:classifier_DCGAN_review.py

示例15: reconstruct_test

 def reconstruct_test(self,epoch):
     for i,batch in enumerate(self.test_dataloader):
         images = batch['image']
         images = images.float()
         bumps = batch['bump']
         bumps = bumps.float()
         masks = batch['mask']
         masks = masks.float()
         images = Variable(images.cuda())
         recon_mask, recon = self.Gnet.forward(images)
         output = torch.cat((masks,recon_mask.data.cpu(),bumps,recon.data.cpu()),dim=3)
         utils.save_image(output, net.outpath + '/'+str(epoch)+'.'+str(i)+'.jpg',nrow=4, normalize=True)
开发者ID:whztt07,项目名称:extreme_3d_faces,代码行数:12,代码来源:bumpMapRegressor.py


注:本文中的torchvision.utils.save_image函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。