本文整理汇总了Python中torchvision.datasets.LSUN属性的典型用法代码示例。如果您正苦于以下问题:Python datasets.LSUN属性的具体用法?Python datasets.LSUN怎么用?Python datasets.LSUN使用的例子?那么, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在类torchvision.datasets
的用法示例。
在下文中一共展示了datasets.LSUN属性的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_lsun_dataloader
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def get_lsun_dataloader(path_to_data='../lsun', dataset='bedroom_train',
batch_size=64):
"""LSUN dataloader with (128, 128) sized images.
path_to_data : str
One of 'bedroom_val' or 'bedroom_train'
"""
# Compose transforms
transform = transforms.Compose([
transforms.Resize(128),
transforms.CenterCrop(128),
transforms.ToTensor()
])
# Get dataset
lsun_dset = datasets.LSUN(db_path=path_to_data, classes=[dataset],
transform=transform)
# Create dataloader
return DataLoader(lsun_dset, batch_size=batch_size, shuffle=True)
示例2: get_lsun_dataloader
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def get_lsun_dataloader(path_to_data='/data/dgl/LSUN', dataset='bedroom_train',
batch_size=64):
"""LSUN dataloader with (128, 128) sized images.
path_to_data : str
One of 'bedroom_val' or 'bedroom_train'
"""
# Compose transforms
transform = transforms.Compose([
transforms.Resize(128),
transforms.CenterCrop(128),
transforms.ToTensor()
])
# Get dataset
lsun_dset = datasets.LSUN(root=path_to_data, classes=[dataset],
transform=transform)
# Create dataloader
return DataLoader(lsun_dset, batch_size=batch_size, shuffle=True)
示例3: load_lsun
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def load_lsun(self, classes=['church_outdoor_train','classroom_train']):
transforms = self.transform(True, True, True, False)
dataset = dsets.LSUN(self.path, classes=classes, transform=transforms)
return dataset
示例4: make_dataloader
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def make_dataloader(batch_size, dataset_type, data_path, shuffle=True, drop_last=True, dataloader_args={},
resize=True, imsize=128, centercrop=False, centercrop_size=128, totensor=True,
normalize=False, norm_mean=(0.5, 0.5, 0.5), norm_std=(0.5, 0.5, 0.5)):
# Make transform
transform = make_transform(resize=resize, imsize=imsize,
centercrop=centercrop, centercrop_size=centercrop_size,
totensor=totensor,
normalize=normalize, norm_mean=norm_mean, norm_std=norm_std)
# Make dataset
if dataset_type in ['folder', 'imagenet', 'lfw']:
# folder dataset
assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path
dataset = dset.ImageFolder(root=data_path, transform=transform)
elif dataset_type == 'lsun':
assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path
dataset = dset.LSUN(root=data_path, classes=['bedroom_train'], transform=transform)
elif dataset_type == 'cifar10':
if not os.path.exists(data_path):
print("data_path does not exist! Given: {}\nDownloading CIFAR10 dataset...".format(data_path))
dataset = dset.CIFAR10(root=data_path, download=True, transform=transform)
elif dataset_type == 'fake':
dataset = dset.FakeData(image_size=(3, centercrop_size, centercrop_size), transform=transforms.ToTensor())
assert dataset
num_of_classes = len(dataset.classes)
print("Data found! # of images =", len(dataset), ", # of classes =", num_of_classes, ", classes:", dataset.classes)
# Make dataloader from dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, **dataloader_args)
return dataloader, num_of_classes
示例5: __init__
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def __init__(self, options):
transform_list = []
if options.image_size is not None:
transform_list.append(transforms.Resize((options.image_size, options.image_size)))
# transform_list.append(transforms.CenterCrop(options.image_size))
transform_list.append(transforms.ToTensor())
if options.image_colors == 1:
transform_list.append(transforms.Normalize(mean=[0.5], std=[0.5]))
elif options.image_colors == 3:
transform_list.append(transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]))
transform = transforms.Compose(transform_list)
if options.dataset == 'mnist':
dataset = datasets.MNIST(options.data_dir, train=True, download=True, transform=transform)
elif options.dataset == 'emnist':
# Updated URL from https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist
datasets.EMNIST.url = 'https://cloudstor.aarnet.edu.au/plus/s/ZNmuFiuQTqZlu9W/download'
dataset = datasets.EMNIST(options.data_dir, split=options.image_class, train=True, download=True, transform=transform)
elif options.dataset == 'fashion-mnist':
dataset = datasets.FashionMNIST(options.data_dir, train=True, download=True, transform=transform)
elif options.dataset == 'lsun':
training_class = options.image_class + '_train'
dataset = datasets.LSUN(options.data_dir, classes=[training_class], transform=transform)
elif options.dataset == 'cifar10':
dataset = datasets.CIFAR10(options.data_dir, train=True, download=True, transform=transform)
elif options.dataset == 'cifar100':
dataset = datasets.CIFAR100(options.data_dir, train=True, download=True, transform=transform)
else:
dataset = datasets.ImageFolder(root=options.data_dir, transform=transform)
self.dataloader = DataLoader(
dataset,
batch_size=options.batch_size,
num_workers=options.loader_workers,
shuffle=True,
drop_last=True,
pin_memory=options.pin_memory
)
self.iterator = iter(self.dataloader)
示例6: load_lsun
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def load_lsun(self, classes='church_outdoor_train'):
transforms = self.transform(True, True, True, False)
dataset = dsets.LSUN(self.path, classes=[classes], transform=transforms)
return dataset
示例7: get_dataset
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def get_dataset(name, data_dir, size=64, lsun_categories=None):
transform = transforms.Compose([
transforms.Resize(size),
transforms.CenterCrop(size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
transforms.Lambda(lambda x: x + 1./128 * torch.rand(x.size())),
])
if name == 'image':
dataset = datasets.ImageFolder(data_dir, transform)
nlabels = len(dataset.classes)
elif name == 'npy':
# Only support normalization for now
dataset = datasets.DatasetFolder(data_dir, npy_loader, ['npy'])
nlabels = len(dataset.classes)
elif name == 'cifar10':
dataset = datasets.CIFAR10(root=data_dir, train=True, download=True,
transform=transform)
nlabels = 10
elif name == 'lsun':
if lsun_categories is None:
lsun_categories = 'train'
dataset = datasets.LSUN(data_dir, lsun_categories, transform)
nlabels = len(dataset.classes)
elif name == 'lsun_class':
dataset = datasets.LSUNClass(data_dir, transform,
target_transform=(lambda t: 0))
nlabels = 1
else:
raise NotImplemented
return dataset, nlabels
示例8: get_dataset
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def get_dataset(args):
trans = lambda im_size: tforms.Compose([tforms.Resize(im_size), tforms.ToTensor(), add_noise])
if args.data == "mnist":
im_dim = 1
im_size = 28 if args.imagesize is None else args.imagesize
train_set = dset.MNIST(root="./data", train=True, transform=trans(im_size), download=True)
elif args.data == "cifar10":
im_dim = 3
im_size = 32 if args.imagesize is None else args.imagesize
train_set = dset.CIFAR10(
root="./data", train=True, transform=tforms.Compose([
tforms.Resize(im_size),
tforms.RandomHorizontalFlip(),
tforms.ToTensor(),
add_noise,
]), download=True
)
elif args.data == 'lsun_church':
im_dim = 3
im_size = 64 if args.imagesize is None else args.imagesize
train_set = dset.LSUN(
'data', ['church_outdoor_train'], transform=tforms.Compose([
tforms.Resize(96),
tforms.RandomCrop(64),
tforms.Resize(im_size),
tforms.ToTensor(),
add_noise,
])
)
data_shape = (im_dim, im_size, im_size)
if not args.conv:
data_shape = (im_dim * im_size * im_size,)
return train_set, data_shape
示例9: load_data
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def load_data(image_data_type, path_to_folder, data_transform, batch_size, classes=None, num_workers=5):
# torch issue
# https://github.com/pytorch/pytorch/issues/22866
torch.set_num_threads(1)
if image_data_type == 'lsun':
dataset = datasets.LSUN(path_to_folder, classes=classes, transform=data_transform)
elif image_data_type == "image_folder":
dataset = datasets.ImageFolder(root=path_to_folder,transform=data_transform)
else:
raise ValueError("Invalid image data type")
dataset_loader = torch.utils.data.DataLoader(dataset,batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True, pin_memory=True)
return dataset_loader
示例10: load_data
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def load_data(path_to_folder, classes):
data_transform = transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])
])
if IMAGE_DATA_SET == 'lsun':
dataset = datasets.LSUN(path_to_folder, classes=classes, transform=data_transform)
else:
dataset = datasets.ImageFolder(root=path_to_folder,transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(dataset,batch_size=BATCH_SIZE, shuffle=True, num_workers=5, drop_last=True, pin_memory=True)
return dataset_loader
示例11: check_dataset
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def check_dataset(dataset, dataroot):
"""
Args:
dataset (str): Name of the dataset to use. See CLI help for details
dataroot (str): root directory where the dataset will be stored.
Returns:
dataset (data.Dataset): torchvision Dataset object
"""
resize = transforms.Resize(64)
crop = transforms.CenterCrop(64)
to_tensor = transforms.ToTensor()
normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
if dataset in {"imagenet", "folder", "lfw"}:
dataset = dset.ImageFolder(root=dataroot, transform=transforms.Compose([resize, crop, to_tensor, normalize]))
nc = 3
elif dataset == "lsun":
dataset = dset.LSUN(
root=dataroot, classes=["bedroom_train"], transform=transforms.Compose([resize, crop, to_tensor, normalize])
)
nc = 3
elif dataset == "cifar10":
dataset = dset.CIFAR10(
root=dataroot, download=True, transform=transforms.Compose([resize, to_tensor, normalize])
)
nc = 3
elif dataset == "mnist":
dataset = dset.MNIST(root=dataroot, download=True, transform=transforms.Compose([resize, to_tensor, normalize]))
nc = 1
elif dataset == "fake":
dataset = dset.FakeData(size=256, image_size=(3, 64, 64), transform=to_tensor)
nc = 3
else:
raise RuntimeError("Invalid dataset name: {}".format(dataset))
return dataset, nc
示例12: __getDataSet
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def __getDataSet(opt):
if isDebug: print(f"Getting dataset: {opt.dataset} ... ")
dataset = None
if opt.dataset in ['imagenet', 'folder', 'lfw']:
# folder dataset
traindir = os.path.join(opt.dataroot, f"{opt.dataroot}/train")
valdir = os.path.join(opt.dataroot, f"{opt.dataroot}/val")
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_dataset = dset.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(opt.imageSize),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
dataset = dset.ImageFolder(root=opt.dataroot,
transform=transforms.Compose([
transforms.Scale(opt.imageSize),
transforms.CenterCrop(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
elif opt.dataset == 'lsun':
dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'],
transform=transforms.Compose([
transforms.Scale(opt.imageSize),
transforms.CenterCrop(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
elif opt.dataset == 'cifar10':
dataset = dset.CIFAR10(root=opt.dataroot, download=True,
transform=transforms.Compose([
transforms.Scale(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
# Load pre-trained state dict
if opt.load_dict:
opt.netD = NETD_CIFAR10
opt.netG = NETG_CIFAR10
elif opt.dataset == 'mnist':
opt.nc = 1
opt.imageSize = 32
dataset = dset.MNIST(root=opt.dataroot, download=True, transform=transforms.Compose([
transforms.Scale(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
# Update opt params for mnist
if opt.load_dict:
opt.netD = NETD_MNIST
opt.netG = NETG_MNIST
return dataset
示例13: get_data_loader
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def get_data_loader(dataset, dataroot, workers, image_size, batch_size):
if dataset in ['imagenet', 'folder', 'lfw']:
# folder dataset
dataset = dset.ImageFolder(root=dataroot,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5)),
]))
elif dataset == 'lsun':
dataset = dset.LSUN(root=dataroot, classes=['bedroom_train'],
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5)),
]))
elif dataset == 'cifar10':
dataset = dset.CIFAR10(root=dataroot, download=True,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5)),
]))
elif dataset == 'mnist':
dataset = dset.MNIST(root=dataroot, train=True, download=True,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5)),
]))
elif dataset == 'fake':
dataset = dset.FakeData(image_size=(3, image_size, image_size),
transform=transforms.ToTensor())
else:
assert False
assert dataset
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True,
num_workers=int(workers))
return data_loader