本文整理汇总了Python中torchvision.utils.save_image函数的典型用法代码示例。如果您正苦于以下问题:Python save_image函数的具体用法?Python save_image怎么用?Python save_image使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了save_image函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test
def test(self):
"""Translate images using StarGAN trained on a single dataset."""
# Load the trained generator.
self.restore_model(self.test_iters)
# Set data loader.
if self.dataset == 'CelebA':
data_loader = self.celeba_loader
elif self.dataset == 'RaFD':
data_loader = self.rafd_loader
with torch.no_grad():
for i, (x_real, c_org) in enumerate(data_loader):
# Prepare input images and target domain labels.
x_real = x_real.to(self.device)
c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
# Translate images.
x_fake_list = [x_real]
for c_trg in c_trg_list:
x_fake_list.append(self.G(x_real, c_trg))
# Save the translated images.
x_concat = torch.cat(x_fake_list, dim=3)
result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
print('Saved real and fake images into {}...'.format(result_path))
示例2: test_multi
def test_multi(self):
"""Translate images using StarGAN trained on multiple datasets."""
# Load the trained generator.
self.restore_model(self.test_iters)
with torch.no_grad():
for i, (x_real, c_org) in enumerate(self.celeba_loader):
# Prepare input images and target domain labels.
x_real = x_real.to(self.device)
c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs)
c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
zero_celeba = torch.zeros(x_real.size(0), self.c_dim).to(self.device) # Zero vector for CelebA.
zero_rafd = torch.zeros(x_real.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD.
mask_celeba = self.label2onehot(torch.zeros(x_real.size(0)), 2).to(self.device) # Mask vector: [1, 0].
mask_rafd = self.label2onehot(torch.ones(x_real.size(0)), 2).to(self.device) # Mask vector: [0, 1].
# Translate images.
x_fake_list = [x_real]
for c_celeba in c_celeba_list:
c_trg = torch.cat([c_celeba, zero_rafd, mask_celeba], dim=1)
x_fake_list.append(self.G(x_real, c_trg))
for c_rafd in c_rafd_list:
c_trg = torch.cat([zero_celeba, c_rafd, mask_rafd], dim=1)
x_fake_list.append(self.G(x_real, c_trg))
# Save the translated images.
x_concat = torch.cat(x_fake_list, dim=3)
result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
print('Saved real and fake images into {}...'.format(result_path))
示例3: sampleTrue
def sampleTrue(dataset, imageSize, dataroot, sampleSize, batchSize, saveFolder, workers=4):
print('sampling real images ...')
saveFolder = saveFolder + '0/'
dataset = make_dataset(dataset, dataroot, imageSize)
dataloader = torch.utils.data.DataLoader(
dataset, shuffle=True, batch_size=batchSize, num_workers=int(workers))
if not os.path.exists(saveFolder):
try:
os.makedirs(saveFolder)
except OSError:
pass
iter = 0
for i, data in enumerate(dataloader, 0):
img, _ = data
for j in range(0, len(img)):
vutils.save_image(img[j].mul(0.5).add(
0.5), saveFolder + giveName(iter) + ".png")
iter += 1
if iter >= sampleSize:
break
if iter >= sampleSize:
break
示例4: save_img_results
def save_img_results(imgs_tcpu, fake_imgs, num_imgs,
count, image_dir, summary_writer):
num = cfg.TRAIN.VIS_COUNT
# The range of real_img (i.e., self.imgs_tcpu[i][0:num])
# is changed to [0, 1] by function vutils.save_image
real_img = imgs_tcpu[-1][0:num]
vutils.save_image(
real_img, '%s/real_samples.png' % (image_dir),
normalize=True)
real_img_set = vutils.make_grid(real_img).numpy()
real_img_set = np.transpose(real_img_set, (1, 2, 0))
real_img_set = real_img_set * 255
real_img_set = real_img_set.astype(np.uint8)
sup_real_img = summary.image('real_img', real_img_set)
summary_writer.add_summary(sup_real_img, count)
for i in range(num_imgs):
fake_img = fake_imgs[i][0:num]
# The range of fake_img.data (i.e., self.fake_imgs[i][0:num])
# is still [-1. 1]...
vutils.save_image(
fake_img.data, '%s/count_%09d_fake_samples%d.png' %
(image_dir, count, i), normalize=True)
fake_img_set = vutils.make_grid(fake_img.data).cpu().numpy()
fake_img_set = np.transpose(fake_img_set, (1, 2, 0))
fake_img_set = (fake_img_set + 1) * 255 / 2
fake_img_set = fake_img_set.astype(np.uint8)
sup_fake_img = summary.image('fake_img%d' % i, fake_img_set)
summary_writer.add_summary(sup_fake_img, count)
summary_writer.flush()
示例5: plot_rec
def plot_rec(x, netEC, netEP, netD):
x_c = x[0]
x_p = x[np.random.randint(1, opt.max_step)]
h_c = netEC(x_c)
h_p = netEP(x_p)
# print('h_c shape: ', h_c.shape)
# print('h p shape: ', h_p.shape)
rec = netD([h_c, h_p])
x_c, x_p, rec = x_c.data, x_p.data, rec.data
fname = '%s/rec/rec_test.png' % (opt.log_dir)
comparison = None
for i in range(len(x_c)):
if comparison is None:
comparison = torch.stack([x_c[i], x_p[i], rec[i]])
else:
new_comparison = torch.stack([x_c[i], x_p[i], rec[i]])
comparison = torch.cat([comparison, new_comparison])
print('comparison: ', comparison.shape)
# row_sz = 5
# nplot = 20
# for i in range(0, nplot - row_sz, row_sz):
# row = [[xc, xp, xr] for xc, xp, xr in zip(x_c[i:i + row_sz], x_p[i:i + row_sz], rec[i:i + row_sz])]
# print('row: ', row)
# to_plot.append(list(itertools.chain(*row)))
# print(len(to_plot[0]))
# utils.save_tensors_image(fname, comparison)
if not os.path.exists(os.path.dirname(fname)):
os.makedirs(os.path.dirname(fname))
save_image(comparison.cpu(), fname, nrow=3)
示例6: test
def test(self):
"""Facial attribute transfer on CelebA or facial expression synthesis on RaFD."""
# Load trained parameters
G_path = os.path.join(self.model_save_path, '{}_G.pth'.format(self.test_model))
self.G.load_state_dict(torch.load(G_path))
self.G.eval()
if self.dataset == 'CelebA':
data_loader = self.celebA_loader
else:
data_loader = self.rafd_loader
for i, (real_x, org_c) in enumerate(data_loader):
real_x = self.to_var(real_x, volatile=True)
if self.dataset == 'CelebA':
target_c_list = self.make_celeb_labels(org_c)
else:
target_c_list = []
for j in range(self.c_dim):
target_c = self.one_hot(torch.ones(real_x.size(0)) * j, self.c_dim)
target_c_list.append(self.to_var(target_c, volatile=True))
# Start translations
fake_image_list = [real_x]
for target_c in target_c_list:
fake_image_list.append(self.G(real_x, target_c))
fake_images = torch.cat(fake_image_list, dim=3)
save_path = os.path.join(self.result_path, '{}_fake.png'.format(i+1))
save_image(self.denorm(fake_images.data), save_path, nrow=1, padding=0)
print('Translated test images and saved into "{}"..!'.format(save_path))
示例7: saver
def saver(state):
if state[torchbearer.BATCH] == 0:
data = state[torchbearer.X]
recon_batch = state[torchbearer.Y_PRED]
comparison = torch.cat([data[:num_images],
recon_batch.view(128, 1, 28, 28)[:num_images]])
save_image(comparison.cpu(),
str(folder) + 'reconstruction_' + str(state[torchbearer.EPOCH]) + '.png', nrow=num_images)
示例8: sample_image
def sample_image(n_row, batches_done):
"""Saves a grid of generated digits ranging from 0 to n_classes"""
# Sample noise
z = Variable(FloatTensor(np.random.normal(0, 1, (n_row**2, opt.latent_dim))))
# Get labels ranging from 0 to n_classes for n rows
labels = np.array([num for _ in range(n_row) for num in range(n_row)])
labels = Variable(LongTensor(labels))
gen_imgs = generator(z, labels)
save_image(gen_imgs.data, 'images/%d.png' % batches_done, nrow=n_row, normalize=True)
示例9: _train
def _train(self, epoch):
"""Perform the actual train."""
# put model into train mode
self.d_model.train()
# TODO: why?
cp_loader = deepcopy(self.train_loader)
if self.verbose:
progress_bar = tqdm(total=len(cp_loader),
desc='Current Epoch',
file=sys.stdout,
leave=False,
ncols=75,
position=0,
unit=' Batch')
else:
progress_bar = None
real_label = 1
fake_label = 0
for batch_idx, inputs in enumerate(cp_loader):
# Update Discriminator network maximize log(D(x)) + log(1 - D(G(z)))
# train with real
self.optimizer_d.zero_grad()
inputs = inputs.to(self.device)
batch_size = inputs.size(0)
outputs = self.d_model(inputs)
label = torch.full((batch_size,), real_label, device=self.device)
loss_d_real = self.loss_function(outputs, label)
loss_d_real.backward()
# train with fake
noise = torch.randn((batch_size, self.g_model.nz, 1, 1,), device=self.device)
fake_outputs = self.g_model(noise)
label.fill_(fake_label)
outputs = self.d_model(fake_outputs.detach())
loss_g_fake = self.loss_function(outputs, label)
loss_g_fake.backward()
self.optimizer_d.step()
# (2) Update G network: maximize log(D(G(z)))
self.g_model.zero_grad()
label.fill_(real_label)
outputs = self.d_model(fake_outputs)
loss_g = self.loss_function(outputs, label)
loss_g.backward()
self.optimizer_g.step()
if self.verbose:
if batch_idx % 10 == 0:
progress_bar.update(10)
if self.out_f is not None and batch_idx % 100 == 0:
fake = self.g_model(self.sample_noise)
vutils.save_image(
fake.detach(),
'%s/fake_samples_epoch_%03d.png' % (self.out_f, epoch),
normalize=True)
if self.verbose:
progress_bar.close()
示例10: sample_images
def sample_images(batches_done):
"""Saves a generated sample from the test set"""
imgs = next(iter(val_dataloader))
real_A = Variable(imgs['A'].type(Tensor))
fake_B = G_AB(real_A)
real_B = Variable(imgs['B'].type(Tensor))
fake_A = G_BA(real_B)
img_sample = torch.cat((real_A.data, fake_B.data,
real_B.data, fake_A.data), 0)
save_image(img_sample, 'images/%s/%s.png' % (opt.dataset_name, batches_done), nrow=5, normalize=True)
示例11: save_image
def save_image(img):
post = transforms.Compose([transforms.Lambda(lambda x: x.mul_(1./255)),
transforms.Normalize(mean=[-0.40760392, -0.45795686, -0.48501961], std=[1,1,1]),
transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]), #turn to RGB
])
img = post(img)
img = img.clamp_(0,1)
vutils.save_image(img,
'%s/transfer.png' % (opt.outf),
normalize=True)
return
示例12: reconstruction_loss
def reconstruction_loss(self, images, input, size_average=True):
# Get the lengths of capsule outputs.
v_mag = torch.sqrt((input**2).sum(dim=2))
# Get index of longest capsule output.
_, v_max_index = v_mag.max(dim=1)
v_max_index = v_max_index.data
# Use just the winning capsule's representation (and zeros for other capsules) to reconstruct input image.
batch_size = input.size(0)
all_masked = [None] * batch_size
for batch_idx in range(batch_size):
# Get one sample from the batch.
input_batch = input[batch_idx]
# Copy only the maximum capsule index from this batch sample.
# This masks out (leaves as zero) the other capsules in this sample.
batch_masked = Variable(torch.zeros(input_batch.size())).cuda()
batch_masked[v_max_index[batch_idx]] = input_batch[v_max_index[batch_idx]]
all_masked[batch_idx] = batch_masked
# Stack masked capsules over the batch dimension.
masked = torch.stack(all_masked, dim=0)
# Reconstruct input image.
masked = masked.view(input.size(0), -1)
output = self.relu(self.reconstruct0(masked))
output = self.relu(self.reconstruct1(output))
output = self.sigmoid(self.reconstruct2(output))
output = output.view(-1, self.image_channels, self.image_height, self.image_width)
# Save reconstructed images occasionally.
if self.reconstructed_image_count % 10 == 0:
if output.size(1) == 2:
# handle two-channel images
zeros = torch.zeros(output.size(0), 1, output.size(2), output.size(3))
output_image = torch.cat([zeros, output.data.cpu()], dim=1)
else:
# assume RGB or grayscale
output_image = output.data.cpu()
vutils.save_image(output_image, "reconstruction.png")
self.reconstructed_image_count += 1
# The reconstruction loss is the sum squared difference between the input image and reconstructed image.
# Multiplied by a small number so it doesn't dominate the margin (class) loss.
error = (output - images).view(output.size(0), -1)
error = error**2
error = torch.sum(error, dim=1) * 0.0005
# Average over batch
if size_average:
error = error.mean()
return error
示例13: sample_images
def sample_images(batches_done):
"""Saves a generated sample from the test set"""
imgs = next(iter(val_dataloader))
X1 = Variable(imgs['A'].type(Tensor))
X2 = Variable(imgs['B'].type(Tensor))
_, Z1 = E1(X1)
_, Z2 = E2(X2)
fake_X1 = G1(Z2)
fake_X2 = G2(Z1)
img_sample = torch.cat((X1.data, fake_X2.data,
X2.data, fake_X1.data), 0)
save_image(img_sample, 'images/%s/%s.png' % (opt.dataset_name, batches_done), nrow=5, normalize=True)
示例14: save_images
def save_images(netG, fixed_noise, outputDir, epoch):
'''
Generates a batch of images from the given 'noise'.
Saves 64 of the generated samples to 'outputDir' system path.
Inputs are the network (netG), a 'noise' input, system path to which images will be saved (outputDir) and current 'epoch'.
'''
noise = Variable(fixed_noise)
netG.eval()
fake = netG(noise)
netG.train()
vutils.save_image(
fake.data[0:64, :, :, :], '%s/fake_samples_epoch_%03d.png' % (outputDir, epoch), nrow=8)
示例15: reconstruct_test
def reconstruct_test(self,epoch):
for i,batch in enumerate(self.test_dataloader):
images = batch['image']
images = images.float()
bumps = batch['bump']
bumps = bumps.float()
masks = batch['mask']
masks = masks.float()
images = Variable(images.cuda())
recon_mask, recon = self.Gnet.forward(images)
output = torch.cat((masks,recon_mask.data.cpu(),bumps,recon.data.cpu()),dim=3)
utils.save_image(output, net.outpath + '/'+str(epoch)+'.'+str(i)+'.jpg',nrow=4, normalize=True)